Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
mitkThresholdSplit.cpp
Go to the documentation of this file.
1 #ifndef mitkThresholdSplit_cpp
2 #define mitkThresholdSplit_cpp
3 
4 #include <mitkThresholdSplit.h>
5 
6 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
8  m_CalculatingFeature(false),
9  m_UseWeights(false),
10  m_UseRandomSplit(false),
11  m_Precision(0.0),
12  m_MaximumTreeDepth(1000)
13 {
14 }
15 
16 //template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
17 //mitk::ThresholdSplit<TColumnDecisionFunctor, TFeatureCalculator, TTag>::ThresholdSplit(const ThresholdSplit & /*other*/)/*:
18 // m_CalculatingFeature(other.IsCalculatingFeature()),
19 // m_UseWeights(other.IsUsingPointBasedWeights()),
20 // m_UseRandomSplit(other.IsUsingRandomSplit()),
21 // m_Precision(other.GetPrecision()),
22 // m_MaximumTreeDepth(other.GetMaximumTreeDepth()),
23 // m_FeatureCalculator(other.GetFeatureCalculator()),
24 // m_Weights(other.GetWeights())*/
25 //{
26 //}
27 
28 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
29 void
31 {
32  m_FeatureCalculator = processor;
33 }
34 
35 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
36 TFeatureCalculator
38 {
39  return m_FeatureCalculator;
40 }
41 
42 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
43 void
45 {
46  m_CalculatingFeature = calculate;
47 }
48 
49 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
50 bool
52 {
53  return m_CalculatingFeature;
54 }
55 
56 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
57 void
59 {
60  m_UseWeights = weightsOn;
61  bgfunc.UsePointWeights(weightsOn);
62 }
63 
64 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
65 bool
67 {
68  return m_UseWeights;
69 }
70 
71 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
72 void
74 {
75  m_Precision = value;
76 }
77 
78 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
79 double
81 {
82  return m_Precision;
83 }
84 
85 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
86 void
88 {
89  m_MaximumTreeDepth = value;
90 }
91 
92 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
93 int
95 {
96  return m_MaximumTreeDepth;
97 }
98 
99 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
100 void
102 {
103  m_Weights = weights;
104  bgfunc.UsePointWeights(m_UseWeights);
105  bgfunc.SetPointWeights(weights);
106 }
107 
108 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
109 vigra::MultiArrayView<2, double>
111 {
112  return m_Weights;
113 }
114 
115 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
116 double
118 {
119  return min_gini_[bestSplitIndex];
120 }
121 
122 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
123 int
125 {
126  return splitColumns[bestSplitIndex];
127 }
128 
129 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
130 double
132 {
133  return min_thresholds_[bestSplitIndex];
134 }
135 
136 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
137 template<class T>
138 void
140 {
141  SB::set_external_parameters(in);
142  bgfunc.set_external_parameters( SB::ext_param_);
143  int featureCount_ = SB::ext_param_.column_count_;
144  splitColumns.resize(featureCount_);
145  for(int k=0; k<featureCount_; ++k)
146  splitColumns[k] = k;
147  min_gini_.resize(featureCount_);
148  min_indices_.resize(featureCount_);
149  min_thresholds_.resize(featureCount_);
150 }
151 
152 template<class TColumnDecisionFunctor, class TFeatureCalculator, class TTag>
153 template<class T, class C, class T2, class C2, class Region, class Random>
154 int
156  vigra::MultiArrayView<2, T2, C2> labels,
157  Region & region,
158  vigra::ArrayVector<Region>& childRegions,
159  Random & randint)
160 {
161  typedef typename Region::IndexIterator IndexIteratorType;
162 
163  if (m_CalculatingFeature)
164  {
165  // Do some very fance stuff here!!
166 
167  // This is not so simple as it might look! We need to
168  // remember which feature has been used to be able to
169  // use it for testing again!!
170  // There, no Splitting class is used!!
171  }
172 
173  bgfunc.UsePointWeights(m_UseWeights);
174  bgfunc.UseRandomSplit(m_UseRandomSplit);
175 
176  vigra::detail::Correction<TTag>::exec(region, labels);
177  // Create initial class count.
178  for(std::size_t i = 0; i < region.classCounts_.size(); ++i)
179  {
180  region.classCounts_[i] = 0;
181  }
182  double regionSum = 0;
183  for (typename Region::IndexIterator iter = region.begin(); iter != region.end(); ++iter)
184  {
185  double probability = 1.0;
186  if (m_UseWeights)
187  {
188  probability = m_Weights(*iter, 0);
189  }
190  region.classCounts_[labels(*iter,0)] += probability;
191  regionSum += probability;
192  }
193  region.classCountsIsValid = true;
194  vigra::ArrayVector<double> vec;
195 
196  // Is pure region?
197  region_gini_ = bgfunc.LossOfRegion(labels,
198  region.begin(),
199  region.end(),
200  region.classCounts());
201  if (region_gini_ <= m_Precision * regionSum) // Necessary to fix wrong calculation of Gini-Index
202  {
203  return this->makeTerminalNode(features, labels, region, randint);
204  }
205 
206  // Randomize the order of columns
207  for (int i = 0; i < SB::ext_param_.actual_mtry_; ++i)
208  {
209  std::swap(splitColumns[i],
210  splitColumns[i+ randint(features.shape(1) - i)]);
211  }
212 
213  // find the split with the best evaluation value
214  bestSplitIndex = 0;
215  double currentMiniGini = region_gini_;
216  int numberOfTrials = features.shape(1);
217  for (int k = 0; k < numberOfTrials; ++k)
218  {
219  bgfunc(columnVector(features, splitColumns[k]),
220  labels,
221  region.begin(), region.end(),
222  region.classCounts());
223  min_gini_[k] = bgfunc.GetMinimumLoss();
224  min_indices_[k] = bgfunc.GetMinimumIndex();
225  min_thresholds_[k] = bgfunc.GetMinimumThreshold();
226 
227  // removed classifier test section, because not necessary
228  if (bgfunc.GetMinimumLoss() < currentMiniGini)
229  {
230  currentMiniGini = bgfunc.GetMinimumLoss();
231  childRegions[0].classCounts() = bgfunc.GetBestCurrentCounts()[0];
232  childRegions[1].classCounts() = bgfunc.GetBestCurrentCounts()[1];
233  childRegions[0].classCountsIsValid = true;
234  childRegions[1].classCountsIsValid = true;
235 
236  bestSplitIndex = k;
237  numberOfTrials = SB::ext_param_.actual_mtry_;
238  }
239  }
240 
241  //If only a small improvement, make terminal node...
242  if(vigra::closeAtTolerance(currentMiniGini, region_gini_))
243  {
244  return this->makeTerminalNode(features, labels, region, randint);
245  }
246 
247  vigra::Node<vigra::i_ThresholdNode> node(SB::t_data, SB::p_data);
248  SB::node_ = node;
249  node.threshold() = min_thresholds_[bestSplitIndex];
250  node.column() = splitColumns[bestSplitIndex];
251 
252  // partition the range according to the best dimension
253  vigra::SortSamplesByDimensions<vigra::MultiArrayView<2, T, C> >
254  sorter(features, node.column(), node.threshold());
255  IndexIteratorType bestSplit =
256  std::partition(region.begin(), region.end(), sorter);
257  // Save the ranges of the child stack entries.
258  childRegions[0].setRange( region.begin() , bestSplit );
259  childRegions[0].rule = region.rule;
260  childRegions[0].rule.push_back(std::make_pair(1, 1.0));
261  childRegions[1].setRange( bestSplit , region.end() );
262  childRegions[1].rule = region.rule;
263  childRegions[1].rule.push_back(std::make_pair(1, 1.0));
264 
265  return vigra::i_ThresholdNode;
266 
267  return 0;
268 }
269 
270 //template<class TRegion, class TRegionIterator, class TLabelHolder, class TWeightsHolder>
271 //static void UpdateRegionCounts(TRegion & region, TRegionIterator begin, TRegionIterator end, TLabelHolder labels, TWeightsHolder weights)
272 //{
273 // if(std::accumulate(region.classCounts().begin(),
274 // region.classCounts().end(), 0.0) != region.size())
275 // {
276 // RandomForestClassCounter< LabelT,
277 // ArrayVector<double> >
278 // counter(labels, region.classCounts());
279 // std::for_each( region.begin(), region.end(), counter);
280 // region.classCountsIsValid = true;
281 // }
282 //}
283 //
284 //template<class TRegion, class TLabel, class TWeights>
285 //static void exec(Region & region, LabelT & labels)
286 //{
287 // if(std::accumulate(region.classCounts().begin(),
288 // region.classCounts().end(), 0.0) != region.size())
289 // {
290 // RandomForestClassCounter< LabelT,
291 // ArrayVector<double> >
292 // counter(labels, region.classCounts());
293 // std::for_each( region.begin(), region.end(), counter);
294 // region.classCountsIsValid = true;
295 // }
296 //}
297 
298 #endif //mitkThresholdSplit_cpp
bool IsUsingPointBasedWeights() const
void SetFeatureCalculator(TFeatureCalculator processor)
void set_external_parameters(vigra::ProblemSpec< T > const &in)
vigra::MultiArrayView< 2, double > GetWeights() const
TFeatureCalculator GetFeatureCalculator() const
virtual int GetMaximumTreeDepth() const
void UsePointBasedWeights(bool weightsOn)
double GetPrecision() const
double bestSplitThreshold() const
void SetWeights(vigra::MultiArrayView< 2, double > weights)
void SetPrecision(double value)
void SetMaximumTreeDepth(int value)
static bool in(Reader::Char c, Reader::Char c1, Reader::Char c2, Reader::Char c3, Reader::Char c4)
Definition: jsoncpp.cpp:244
const char features[]
static void swap(T &x, T &y)
Definition: svm.cpp:72
void SetCalculatingFeature(bool calculate)
int findBestSplit(vigra::MultiArrayView< 2, T, C > features, vigra::MultiArrayView< 2, T2, C2 > labels, Region &region, vigra::ArrayVector< Region > &childRegions, Random &randint)
bool IsCalculatingFeature() const