Medical Imaging Interaction Toolkit  2018.4.99-12ad79a3
Medical Imaging Interaction Toolkit
mitkLinearSplitting.cpp
Go to the documentation of this file.
1 /*============================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center (DKFZ)
6 All rights reserved.
7 
8 Use of this source code is governed by a 3-clause BSD license that can be
9 found in the LICENSE file.
10 
11 ============================================================================*/
12 
13 #ifndef mitkLinearSplitting_cpp
14 #define mitkLinearSplitting_cpp
15 
16 #include <mitkLinearSplitting.h>
17 #include <mitkPUImpurityLoss.h>
18 
19 template<class TLossAccumulator>
21  m_UsePointWeights(false),
22  m_UseRandomSplit(false),
23  m_AdditionalData(nullptr)
24 {
25 }
26 
27 template<class TLossAccumulator>
28 template <class T>
30  m_UsePointWeights(false),
31  m_UseRandomSplit(false)
32 {
34 }
35 
36 template<class TLossAccumulator>
37 void
39 {
40  m_AdditionalData = data;
41 }
42 
43 template<class TLossAccumulator>
46 {
47  return m_AdditionalData;
48 }
49 
50 template<class TLossAccumulator>
51 void
53 {
54  m_UsePointWeights = pointWeight;
55 }
56 
57 template<class TLossAccumulator>
58 bool
60 {
61  return m_UsePointWeights;
62 }
63 
64 
65 template<class TLossAccumulator>
66 void
68 {
69  m_UseRandomSplit = randomSplit;
70 }
71 
72 template<class TLossAccumulator>
73 bool
75 {
76  return m_UseRandomSplit;
77 }
78 
79 template<class TLossAccumulator>
80 void
82 {
83  m_PointWeights = weight;
84 }
85 
86 template<class TLossAccumulator>
89 {
90  return m_PointWeights;
91 }
92 
93 template<class TLossAccumulator>
94 template <class T>
95 void
97 {
98  m_ExtParameter = ext;
99 }
100 
101 template<class TLossAccumulator>
102 template <class TDataSourceFeature, class TDataSourceLabel, class TDataIterator, class TArray>
103 void
105  TDataSourceLabel const &labels,
106  TDataIterator &begin,
107  TDataIterator &end,
108  TArray const &regionResponse)
109 {
110  typedef TLossAccumulator LineSearchLoss;
111  std::sort(begin, end, vigra::SortSamplesByDimensions<TDataSourceFeature>(column, 0));
112 
113  LineSearchLoss left(labels, m_ExtParameter, m_AdditionalData);
114  LineSearchLoss right(labels, m_ExtParameter, m_AdditionalData);
115 
116  if (m_UsePointWeights)
117  {
118  left.UsePointWeights(true);
119  left.SetPointWeights(m_PointWeights);
120  right.UsePointWeights(true);
121  right.SetPointWeights(m_PointWeights);
122  }
123 
124  m_MinimumLoss = right.Init(regionResponse);
125  m_MinimumThreshold = *begin;
126  m_MinimumIndex = 0;
127 
128  vigra::DimensionNotEqual<TDataSourceFeature> compareNotEqual(column, 0);
129 
130  if (!m_UseRandomSplit)
131  {
132  TDataIterator iter = begin;
133  // Find the next element that are NOT equal with his neightbour!
134  TDataIterator next = std::adjacent_find(iter, end, compareNotEqual);
135 
136  while(next != end)
137  {
138  // Remove or add the current segment are from the LineSearch
139  double rightLoss = right.Decrement(iter, next +1);
140  double leftLoss = left.Increment(iter, next +1);
141  double currentLoss = rightLoss + leftLoss;
142 
143  if (currentLoss < m_MinimumLoss)
144  {
145  m_BestCurrentCounts[0] = left.Response();
146  m_BestCurrentCounts[1] = right.Response();
147  m_MinimumLoss = currentLoss;
148  m_MinimumIndex = next - begin + 1;
149  m_MinimumThreshold = (double(column(*next,0)) + double(column(*(next +1), 0)))/2.0;
150  }
151 
152  iter = next + 1;
153  next = std::adjacent_find(iter, end, compareNotEqual);
154  }
155  }
156  else // If Random split is selected, e.g. ExtraTree behaviour
157  {
158  int size = end - begin + 1;
159  srand(time(nullptr));
160  int offset = rand() % size;
161  TDataIterator iter = begin + offset;
162 
163  double rightLoss = right.Decrement(begin, iter+1);
164  double leftLoss = left.Increment(begin, iter+1);
165  double currentLoss = rightLoss + leftLoss;
166 
167  if (currentLoss < m_MinimumLoss)
168  {
169  m_BestCurrentCounts[0] = left.Response();
170  m_BestCurrentCounts[1] = right.Response();
171  m_MinimumLoss = currentLoss;
172  m_MinimumIndex = offset + 1;
173  m_MinimumThreshold = (double(column(*iter,0)) + double(column(*(iter+1), 0)))/2.0;
174  }
175  }
176 }
177 
178 template<class TLossAccumulator>
179 template <class TDataSourceLabel, class TDataIterator, class TArray>
180 double
182  TDataIterator &/*begin*/,
183  TDataIterator &/*end*/,
184  TArray const & regionResponse)
185 {
186  typedef TLossAccumulator LineSearchLoss;
187  LineSearchLoss regionLoss(labels, m_ExtParameter, m_AdditionalData);
188  if (m_UsePointWeights)
189  {
190  regionLoss.UsePointWeights(true);
191  regionLoss.SetPointWeights(m_PointWeights);
192  }
193  return regionLoss.Init(regionResponse);
194 }
195 
196 #endif //mitkLinearSplitting_cpp
void operator()(TDataSourceFeature const &column, TDataSourceLabel const &labels, TDataIterator &begin, TDataIterator &end, TArray const &regionResponse)
void SetPointWeights(WeightContainerType weight)
void set_external_parameters(vigra::ProblemSpec< T > const &ext)
void UseRandomSplit(bool randomSplit)
void UsePointWeights(bool pointWeight)
static Vector3D offset
AdditionalRFDataAbstract * GetAdditionalData() const
WeightContainerType GetPointWeights()
void SetAdditionalData(AdditionalRFDataAbstract *data)
TWeightContainer WeightContainerType
double LossOfRegion(TDataSourceLabel const &labels, TDataIterator &begin, TDataIterator &end, TArray const &regionResponse)