13 #ifndef mitkPUImpurityLoss_cpp 14 #define mitkPUImpurityLoss_cpp 19 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
22 vigra::ProblemSpec<T>
const &ext,
24 m_UsePointWeights(false),
27 m_Counts(ext.class_count_, 0.0),
28 m_PUCounts(ext.class_count_, 0.0),
29 m_ClassWeights(ext.class_weights_),
32 m_ClassCount(ext.class_count_)
36 m_Kappa = vigra::ArrayVector<double>(purfdata->
m_Kappa);
39 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
47 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
52 for (
int i = 1; i < m_ClassCount; ++i)
54 m_PUCounts[i] = m_Kappa[i] * m_Counts[i];
55 m_PUTotalCount += m_PUCounts[i];
57 m_PUCounts[0] =
std::max(0.0, m_TotalCount - m_PUTotalCount);
58 m_PUTotalCount += m_PUCounts[0];
61 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
62 template <
class TDataIterator>
66 for (TDataIterator iter = begin; iter != end; ++iter)
68 double pointProbability = 1.0;
69 if (m_UsePointWeights)
71 pointProbability = m_PointWeights(*iter,0);
73 m_Counts[m_Labels(*iter,0)] += pointProbability;
74 m_TotalCount += pointProbability;
77 return m_LossFunction(m_PUCounts, m_ClassWeights, m_PUTotalCount);
80 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
81 template <
class TDataIterator>
85 for (TDataIterator iter = begin; iter != end; ++iter)
87 double pointProbability = 1.0;
88 if (m_UsePointWeights)
90 pointProbability = m_PointWeights(*iter,0);
92 m_Counts[m_Labels(*iter,0)] -= pointProbability;
93 m_TotalCount -= pointProbability;
96 return m_LossFunction(m_PUCounts, m_ClassWeights, m_PUTotalCount);
99 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
100 template <
class TArray>
105 std::copy(initCounts.begin(), initCounts.end(), m_Counts.begin());
106 m_TotalCount = std::accumulate(m_Counts.begin(), m_Counts.end(), 0.0);
107 return m_LossFunction(m_Counts, m_ClassWeights, m_TotalCount);
110 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
111 vigra::ArrayVector<double>
const&
117 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
121 m_UsePointWeights = useWeights;
124 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
128 return m_UsePointWeights;
131 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
135 m_PointWeights = weight;
138 template <
class TLossFunction,
class TLabelContainer,
class TWeightContainer>
142 return m_PointWeights;
146 #endif // mitkImpurityLoss_cpp vigra::ArrayVector< double > const & Response()
double Decrement(TDataIterator begin, TDataIterator end)
WeightContainerType GetPointWeights()
double Init(TArray initCounts)
bool IsUsingPointWeights()
TWeightContainer WeightContainerType
double Increment(TDataIterator begin, TDataIterator end)
void SetPointWeights(TWeightContainer weight)
void UsePointWeights(bool useWeights)
vigra::ArrayVector< double > m_Kappa
PUImpurityLoss(TLabelContainer const &labels, vigra::ProblemSpec< T > const &ext, AdditionalRFDataAbstract *data)