13 #ifndef mitkImpurityLoss_cpp 14 #define mitkImpurityLoss_cpp 18 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
21 vigra::ProblemSpec<T>
const &ext,
23 m_UsePointWeights(false),
25 m_Counts(ext.class_count_, 0.0),
26 m_ClassWeights(ext.class_weights_),
31 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
39 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
40 template <
class TDataIterator>
44 for (TDataIterator iter = begin; iter != end; ++iter)
46 double pointProbability = 1.0;
47 if (m_UsePointWeights)
49 pointProbability = m_PointWeights(*iter,0);
51 m_Counts[m_Labels(*iter,0)] += pointProbability;
52 m_TotalCount += pointProbability;
54 return m_LossFunction(m_Counts, m_ClassWeights, m_TotalCount);
57 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
58 template <
class TDataIterator>
62 for (TDataIterator iter = begin; iter != end; ++iter)
64 double pointProbability = 1.0;
65 if (m_UsePointWeights)
67 pointProbability = m_PointWeights(*iter,0);
69 m_Counts[m_Labels(*iter,0)] -= pointProbability;
70 m_TotalCount -= pointProbability;
72 return m_LossFunction(m_Counts, m_ClassWeights, m_TotalCount);
75 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
76 template <
class TArray>
81 std::copy(initCounts.begin(), initCounts.end(), m_Counts.begin());
82 m_TotalCount = std::accumulate(m_Counts.begin(), m_Counts.end(), 0.0);
83 return m_LossFunction(m_Counts, m_ClassWeights, m_TotalCount);
86 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
87 vigra::ArrayVector<double>
const&
93 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
97 m_UsePointWeights = useWeights;
100 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
104 return m_UsePointWeights;
107 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
111 m_PointWeights = weight;
114 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
118 return m_PointWeights;
122 #endif // mitkImpurityLoss_cpp void UsePointWeights(bool useWeights)
vigra::ArrayVector< double > const & Response()
void SetPointWeights(TWeightContainer weight)
double Decrement(TDataIterator begin, TDataIterator end)
double Increment(TDataIterator begin, TDataIterator end)
WeightContainerType GetPointWeights()
double Init(TArray initCounts)
bool IsUsingPointWeights()
TWeightContainer WeightContainerType
ImpurityLoss(TLabelContainer const &labels, vigra::ProblemSpec< T > const &ext, AdditionalRFDataAbstract *data)