1 #ifndef mitkLinearSplitting_cpp
2 #define mitkLinearSplitting_cpp
6 template<
class TLossAccumulator>
8 m_UsePointWeights(false),
9 m_UseRandomSplit(false)
13 template<
class TLossAccumulator>
16 m_UsePointWeights(false),
17 m_UseRandomSplit(false)
22 template<
class TLossAccumulator>
26 m_UsePointWeights = pointWeight;
29 template<
class TLossAccumulator>
33 return m_UsePointWeights;
37 template<
class TLossAccumulator>
41 m_UseRandomSplit = randomSplit;
44 template<
class TLossAccumulator>
48 return m_UseRandomSplit;
51 template<
class TLossAccumulator>
55 m_PointWeights = weight;
58 template<
class TLossAccumulator>
62 return m_PointWeights;
65 template<
class TLossAccumulator>
73 template<
class TLossAccumulator>
74 template <
class TDataSourceFeature,
class TDataSourceLabel,
class TDataIterator,
class TArray>
77 TDataSourceLabel
const &labels,
80 TArray
const ®ionResponse)
82 typedef TLossAccumulator LineSearchLoss;
83 std::sort(begin, end, vigra::SortSamplesByDimensions<TDataSourceFeature>(column, 0));
85 LineSearchLoss left(labels, m_ExtParameter);
86 LineSearchLoss right(labels, m_ExtParameter);
88 if (m_UsePointWeights)
90 left.UsePointWeights(
true);
91 left.SetPointWeights(m_PointWeights);
92 right.UsePointWeights(
true);
93 right.SetPointWeights(m_PointWeights);
96 m_MinimumLoss = right.Init(regionResponse);
97 m_MinimumThreshold = *begin;
100 vigra::DimensionNotEqual<TDataSourceFeature> compareNotEqual(column, 0);
102 if (!m_UseRandomSplit)
104 TDataIterator iter = begin;
106 TDataIterator next = std::adjacent_find(iter, end, compareNotEqual);
111 double rightLoss = right.Decrement(iter, next +1);
112 double leftLoss = left.Increment(iter, next +1);
113 double currentLoss = rightLoss + leftLoss;
115 if (currentLoss < m_MinimumLoss)
117 m_BestCurrentCounts[0] = left.Response();
118 m_BestCurrentCounts[1] = right.Response();
119 m_MinimumLoss = currentLoss;
120 m_MinimumIndex = next - begin + 1;
121 m_MinimumThreshold = (double(column(*next,0)) + double(column(*(next +1), 0)))/2.0;
125 next = std::adjacent_find(iter, end, compareNotEqual);
130 int size = end - begin + 1;
132 int offset = rand() % size;
133 TDataIterator iter = begin +
offset;
135 double rightLoss = right.Decrement(begin, iter+1);
136 double leftLoss = left.Increment(begin, iter+1);
137 double currentLoss = rightLoss + leftLoss;
139 if (currentLoss < m_MinimumLoss)
141 m_BestCurrentCounts[0] = left.Response();
142 m_BestCurrentCounts[1] = right.Response();
143 m_MinimumLoss = currentLoss;
144 m_MinimumIndex = offset + 1;
145 m_MinimumThreshold = (double(column(*iter,0)) + double(column(*(iter+1), 0)))/2.0;
150 template<
class TLossAccumulator>
151 template <
class TDataSourceLabel,
class TDataIterator,
class TArray>
156 TArray
const & regionResponse)
158 typedef TLossAccumulator LineSearchLoss;
159 LineSearchLoss regionLoss(labels, m_ExtParameter);
160 if (m_UsePointWeights)
162 regionLoss.UsePointWeights(
true);
163 regionLoss.SetPointWeights(m_PointWeights);
165 return regionLoss.Init(regionResponse);
168 #endif //mitkLinearSplitting_cpp
void operator()(TDataSourceFeature const &column, TDataSourceLabel const &labels, TDataIterator &begin, TDataIterator &end, TArray const ®ionResponse)
void SetPointWeights(WeightContainerType weight)
void set_external_parameters(vigra::ProblemSpec< T > const &ext)
bool IsUsingRandomSplit()
void UseRandomSplit(bool randomSplit)
void UsePointWeights(bool pointWeight)
bool IsUsingPointWeights()
WeightContainerType GetPointWeights()
TWeightContainer WeightContainerType
double LossOfRegion(TDataSourceLabel const &labels, TDataIterator &begin, TDataIterator &end, TArray const ®ionResponse)