1 #ifndef mitkImpurityLoss_cpp
2 #define mitkImpurityLoss_cpp
6 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
9 vigra::ProblemSpec<T>
const &ext) :
10 m_UsePointWeights(false),
12 m_Counts(ext.class_count_, 0.0),
13 m_ClassWeights(ext.class_weights_),
18 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
26 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
27 template <
class TDataIterator>
31 for (TDataIterator iter = begin; iter != end; ++iter)
33 double pointProbability = 1.0;
34 if (m_UsePointWeights)
36 pointProbability = m_PointWeights(*iter,0);
38 m_Counts[m_Labels(*iter,0)] += pointProbability;
39 m_TotalCount += pointProbability;
41 return m_LossFunction(m_Counts, m_ClassWeights, m_TotalCount);
44 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
45 template <
class TDataIterator>
49 for (TDataIterator iter = begin; iter != end; ++iter)
51 double pointProbability = 1.0;
52 if (m_UsePointWeights)
54 pointProbability = m_PointWeights(*iter,0);
56 m_Counts[m_Labels(*iter,0)] -= pointProbability;
57 m_TotalCount -= pointProbability;
59 return m_LossFunction(m_Counts, m_ClassWeights, m_TotalCount);
62 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
63 template <
class TArray>
68 std::copy(initCounts.begin(), initCounts.end(), m_Counts.begin());
69 m_TotalCount = std::accumulate(m_Counts.begin(), m_Counts.end(), 0.0);
70 return m_LossFunction(m_Counts, m_ClassWeights, m_TotalCount);
73 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
74 vigra::ArrayVector<double>
const&
80 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
84 m_UsePointWeights = useWeights;
87 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
91 return m_UsePointWeights;
94 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
98 m_PointWeights = weight;
101 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
105 return m_PointWeights;
109 #endif // mitkImpurityLoss_cpp
void UsePointWeights(bool useWeights)
void SetPointWeights(TWeightContainer weight)
double Decrement(TDataIterator begin, TDataIterator end)
vigra::ArrayVector< double > const & Response()
ImpurityLoss(TLabelContainer const &labels, vigra::ProblemSpec< T > const &ext)
double Increment(TDataIterator begin, TDataIterator end)
WeightContainerType GetPointWeights()
double Init(TArray initCounts)
bool IsUsingPointWeights()
TWeightContainer WeightContainerType