Medical Imaging Interaction Toolkit  2018.4.99-dfa0c14e
Medical Imaging Interaction Toolkit
mitkPUImpurityLoss.cpp
Go to the documentation of this file.
1 /*============================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center (DKFZ)
6 All rights reserved.
7 
8 Use of this source code is governed by a 3-clause BSD license that can be
9 found in the LICENSE file.
10 
11 ============================================================================*/
12 
13 #ifndef mitkPUImpurityLoss_cpp
14 #define mitkPUImpurityLoss_cpp
15 
16 #include <mitkPUImpurityLoss.h>
17 #include <mitkAdditionalRFData.h>
18 
19 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
20 template <class T>
22  vigra::ProblemSpec<T> const &ext,
24  m_UsePointWeights(false),
25  m_Labels(labels),
26  //m_Kappa(ext.kappa_), // Not possible due to data type
27  m_Counts(ext.class_count_, 0.0),
28  m_PUCounts(ext.class_count_, 0.0),
29  m_ClassWeights(ext.class_weights_),
30  m_TotalCount(0.0),
31  m_PUTotalCount(0.0),
32  m_ClassCount(ext.class_count_)
33 {
34  mitk::PURFData * purfdata = dynamic_cast<PURFData *> (data);
35  //const PURFProblemSpec<T> *problem = static_cast<const PURFProblemSpec<T> * > (&ext);
36  m_Kappa = vigra::ArrayVector<double>(purfdata->m_Kappa);
37 }
38 
39 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
40 void
42 {
43  m_Counts.init(0);
44  m_TotalCount = 0.0;
45 }
46 
47 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
48 void
50 {
51  m_PUTotalCount = 0;
52  for (int i = 1; i < m_ClassCount; ++i)
53  {
54  m_PUCounts[i] = m_Kappa[i] * m_Counts[i];
55  m_PUTotalCount += m_PUCounts[i];
56  }
57  m_PUCounts[0] = std::max(0.0, m_TotalCount - m_PUTotalCount);
58  m_PUTotalCount += m_PUCounts[0];
59 }
60 
61 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
62 template <class TDataIterator>
63 double
65 {
66  for (TDataIterator iter = begin; iter != end; ++iter)
67  {
68  double pointProbability = 1.0;
69  if (m_UsePointWeights)
70  {
71  pointProbability = m_PointWeights(*iter,0);
72  }
73  m_Counts[m_Labels(*iter,0)] += pointProbability;
74  m_TotalCount += pointProbability;
75  }
77  return m_LossFunction(m_PUCounts, m_ClassWeights, m_PUTotalCount);
78 }
79 
80 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
81 template <class TDataIterator>
82 double
84 {
85  for (TDataIterator iter = begin; iter != end; ++iter)
86  {
87  double pointProbability = 1.0;
88  if (m_UsePointWeights)
89  {
90  pointProbability = m_PointWeights(*iter,0);
91  }
92  m_Counts[m_Labels(*iter,0)] -= pointProbability;
93  m_TotalCount -= pointProbability;
94  }
96  return m_LossFunction(m_PUCounts, m_ClassWeights, m_PUTotalCount);
97 }
98 
99 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
100 template <class TArray>
101 double
103 {
104  Reset();
105  std::copy(initCounts.begin(), initCounts.end(), m_Counts.begin());
106  m_TotalCount = std::accumulate(m_Counts.begin(), m_Counts.end(), 0.0);
107  return m_LossFunction(m_Counts, m_ClassWeights, m_TotalCount);
108 }
109 
110 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
111 vigra::ArrayVector<double> const&
113 {
114  return m_Counts;
115 }
116 
117 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
118 void
120 {
121  m_UsePointWeights = useWeights;
122 }
123 
124 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
125 bool
127 {
128  return m_UsePointWeights;
129 }
130 
131 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
132 void
134 {
135  m_PointWeights = weight;
136 }
137 
138 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
141 {
142  return m_PointWeights;
143 }
144 
145 
146 #endif // mitkImpurityLoss_cpp
147 
148 
vigra::ArrayVector< double > const & Response()
double Decrement(TDataIterator begin, TDataIterator end)
WeightContainerType GetPointWeights()
double Init(TArray initCounts)
TWeightContainer WeightContainerType
double Increment(TDataIterator begin, TDataIterator end)
void SetPointWeights(TWeightContainer weight)
void UsePointWeights(bool useWeights)
static T max(T x, T y)
Definition: svm.cpp:56
vigra::ArrayVector< double > m_Kappa
PUImpurityLoss(TLabelContainer const &labels, vigra::ProblemSpec< T > const &ext, AdditionalRFDataAbstract *data)