Medical Imaging Interaction Toolkit  2018.4.99-b20efe7f
Medical Imaging Interaction Toolkit
mitkImpurityLoss.cpp
Go to the documentation of this file.
1 /*============================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center (DKFZ)
6 All rights reserved.
7 
8 Use of this source code is governed by a 3-clause BSD license that can be
9 found in the LICENSE file.
10 
11 ============================================================================*/
12 
13 #ifndef mitkImpurityLoss_cpp
14 #define mitkImpurityLoss_cpp
15 
16 #include <mitkImpurityLoss.h>
17 
18 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
19 template <class T>
21  vigra::ProblemSpec<T> const &ext,
22  AdditionalRFDataAbstract * /*data*/) :
23  m_UsePointWeights(false),
24  m_Labels(labels),
25  m_Counts(ext.class_count_, 0.0),
26  m_ClassWeights(ext.class_weights_),
27  m_TotalCount(0.0)
28 {
29 }
30 
31 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
32 void
34 {
35  m_Counts.init(0);
36  m_TotalCount = 0.0;
37 }
38 
39 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
40 template <class TDataIterator>
41 double
43 {
44  for (TDataIterator iter = begin; iter != end; ++iter)
45  {
46  double pointProbability = 1.0;
47  if (m_UsePointWeights)
48  {
49  pointProbability = m_PointWeights(*iter,0);
50  }
51  m_Counts[m_Labels(*iter,0)] += pointProbability;
52  m_TotalCount += pointProbability;
53  }
54  return m_LossFunction(m_Counts, m_ClassWeights, m_TotalCount);
55 }
56 
57 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
58 template <class TDataIterator>
59 double
61 {
62  for (TDataIterator iter = begin; iter != end; ++iter)
63  {
64  double pointProbability = 1.0;
65  if (m_UsePointWeights)
66  {
67  pointProbability = m_PointWeights(*iter,0);
68  }
69  m_Counts[m_Labels(*iter,0)] -= pointProbability;
70  m_TotalCount -= pointProbability;
71  }
72  return m_LossFunction(m_Counts, m_ClassWeights, m_TotalCount);
73 }
74 
75 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
76 template <class TArray>
77 double
79 {
80  Reset();
81  std::copy(initCounts.begin(), initCounts.end(), m_Counts.begin());
82  m_TotalCount = std::accumulate(m_Counts.begin(), m_Counts.end(), 0.0);
83  return m_LossFunction(m_Counts, m_ClassWeights, m_TotalCount);
84 }
85 
86 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
87 vigra::ArrayVector<double> const&
89 {
90  return m_Counts;
91 }
92 
93 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
94 void
96 {
97  m_UsePointWeights = useWeights;
98 }
99 
100 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
101 bool
103 {
104  return m_UsePointWeights;
105 }
106 
107 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
108 void
110 {
111  m_PointWeights = weight;
112 }
113 
114 template <class TLossFunction, class TLabelContainer, class TWeightContainer>
117 {
118  return m_PointWeights;
119 }
120 
121 
122 #endif // mitkImpurityLoss_cpp
123 
124 
void UsePointWeights(bool useWeights)
vigra::ArrayVector< double > const & Response()
void SetPointWeights(TWeightContainer weight)
double Decrement(TDataIterator begin, TDataIterator end)
double Increment(TDataIterator begin, TDataIterator end)
WeightContainerType GetPointWeights()
double Init(TArray initCounts)
TWeightContainer WeightContainerType
ImpurityLoss(TLabelContainer const &labels, vigra::ProblemSpec< T > const &ext, AdditionalRFDataAbstract *data)