Medical Imaging Interaction Toolkit  2018.4.99-389bf124
Medical Imaging Interaction Toolkit
mitkThresholdSplit.cpp
Go to the documentation of this file.
1 /*============================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center (DKFZ)
6 All rights reserved.
7 
8 Use of this source code is governed by a 3-clause BSD license that can be
9 found in the LICENSE file.
10 
11 ============================================================================*/
12 
13 #ifndef mitkThresholdSplit_cpp
14 #define mitkThresholdSplit_cpp
15 
16 #include <mitkThresholdSplit.h>
17 
18 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
20  m_CalculatingFeature(false),
21  m_UseWeights(false),
22  m_UseRandomSplit(false),
23  m_Precision(0.0),
24  m_MaximumTreeDepth(1000),
25  m_AdditionalData(nullptr)
26 {
27 }
28 
29 //template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
30 //mitk::ThresholdSplit<TColumnDecisionFunctor, TFeatureCalculator, TTag>::ThresholdSplit(const ThresholdSplit & /*other*/)/*:
31 // m_CalculatingFeature(other.IsCalculatingFeature()),
32 // m_UseWeights(other.IsUsingPointBasedWeights()),
33 // m_UseRandomSplit(other.IsUsingRandomSplit()),
34 // m_Precision(other.GetPrecision()),
35 // m_MaximumTreeDepth(other.GetMaximumTreeDepth()),
36 // m_FeatureCalculator(other.GetFeatureCalculator()),
37 // m_Weights(other.GetWeights())*/
38 //{
39 //}
40 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
41 void
43 {
44  bgfunc.SetAdditionalData(data);
45  m_AdditionalData = data;
46 }
47 
48 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
51 {
52  return m_AdditionalData;
53 }
54 
55 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
56 void
58 {
59  m_FeatureCalculator = processor;
60 }
61 
62 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
63 TFeatureCalculator
65 {
66  return m_FeatureCalculator;
67 }
68 
69 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
70 void
72 {
73  m_CalculatingFeature = calculate;
74 }
75 
76 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
77 bool
79 {
80  return m_CalculatingFeature;
81 }
82 
83 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
84 void
86 {
87  m_UseWeights = weightsOn;
88  bgfunc.UsePointWeights(weightsOn);
89 }
90 
91 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
92 bool
94 {
95  return m_UseWeights;
96 }
97 
98 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
99 void
101 {
102  m_Precision = value;
103 }
104 
105 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
106 double
108 {
109  return m_Precision;
110 }
111 
112 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
113 void
115 {
116  m_MaximumTreeDepth = value;
117 }
118 
119 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
120 int
122 {
123  return m_MaximumTreeDepth;
124 }
125 
126 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
127 void
129 {
130  m_Weights = weights;
131  bgfunc.UsePointWeights(m_UseWeights);
132  bgfunc.SetPointWeights(weights);
133 }
134 
135 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
136 vigra::MultiArrayView<2, double>
138 {
139  return m_Weights;
140 }
141 
142 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
143 double
145 {
146  return min_gini_[bestSplitIndex];
147 }
148 
149 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
150 int
152 {
153  return splitColumns[bestSplitIndex];
154 }
155 
156 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
157 double
159 {
160  return min_thresholds_[bestSplitIndex];
161 }
162 
163 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
164 template<class T>
165 void
167 {
168  SB::set_external_parameters(in);
169  bgfunc.set_external_parameters( SB::ext_param_);
170  int featureCount_ = SB::ext_param_.column_count_;
171  splitColumns.resize(featureCount_);
172  for(int k=0; k<featureCount_; ++k)
173  splitColumns[k] = k;
174  min_gini_.resize(featureCount_);
175  min_indices_.resize(featureCount_);
176  min_thresholds_.resize(featureCount_);
177 }
178 
179 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
180 template<class T, class C, class T2, class C2, class Region, class Random>
181 int
183  vigra::MultiArrayView<2, T2, C2> labels,
184  Region & region,
185  vigra::ArrayVector<Region>& childRegions,
186  Random & randint)
187 {
188  typedef typename Region::IndexIterator IndexIteratorType;
189 
190  if (m_CalculatingFeature)
191  {
192  // Do some very fance stuff here!!
193 
194  // This is not so simple as it might look! We need to
195  // remember which feature has been used to be able to
196  // use it for testing again!!
197  // There, no Splitting class is used!!
198  }
199 
200  bgfunc.UsePointWeights(m_UseWeights);
201  bgfunc.UseRandomSplit(m_UseRandomSplit);
202 
203  vigra::detail::Correction<TTag>::exec(region, labels);
204  // Create initial class count.
205  for(std::size_t i = 0; i < region.classCounts_.size(); ++i)
206  {
207  region.classCounts_[i] = 0;
208  }
209  double regionSum = 0;
210  for (typename Region::IndexIterator iter = region.begin(); iter != region.end(); ++iter)
211  {
212  double probability = 1.0;
213  if (m_UseWeights)
214  {
215  probability = m_Weights(*iter, 0);
216  }
217  region.classCounts_[labels(*iter,0)] += probability;
218  regionSum += probability;
219  }
220  region.classCountsIsValid = true;
221  vigra::ArrayVector<double> vec;
222 
223  // Is pure region?
224  region_gini_ = bgfunc.LossOfRegion(labels,
225  region.begin(),
226  region.end(),
227  region.classCounts());
228  if (region_gini_ <= m_Precision * regionSum) // Necessary to fix wrong calculation of Gini-Index
229  {
230  return this->makeTerminalNode(features, labels, region, randint);
231  }
232 
233  // Randomize the order of columns
234  for (int i = 0; i < SB::ext_param_.actual_mtry_; ++i)
235  {
236  std::swap(splitColumns[i],
237  splitColumns[i+ randint(features.shape(1) - i)]);
238  }
239 
240  // find the split with the best evaluation value
241  bestSplitIndex = 0;
242  double currentMiniGini = region_gini_;
243  int numberOfTrials = features.shape(1);
244  for (int k = 0; k < numberOfTrials; ++k)
245  {
246  bgfunc(columnVector(features, splitColumns[k]),
247  labels,
248  region.begin(), region.end(),
249  region.classCounts());
250  min_gini_[k] = bgfunc.GetMinimumLoss();
251  min_indices_[k] = bgfunc.GetMinimumIndex();
252  min_thresholds_[k] = bgfunc.GetMinimumThreshold();
253 
254  // removed classifier test section, because not necessary
255  if (bgfunc.GetMinimumLoss() < currentMiniGini)
256  {
257  currentMiniGini = bgfunc.GetMinimumLoss();
258  childRegions[0].classCounts() = bgfunc.GetBestCurrentCounts()[0];
259  childRegions[1].classCounts() = bgfunc.GetBestCurrentCounts()[1];
260  childRegions[0].classCountsIsValid = true;
261  childRegions[1].classCountsIsValid = true;
262 
263  bestSplitIndex = k;
264  numberOfTrials = SB::ext_param_.actual_mtry_;
265  }
266  }
267 
268  //If only a small improvement, make terminal node...
269  if(vigra::closeAtTolerance(currentMiniGini, region_gini_))
270  {
271  return this->makeTerminalNode(features, labels, region, randint);
272  }
273 
274  vigra::Node<vigra::i_ThresholdNode> node(SB::t_data, SB::p_data);
275  SB::node_ = node;
276  node.threshold() = min_thresholds_[bestSplitIndex];
277  node.column() = splitColumns[bestSplitIndex];
278 
279  // partition the range according to the best dimension
280  vigra::SortSamplesByDimensions<vigra::MultiArrayView<2, T, C> >
281  sorter(features, node.column(), node.threshold());
282  IndexIteratorType bestSplit =
283  std::partition(region.begin(), region.end(), sorter);
284  // Save the ranges of the child stack entries.
285  childRegions[0].setRange( region.begin() , bestSplit );
286  childRegions[0].rule = region.rule;
287  childRegions[0].rule.push_back(std::make_pair(1, 1.0));
288  childRegions[1].setRange( bestSplit , region.end() );
289  childRegions[1].rule = region.rule;
290  childRegions[1].rule.push_back(std::make_pair(1, 1.0));
291 
292  return vigra::i_ThresholdNode;
293 
294  return 0;
295 }
296 
297 //template<class TRegion, class TRegionIterator, class TLabelHolder, class TWeightsHolder>
298 //static void UpdateRegionCounts(TRegion & region, TRegionIterator begin, TRegionIterator end, TLabelHolder labels, TWeightsHolder weights)
299 //{
300 // if(std::accumulate(region.classCounts().begin(),
301 // region.classCounts().end(), 0.0) != region.size())
302 // {
303 // RandomForestClassCounter< LabelT,
304 // ArrayVector<double> >
305 // counter(labels, region.classCounts());
306 // std::for_each( region.begin(), region.end(), counter);
307 // region.classCountsIsValid = true;
308 // }
309 //}
310 //
311 //template<class TRegion, class TLabel, class TWeights>
312 //static void exec(Region & region, LabelT & labels)
313 //{
314 // if(std::accumulate(region.classCounts().begin(),
315 // region.classCounts().end(), 0.0) != region.size())
316 // {
317 // RandomForestClassCounter< LabelT,
318 // ArrayVector<double> >
319 // counter(labels, region.classCounts());
320 // std::for_each( region.begin(), region.end(), counter);
321 // region.classCountsIsValid = true;
322 // }
323 //}
324 
325 #endif //mitkThresholdSplit_cpp
float k(1.0)
void SetFeatureCalculator(TFeatureCalculator processor)
void SetAdditionalData(AdditionalRFDataAbstract *data)
void set_external_parameters(vigra::ProblemSpec< T > const &in)
void UsePointBasedWeights(bool weightsOn)
bool IsUsingPointBasedWeights() const
AdditionalRFDataAbstract * GetAdditionalData() const
int GetMaximumTreeDepth() const override
void SetWeights(vigra::MultiArrayView< 2, double > weights)
bool IsCalculatingFeature() const
double bestSplitThreshold() const
void SetPrecision(double value)
void SetMaximumTreeDepth(int value)
static bool in(Reader::Char c, Reader::Char c1, Reader::Char c2, Reader::Char c3, Reader::Char c4)
Definition: jsoncpp.cpp:244
const char features[]
static void swap(T &x, T &y)
Definition: svm.cpp:58
void SetCalculatingFeature(bool calculate)
vigra::MultiArrayView< 2, double > GetWeights() const
int findBestSplit(vigra::MultiArrayView< 2, T, C > features, vigra::MultiArrayView< 2, T2, C2 > labels, Region &region, vigra::ArrayVector< Region > &childRegions, Random &randint)
TFeatureCalculator GetFeatureCalculator() const