Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
mitkThresholdSplit.cpp
Go to the documentation of this file.
1 #ifndef mitkThresholdSplit_cpp
2 #define mitkThresholdSplit_cpp
3 
4 #include <mitkThresholdSplit.h>
5 
6 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
8  m_CalculatingFeature(false),
9  m_UseWeights(false),
10  m_UseRandomSplit(false),
11  m_Precision(0.0),
12  m_MaximumTreeDepth(1000)
13 {
14 }
15 
16 //template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
17 //mitk::ThresholdSplit<TColumnDecisionFunctor, TFeatureCalculator, TTag>::ThresholdSplit(const ThresholdSplit & /*other*/)/*:
18 // m_CalculatingFeature(other.IsCalculatingFeature()),
19 // m_UseWeights(other.IsUsingPointBasedWeights()),
20 // m_UseRandomSplit(other.IsUsingRandomSplit()),
21 // m_Precision(other.GetPrecision()),
22 // m_MaximumTreeDepth(other.GetMaximumTreeDepth()),
23 // m_FeatureCalculator(other.GetFeatureCalculator()),
24 // m_Weights(other.GetWeights())*/
25 //{
26 //}
27 
28 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
29 void
31 {
32  m_FeatureCalculator = processor;
33 }
34 
35 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
36 TFeatureCalculator
38 {
39  return m_FeatureCalculator;
40 }
41 
42 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
43 void
45 {
46  m_CalculatingFeature = calculate;
47 }
48 
49 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
50 bool
52 {
53  return m_CalculatingFeature;
54 }
55 
56 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
57 void
59 {
60  m_UseWeights = weightsOn;
61  bgfunc.UsePointWeights(weightsOn);
62 }
63 
64 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
65 bool
67 {
68  return m_UseWeights;
69 }
70 
71 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
72 void
74 {
75  m_Precision = value;
76 }
77 
78 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
79 double
81 {
82  return m_Precision;
83 }
84 
85 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
86 void
88 {
89  m_MaximumTreeDepth = value;
90 }
91 
92 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
93 int
95 {
96  return m_MaximumTreeDepth;
97 }
98 
99 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
100 void
102 {
103  m_Weights = weights;
104  bgfunc.UsePointWeights(m_UseWeights);
105  bgfunc.SetPointWeights(weights);
106 }
107 
108 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
109 vigra::MultiArrayView<2, double>
111 {
112  return m_Weights;
113 }
114 
115 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
116 double
118 {
119  return min_gini_[bestSplitIndex];
120 }
121 
122 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
123 int
125 {
126  return splitColumns[bestSplitIndex];
127 }
128 
129 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
130 double
132 {
133  return min_thresholds_[bestSplitIndex];
134 }
135 
136 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
137 template<class T>
138 void
140 {
141  SB::set_external_parameters(in);
142  bgfunc.set_external_parameters( SB::ext_param_);
143  int featureCount_ = SB::ext_param_.column_count_;
144  splitColumns.resize(featureCount_);
145  for(int k=0; k<featureCount_; ++k)
146  splitColumns[k] = k;
147  min_gini_.resize(featureCount_);
148  min_indices_.resize(featureCount_);
149  min_thresholds_.resize(featureCount_);
150 }
151 
152 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
153 template<class T, class C, class T2, class C2, class Region, class Random>
154 int
156  vigra::MultiArrayView<2, T2, C2> labels,
157  Region & region,
158  vigra::ArrayVector<Region>& childRegions,
159  Random & randint)
160 {
161  typedef typename Region::IndexIterator IndexIteratorType;
162 
163  if (m_CalculatingFeature)
164  {
165  // Do some very fance stuff here!!
166 
167  // This is not so simple as it might look! We need to
168  // remember which feature has been used to be able to
169  // use it for testing again!!
170  // There, no Splitting class is used!!
171  }
172 
173  bgfunc.UsePointWeights(m_UseWeights);
174  bgfunc.UseRandomSplit(m_UseRandomSplit);
175 
176  vigra::detail::Correction<TTag>::exec(region, labels);
177  // Create initial class count.
178  for(std::size_t i = 0; i < region.classCounts_.size(); ++i)
179  {
180  region.classCounts_[i] = 0;
181  }
182  double regionSum = 0;
183  for (typename Region::IndexIterator iter = region.begin(); iter != region.end(); ++iter)
184  {
185  double probability = 1.0;
186  if (m_UseWeights)
187  {
188  probability = m_Weights(*iter, 0);
189  }
190  region.classCounts_[labels(*iter,0)] += probability;
191  regionSum += probability;
192  }
193  region.classCountsIsValid = true;
194  vigra::ArrayVector<double> vec;
195 
196  // Is pure region?
197  region_gini_ = bgfunc.LossOfRegion(labels,
198  region.begin(),
199  region.end(),
200  region.classCounts());
201  if (region_gini_ <= m_Precision * regionSum) // Necessary to fix wrong calculation of Gini-Index
202  {
203  return this->makeTerminalNode(features, labels, region, randint);
204  }
205 
206  // Randomize the order of columns
207  for (int i = 0; i < SB::ext_param_.actual_mtry_; ++i)
208  {
209  std::swap(splitColumns[i],
210  splitColumns[i+ randint(features.shape(1) - i)]);
211  }
212 
213  // find the split with the best evaluation value
214  bestSplitIndex = 0;
215  double currentMiniGini = region_gini_;
216  int numberOfTrials = features.shape(1);
217  for (int k = 0; k < numberOfTrials; ++k)
218  {
219  bgfunc(columnVector(features, splitColumns[k]),
220  labels,
221  region.begin(), region.end(),
222  region.classCounts());
223  min_gini_[k] = bgfunc.GetMinimumLoss();
224  min_indices_[k] = bgfunc.GetMinimumIndex();
225  min_thresholds_[k] = bgfunc.GetMinimumThreshold();
226 
227  // removed classifier test section, because not necessary
228  if (bgfunc.GetMinimumLoss() < currentMiniGini)
229  {
230  currentMiniGini = bgfunc.GetMinimumLoss();
231  childRegions[0].classCounts() = bgfunc.GetBestCurrentCounts()[0];
232  childRegions[1].classCounts() = bgfunc.GetBestCurrentCounts()[1];
233  childRegions[0].classCountsIsValid = true;
234  childRegions[1].classCountsIsValid = true;
235 
236  bestSplitIndex = k;
237  numberOfTrials = SB::ext_param_.actual_mtry_;
238  }
239  }
240 
241  //If only a small improvement, make terminal node...
242  if(vigra::closeAtTolerance(currentMiniGini, region_gini_))
243  {
244  return this->makeTerminalNode(features, labels, region, randint);
245  }
246 
247  vigra::Node<vigra::i_ThresholdNode> node(SB::t_data, SB::p_data);
248  SB::node_ = node;
249  node.threshold() = min_thresholds_[bestSplitIndex];
250  node.column() = splitColumns[bestSplitIndex];
251 
252  // partition the range according to the best dimension
253  vigra::SortSamplesByDimensions<vigra::MultiArrayView<2, T, C> >
254  sorter(features, node.column(), node.threshold());
255  IndexIteratorType bestSplit =
256  std::partition(region.begin(), region.end(), sorter);
257  // Save the ranges of the child stack entries.
258  childRegions[0].setRange( region.begin() , bestSplit );
259  childRegions[0].rule = region.rule;
260  childRegions[0].rule.push_back(std::make_pair(1, 1.0));
261  childRegions[1].setRange( bestSplit , region.end() );
262  childRegions[1].rule = region.rule;
263  childRegions[1].rule.push_back(std::make_pair(1, 1.0));
264 
265  return vigra::i_ThresholdNode;
266 
267  return 0;
268 }
269 
270 //template<class TRegion, class TRegionIterator, class TLabelHolder, class TWeightsHolder>
271 //static void UpdateRegionCounts(TRegion & region, TRegionIterator begin, TRegionIterator end, TLabelHolder labels, TWeightsHolder weights)
272 //{
273 // if(std::accumulate(region.classCounts().begin(),
274 // region.classCounts().end(), 0.0) != region.size())
275 // {
276 // RandomForestClassCounter< LabelT,
277 // ArrayVector<double> >
278 // counter(labels, region.classCounts());
279 // std::for_each( region.begin(), region.end(), counter);
280 // region.classCountsIsValid = true;
281 // }
282 //}
283 //
284 //template<class TRegion, class TLabel, class TWeights>
285 //static void exec(Region & region, LabelT & labels)
286 //{
287 // if(std::accumulate(region.classCounts().begin(),
288 // region.classCounts().end(), 0.0) != region.size())
289 // {
290 // RandomForestClassCounter< LabelT,
291 // ArrayVector<double> >
292 // counter(labels, region.classCounts());
293 // std::for_each( region.begin(), region.end(), counter);
294 // region.classCountsIsValid = true;
295 // }
296 //}
297 
298 #endif //mitkThresholdSplit_cpp
bool IsUsingPointBasedWeights() const
void SetFeatureCalculator(TFeatureCalculator processor)
void set_external_parameters(vigra::ProblemSpec< T > const &in)
vigra::MultiArrayView< 2, double > GetWeights() const
TFeatureCalculator GetFeatureCalculator() const
virtual int GetMaximumTreeDepth() const
void UsePointBasedWeights(bool weightsOn)
double GetPrecision() const
double bestSplitThreshold() const
void SetWeights(vigra::MultiArrayView< 2, double > weights)
void SetPrecision(double value)
void SetMaximumTreeDepth(int value)
static bool in(Reader::Char c, Reader::Char c1, Reader::Char c2, Reader::Char c3, Reader::Char c4)
Definition: jsoncpp.cpp:244
const char features[]
static void swap(T &x, T &y)
Definition: svm.cpp:72
void SetCalculatingFeature(bool calculate)
int findBestSplit(vigra::MultiArrayView< 2, T, C > features, vigra::MultiArrayView< 2, T2, C2 > labels, Region &region, vigra::ArrayVector< Region > &childRegions, Random &randint)
bool IsCalculatingFeature() const