13 #ifndef mitkThresholdSplit_cpp 14 #define mitkThresholdSplit_cpp 18 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
20 m_CalculatingFeature(false),
22 m_UseRandomSplit(false),
24 m_MaximumTreeDepth(1000),
25 m_AdditionalData(nullptr)
40 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
44 bgfunc.SetAdditionalData(data);
45 m_AdditionalData = data;
48 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
52 return m_AdditionalData;
55 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
59 m_FeatureCalculator = processor;
62 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
66 return m_FeatureCalculator;
69 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
73 m_CalculatingFeature = calculate;
76 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
80 return m_CalculatingFeature;
83 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
87 m_UseWeights = weightsOn;
88 bgfunc.UsePointWeights(weightsOn);
91 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
98 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
105 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
112 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
116 m_MaximumTreeDepth = value;
119 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
123 return m_MaximumTreeDepth;
126 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
131 bgfunc.UsePointWeights(m_UseWeights);
132 bgfunc.SetPointWeights(weights);
135 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
136 vigra::MultiArrayView<2, double>
142 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
146 return min_gini_[bestSplitIndex];
149 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
153 return splitColumns[bestSplitIndex];
156 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
160 return min_thresholds_[bestSplitIndex];
163 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
168 SB::set_external_parameters(in);
169 bgfunc.set_external_parameters( SB::ext_param_);
170 int featureCount_ = SB::ext_param_.column_count_;
171 splitColumns.resize(featureCount_);
172 for(
int k=0;
k<featureCount_; ++
k)
174 min_gini_.resize(featureCount_);
175 min_indices_.resize(featureCount_);
176 min_thresholds_.resize(featureCount_);
179 template<
class TColumnDecisionFunctor,
class TFeatureCalculator,
class TTag>
180 template<
class T,
class C,
class T2,
class C2,
class Region,
class Random>
183 vigra::MultiArrayView<2, T2, C2> labels,
185 vigra::ArrayVector<Region>& childRegions,
188 typedef typename Region::IndexIterator IndexIteratorType;
190 if (m_CalculatingFeature)
200 bgfunc.UsePointWeights(m_UseWeights);
201 bgfunc.UseRandomSplit(m_UseRandomSplit);
203 vigra::detail::Correction<TTag>::exec(region, labels);
205 for(std::size_t i = 0; i < region.classCounts_.size(); ++i)
207 region.classCounts_[i] = 0;
209 double regionSum = 0;
210 for (
typename Region::IndexIterator iter = region.begin(); iter != region.end(); ++iter)
212 double probability = 1.0;
215 probability = m_Weights(*iter, 0);
217 region.classCounts_[labels(*iter,0)] += probability;
218 regionSum += probability;
220 region.classCountsIsValid =
true;
221 vigra::ArrayVector<double> vec;
227 region.classCounts());
230 return this->makeTerminalNode(features, labels, region, randint);
234 for (
int i = 0; i < SB::ext_param_.actual_mtry_; ++i)
237 splitColumns[i+ randint(features.shape(1) - i)]);
243 int numberOfTrials = features.shape(1);
244 for (
int k = 0;
k < numberOfTrials; ++
k)
246 bgfunc(columnVector(features, splitColumns[
k]),
248 region.begin(), region.end(),
249 region.classCounts());
250 min_gini_[
k] = bgfunc.GetMinimumLoss();
251 min_indices_[
k] = bgfunc.GetMinimumIndex();
252 min_thresholds_[
k] = bgfunc.GetMinimumThreshold();
255 if (bgfunc.GetMinimumLoss() < currentMiniGini)
257 currentMiniGini = bgfunc.GetMinimumLoss();
258 childRegions[0].classCounts() = bgfunc.GetBestCurrentCounts()[0];
259 childRegions[1].classCounts() = bgfunc.GetBestCurrentCounts()[1];
260 childRegions[0].classCountsIsValid =
true;
261 childRegions[1].classCountsIsValid =
true;
264 numberOfTrials = SB::ext_param_.actual_mtry_;
269 if(vigra::closeAtTolerance(currentMiniGini,
region_gini_))
271 return this->makeTerminalNode(features, labels, region, randint);
274 vigra::Node<vigra::i_ThresholdNode> node(SB::t_data, SB::p_data);
276 node.threshold() = min_thresholds_[bestSplitIndex];
277 node.column() = splitColumns[bestSplitIndex];
280 vigra::SortSamplesByDimensions<vigra::MultiArrayView<2, T, C> >
281 sorter(features, node.column(), node.threshold());
282 IndexIteratorType bestSplit =
283 std::partition(region.begin(), region.end(), sorter);
285 childRegions[0].setRange( region.begin() , bestSplit );
286 childRegions[0].rule = region.rule;
287 childRegions[0].rule.push_back(std::make_pair(1, 1.0));
288 childRegions[1].setRange( bestSplit , region.end() );
289 childRegions[1].rule = region.rule;
290 childRegions[1].rule.push_back(std::make_pair(1, 1.0));
292 return vigra::i_ThresholdNode;
325 #endif //mitkThresholdSplit_cpp
void SetFeatureCalculator(TFeatureCalculator processor)
void SetAdditionalData(AdditionalRFDataAbstract *data)
void set_external_parameters(vigra::ProblemSpec< T > const &in)
int bestSplitColumn() const
void UsePointBasedWeights(bool weightsOn)
double GetPrecision() const
bool IsUsingPointBasedWeights() const
AdditionalRFDataAbstract * GetAdditionalData() const
int GetMaximumTreeDepth() const override
void SetWeights(vigra::MultiArrayView< 2, double > weights)
bool IsCalculatingFeature() const
double bestSplitThreshold() const
void SetPrecision(double value)
void SetMaximumTreeDepth(int value)
static bool in(Reader::Char c, Reader::Char c1, Reader::Char c2, Reader::Char c3, Reader::Char c4)
static void swap(T &x, T &y)
void SetCalculatingFeature(bool calculate)
vigra::MultiArrayView< 2, double > GetWeights() const
int findBestSplit(vigra::MultiArrayView< 2, T, C > features, vigra::MultiArrayView< 2, T2, C2 > labels, Region ®ion, vigra::ArrayVector< Region > &childRegions, Random &randint)
TFeatureCalculator GetFeatureCalculator() const