Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
mitkLinearSplitting.cpp
Go to the documentation of this file.
1 #ifndef mitkLinearSplitting_cpp
2 #define mitkLinearSplitting_cpp
3 
4 #include <mitkLinearSplitting.h>
5 
6 template<class TLossAccumulator>
8  m_UsePointWeights(false),
9  m_UseRandomSplit(false)
10 {
11 }
12 
13 template<class TLossAccumulator>
14 template <class T>
16  m_UsePointWeights(false),
17  m_UseRandomSplit(false)
18 {
20 }
21 
22 template<class TLossAccumulator>
23 void
25 {
26  m_UsePointWeights = pointWeight;
27 }
28 
29 template<class TLossAccumulator>
30 bool
32 {
33  return m_UsePointWeights;
34 }
35 
36 
37 template<class TLossAccumulator>
38 void
40 {
41  m_UseRandomSplit = randomSplit;
42 }
43 
44 template<class TLossAccumulator>
45 bool
47 {
48  return m_UseRandomSplit;
49 }
50 
51 template<class TLossAccumulator>
52 void
54 {
55  m_PointWeights = weight;
56 }
57 
58 template<class TLossAccumulator>
61 {
62  return m_PointWeights;
63 }
64 
65 template<class TLossAccumulator>
66 template <class T>
67 void
69 {
70  m_ExtParameter = ext;
71 }
72 
73 template<class TLossAccumulator>
74 template <class TDataSourceFeature, class TDataSourceLabel, class TDataIterator, class TArray>
75 void
76 mitk::LinearSplitting<TLossAccumulator>::operator()(TDataSourceFeature const &column,
77  TDataSourceLabel const &labels,
78  TDataIterator &begin,
79  TDataIterator &end,
80  TArray const &regionResponse)
81 {
82  typedef TLossAccumulator LineSearchLoss;
83  std::sort(begin, end, vigra::SortSamplesByDimensions<TDataSourceFeature>(column, 0));
84 
85  LineSearchLoss left(labels, m_ExtParameter);
86  LineSearchLoss right(labels, m_ExtParameter);
87 
88  if (m_UsePointWeights)
89  {
90  left.UsePointWeights(true);
91  left.SetPointWeights(m_PointWeights);
92  right.UsePointWeights(true);
93  right.SetPointWeights(m_PointWeights);
94  }
95 
96  m_MinimumLoss = right.Init(regionResponse);
97  m_MinimumThreshold = *begin;
98  m_MinimumIndex = 0;
99 
100  vigra::DimensionNotEqual<TDataSourceFeature> compareNotEqual(column, 0);
101 
102  if (!m_UseRandomSplit)
103  {
104  TDataIterator iter = begin;
105  // Find the next element that are NOT equal with his neightbour!
106  TDataIterator next = std::adjacent_find(iter, end, compareNotEqual);
107 
108  while(next != end)
109  {
110  // Remove or add the current segment are from the LineSearch
111  double rightLoss = right.Decrement(iter, next +1);
112  double leftLoss = left.Increment(iter, next +1);
113  double currentLoss = rightLoss + leftLoss;
114 
115  if (currentLoss < m_MinimumLoss)
116  {
117  m_BestCurrentCounts[0] = left.Response();
118  m_BestCurrentCounts[1] = right.Response();
119  m_MinimumLoss = currentLoss;
120  m_MinimumIndex = next - begin + 1;
121  m_MinimumThreshold = (double(column(*next,0)) + double(column(*(next +1), 0)))/2.0;
122  }
123 
124  iter = next + 1;
125  next = std::adjacent_find(iter, end, compareNotEqual);
126  }
127  }
128  else // If Random split is selected, e.g. ExtraTree behaviour
129  {
130  int size = end - begin + 1;
131  srand(time(NULL));
132  int offset = rand() % size;
133  TDataIterator iter = begin + offset;
134 
135  double rightLoss = right.Decrement(begin, iter+1);
136  double leftLoss = left.Increment(begin, iter+1);
137  double currentLoss = rightLoss + leftLoss;
138 
139  if (currentLoss < m_MinimumLoss)
140  {
141  m_BestCurrentCounts[0] = left.Response();
142  m_BestCurrentCounts[1] = right.Response();
143  m_MinimumLoss = currentLoss;
144  m_MinimumIndex = offset + 1;
145  m_MinimumThreshold = (double(column(*iter,0)) + double(column(*(iter+1), 0)))/2.0;
146  }
147  }
148 }
149 
150 template<class TLossAccumulator>
151 template <class TDataSourceLabel, class TDataIterator, class TArray>
152 double
154  TDataIterator &/*begin*/,
155  TDataIterator &/*end*/,
156  TArray const & regionResponse)
157 {
158  typedef TLossAccumulator LineSearchLoss;
159  LineSearchLoss regionLoss(labels, m_ExtParameter);
160  if (m_UsePointWeights)
161  {
162  regionLoss.UsePointWeights(true);
163  regionLoss.SetPointWeights(m_PointWeights);
164  }
165  return regionLoss.Init(regionResponse);
166 }
167 
168 #endif //mitkLinearSplitting_cpp
void operator()(TDataSourceFeature const &column, TDataSourceLabel const &labels, TDataIterator &begin, TDataIterator &end, TArray const &regionResponse)
void SetPointWeights(WeightContainerType weight)
void set_external_parameters(vigra::ProblemSpec< T > const &ext)
void UseRandomSplit(bool randomSplit)
void UsePointWeights(bool pointWeight)
static Vector3D offset
WeightContainerType GetPointWeights()
TWeightContainer WeightContainerType
double LossOfRegion(TDataSourceLabel const &labels, TDataIterator &begin, TDataIterator &end, TArray const &regionResponse)