13 #ifndef mitkLinearSplitting_cpp 14 #define mitkLinearSplitting_cpp 19 template<
class TLossAccumulator>
21 m_UsePointWeights(false),
22 m_UseRandomSplit(false),
23 m_AdditionalData(nullptr)
27 template<
class TLossAccumulator>
30 m_UsePointWeights(false),
31 m_UseRandomSplit(false)
36 template<
class TLossAccumulator>
40 m_AdditionalData = data;
43 template<
class TLossAccumulator>
47 return m_AdditionalData;
50 template<
class TLossAccumulator>
54 m_UsePointWeights = pointWeight;
57 template<
class TLossAccumulator>
61 return m_UsePointWeights;
65 template<
class TLossAccumulator>
69 m_UseRandomSplit = randomSplit;
72 template<
class TLossAccumulator>
76 return m_UseRandomSplit;
79 template<
class TLossAccumulator>
83 m_PointWeights = weight;
86 template<
class TLossAccumulator>
90 return m_PointWeights;
93 template<
class TLossAccumulator>
101 template<
class TLossAccumulator>
102 template <
class TDataSourceFeature,
class TDataSourceLabel,
class TDataIterator,
class TArray>
105 TDataSourceLabel
const &labels,
106 TDataIterator &begin,
108 TArray
const ®ionResponse)
110 typedef TLossAccumulator LineSearchLoss;
111 std::sort(begin, end, vigra::SortSamplesByDimensions<TDataSourceFeature>(column, 0));
113 LineSearchLoss left(labels, m_ExtParameter, m_AdditionalData);
114 LineSearchLoss right(labels, m_ExtParameter, m_AdditionalData);
116 if (m_UsePointWeights)
118 left.UsePointWeights(
true);
119 left.SetPointWeights(m_PointWeights);
120 right.UsePointWeights(
true);
121 right.SetPointWeights(m_PointWeights);
124 m_MinimumLoss = right.Init(regionResponse);
125 m_MinimumThreshold = *begin;
128 vigra::DimensionNotEqual<TDataSourceFeature> compareNotEqual(column, 0);
130 if (!m_UseRandomSplit)
132 TDataIterator iter = begin;
134 TDataIterator next = std::adjacent_find(iter, end, compareNotEqual);
139 double rightLoss = right.Decrement(iter, next +1);
140 double leftLoss = left.Increment(iter, next +1);
141 double currentLoss = rightLoss + leftLoss;
143 if (currentLoss < m_MinimumLoss)
145 m_BestCurrentCounts[0] = left.Response();
146 m_BestCurrentCounts[1] = right.Response();
147 m_MinimumLoss = currentLoss;
148 m_MinimumIndex = next - begin + 1;
149 m_MinimumThreshold = (double(column(*next,0)) + double(column(*(next +1), 0)))/2.0;
153 next = std::adjacent_find(iter, end, compareNotEqual);
158 int size = end - begin + 1;
159 srand(time(
nullptr));
160 int offset = rand() % size;
161 TDataIterator iter = begin +
offset;
163 double rightLoss = right.Decrement(begin, iter+1);
164 double leftLoss = left.Increment(begin, iter+1);
165 double currentLoss = rightLoss + leftLoss;
167 if (currentLoss < m_MinimumLoss)
169 m_BestCurrentCounts[0] = left.Response();
170 m_BestCurrentCounts[1] = right.Response();
171 m_MinimumLoss = currentLoss;
172 m_MinimumIndex = offset + 1;
173 m_MinimumThreshold = (double(column(*iter,0)) + double(column(*(iter+1), 0)))/2.0;
178 template<
class TLossAccumulator>
179 template <
class TDataSourceLabel,
class TDataIterator,
class TArray>
184 TArray
const & regionResponse)
186 typedef TLossAccumulator LineSearchLoss;
187 LineSearchLoss regionLoss(labels, m_ExtParameter, m_AdditionalData);
188 if (m_UsePointWeights)
190 regionLoss.UsePointWeights(
true);
191 regionLoss.SetPointWeights(m_PointWeights);
193 return regionLoss.Init(regionResponse);
196 #endif //mitkLinearSplitting_cpp void operator()(TDataSourceFeature const &column, TDataSourceLabel const &labels, TDataIterator &begin, TDataIterator &end, TArray const ®ionResponse)
void SetPointWeights(WeightContainerType weight)
void set_external_parameters(vigra::ProblemSpec< T > const &ext)
bool IsUsingRandomSplit()
void UseRandomSplit(bool randomSplit)
void UsePointWeights(bool pointWeight)
bool IsUsingPointWeights()
AdditionalRFDataAbstract * GetAdditionalData() const
WeightContainerType GetPointWeights()
void SetAdditionalData(AdditionalRFDataAbstract *data)
TWeightContainer WeightContainerType
double LossOfRegion(TDataSourceLabel const &labels, TDataIterator &begin, TDataIterator &end, TArray const ®ionResponse)