1 #ifndef mitkThresholdSplit_cpp
2 #define mitkThresholdSplit_cpp
6 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
8 m_CalculatingFeature(false),
10 m_UseRandomSplit(false),
12 m_MaximumTreeDepth(1000)
28 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
32 m_FeatureCalculator = processor;
35 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
39 return m_FeatureCalculator;
42 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
46 m_CalculatingFeature = calculate;
49 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
53 return m_CalculatingFeature;
56 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
60 m_UseWeights = weightsOn;
61 bgfunc.UsePointWeights(weightsOn);
64 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
71 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
78 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
85 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
89 m_MaximumTreeDepth = value;
92 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
96 return m_MaximumTreeDepth;
99 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
104 bgfunc.UsePointWeights(m_UseWeights);
105 bgfunc.SetPointWeights(weights);
108 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
109 vigra::MultiArrayView<2, double>
115 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
119 return min_gini_[bestSplitIndex];
122 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
126 return splitColumns[bestSplitIndex];
129 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
133 return min_thresholds_[bestSplitIndex];
136 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
141 SB::set_external_parameters(in);
142 bgfunc.set_external_parameters( SB::ext_param_);
143 int featureCount_ = SB::ext_param_.column_count_;
144 splitColumns.resize(featureCount_);
145 for(
int k=0; k<featureCount_; ++k)
147 min_gini_.resize(featureCount_);
148 min_indices_.resize(featureCount_);
149 min_thresholds_.resize(featureCount_);
152 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
153 template<
class T,
class C,
class T2,
class C2,
class Region,
class Random>
156 vigra::MultiArrayView<2, T2, C2> labels,
158 vigra::ArrayVector<Region>& childRegions,
161 typedef typename Region::IndexIterator IndexIteratorType;
163 if (m_CalculatingFeature)
173 bgfunc.UsePointWeights(m_UseWeights);
174 bgfunc.UseRandomSplit(m_UseRandomSplit);
176 vigra::detail::Correction<TTag>::exec(region, labels);
178 for(std::size_t i = 0; i < region.classCounts_.size(); ++i)
180 region.classCounts_[i] = 0;
182 double regionSum = 0;
183 for (
typename Region::IndexIterator iter = region.begin(); iter != region.end(); ++iter)
185 double probability = 1.0;
188 probability = m_Weights(*iter, 0);
190 region.classCounts_[labels(*iter,0)] += probability;
191 regionSum += probability;
193 region.classCountsIsValid =
true;
194 vigra::ArrayVector<double> vec;
197 region_gini_ = bgfunc.LossOfRegion(labels,
200 region.classCounts());
201 if (region_gini_ <= m_Precision * regionSum)
203 return this->makeTerminalNode(features, labels, region, randint);
207 for (
int i = 0; i < SB::ext_param_.actual_mtry_; ++i)
210 splitColumns[i+ randint(features.shape(1) - i)]);
215 double currentMiniGini = region_gini_;
216 int numberOfTrials = features.shape(1);
217 for (
int k = 0; k < numberOfTrials; ++k)
219 bgfunc(columnVector(features, splitColumns[k]),
221 region.begin(), region.end(),
222 region.classCounts());
223 min_gini_[k] = bgfunc.GetMinimumLoss();
224 min_indices_[k] = bgfunc.GetMinimumIndex();
225 min_thresholds_[k] = bgfunc.GetMinimumThreshold();
228 if (bgfunc.GetMinimumLoss() < currentMiniGini)
230 currentMiniGini = bgfunc.GetMinimumLoss();
231 childRegions[0].classCounts() = bgfunc.GetBestCurrentCounts()[0];
232 childRegions[1].classCounts() = bgfunc.GetBestCurrentCounts()[1];
233 childRegions[0].classCountsIsValid =
true;
234 childRegions[1].classCountsIsValid =
true;
237 numberOfTrials = SB::ext_param_.actual_mtry_;
242 if(vigra::closeAtTolerance(currentMiniGini, region_gini_))
244 return this->makeTerminalNode(features, labels, region, randint);
247 vigra::Node<vigra::i_ThresholdNode> node(SB::t_data, SB::p_data);
249 node.threshold() = min_thresholds_[bestSplitIndex];
250 node.column() = splitColumns[bestSplitIndex];
253 vigra::SortSamplesByDimensions<vigra::MultiArrayView<2, T, C> >
254 sorter(features, node.column(), node.threshold());
255 IndexIteratorType bestSplit =
256 std::partition(region.begin(), region.end(), sorter);
258 childRegions[0].setRange( region.begin() , bestSplit );
259 childRegions[0].rule = region.rule;
260 childRegions[0].rule.push_back(std::make_pair(1, 1.0));
261 childRegions[1].setRange( bestSplit , region.end() );
262 childRegions[1].rule = region.rule;
263 childRegions[1].rule.push_back(std::make_pair(1, 1.0));
265 return vigra::i_ThresholdNode;
298 #endif //mitkThresholdSplit_cpp
bool IsUsingPointBasedWeights() const
void SetFeatureCalculator(TFeatureCalculator processor)
void set_external_parameters(vigra::ProblemSpec< T > const &in)
vigra::MultiArrayView< 2, double > GetWeights() const
TFeatureCalculator GetFeatureCalculator() const
virtual int GetMaximumTreeDepth() const
void UsePointBasedWeights(bool weightsOn)
int bestSplitColumn() const
double GetPrecision() const
double bestSplitThreshold() const
void SetWeights(vigra::MultiArrayView< 2, double > weights)
void SetPrecision(double value)
void SetMaximumTreeDepth(int value)
static bool in(Reader::Char c, Reader::Char c1, Reader::Char c2, Reader::Char c3, Reader::Char c4)
static void swap(T &x, T &y)
void SetCalculatingFeature(bool calculate)
int findBestSplit(vigra::MultiArrayView< 2, T, C > features, vigra::MultiArrayView< 2, T2, C2 > labels, Region ®ion, vigra::ArrayVector< Region > &childRegions, Random &randint)
bool IsCalculatingFeature() const