Medical Imaging Interaction Toolkit  2018.4.99-389bf124
Medical Imaging Interaction Toolkit
mitkPURFClassifier.cpp
Go to the documentation of this file.
1 /*============================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center (DKFZ)
6 All rights reserved.
7 
8 Use of this source code is governed by a 3-clause BSD license that can be
9 found in the LICENSE file.
10 
11 ============================================================================*/
12 
13 // MITK includes
14 #include <mitkPURFClassifier.h>
15 #include <mitkThresholdSplit.h>
16 #include <mitkPUImpurityLoss.h>
17 #include <mitkImpurityLoss.h>
18 #include <mitkLinearSplitting.h>
19 #include <mitkProperties.h>
20 
21 // Vigra includes
22 #include <vigra/random_forest.hxx>
23 #include <vigra/random_forest/rf_split.hxx>
24 
25 // ITK include
26 #include <itkFastMutexLock.h>
27 #include <itkMultiThreader.h>
28 #include <itkCommand.h>
29 
31 
32 struct mitk::PURFClassifier::Parameter
33 {
34  vigra::RF_OptionTag Stratification;
35  bool SampleWithReplacement;
36  bool UseRandomSplit;
37  bool UsePointBasedWeights;
38  int TreeCount;
39  int MinimumSplitNodeSize;
40  int TreeDepth;
41  double Precision;
42  double WeightLambda;
43  double SamplesPerTree;
44 };
45 
46 struct mitk::PURFClassifier::TrainingData
47 {
48  TrainingData(unsigned int numberOfTrees,
49  const vigra::RandomForest<int> & refRF,
50  const DefaultPUSplitType & refSplitter,
51  const vigra::MultiArrayView<2, double> refFeature,
52  const vigra::MultiArrayView<2, int> refLabel,
53  const Parameter parameter)
54  : m_ClassCount(0),
55  m_NumberOfTrees(numberOfTrees),
56  m_RandomForest(refRF),
57  m_Splitter(refSplitter),
58  m_Feature(refFeature),
59  m_Label(refLabel),
60  m_Parameter(parameter)
61  {
62  m_mutex = itk::FastMutexLock::New();
63  }
64  vigra::ArrayVector<vigra::RandomForest<int>::DecisionTree_t> trees_;
65 
66  int m_ClassCount;
67  unsigned int m_NumberOfTrees;
68  const vigra::RandomForest<int> & m_RandomForest;
69  const DefaultPUSplitType & m_Splitter;
70  const vigra::MultiArrayView<2, double> m_Feature;
71  const vigra::MultiArrayView<2, int> m_Label;
72  itk::FastMutexLock::Pointer m_mutex;
73  Parameter m_Parameter;
74 };
75 
76 struct mitk::PURFClassifier::PredictionData
77 {
78  PredictionData(const vigra::RandomForest<int> & refRF,
79  const vigra::MultiArrayView<2, double> refFeature,
80  vigra::MultiArrayView<2, int> refLabel,
81  vigra::MultiArrayView<2, double> refProb,
82  vigra::MultiArrayView<2, double> refTreeWeights)
83  : m_RandomForest(refRF),
84  m_Feature(refFeature),
85  m_Label(refLabel),
86  m_Probabilities(refProb),
87  m_TreeWeights(refTreeWeights)
88  {
89  }
90  const vigra::RandomForest<int> & m_RandomForest;
91  const vigra::MultiArrayView<2, double> m_Feature;
92  vigra::MultiArrayView<2, int> m_Label;
93  vigra::MultiArrayView<2, double> m_Probabilities;
94  vigra::MultiArrayView<2, double> m_TreeWeights;
95 };
96 
98  :m_Parameter(nullptr)
99 {
100  itk::SimpleMemberCommand<mitk::PURFClassifier>::Pointer command = itk::SimpleMemberCommand<mitk::PURFClassifier>::New();
101  command->SetCallbackFunction(this, &mitk::PURFClassifier::ConvertParameter);
102  this->GetPropertyList()->AddObserver( itk::ModifiedEvent(), command );
103 }
104 
106 {
107 }
108 
109 void mitk::PURFClassifier::SetClassProbabilities(Eigen::VectorXd probabilities)
110 {
111  m_ClassProbabilities = probabilities;
112 }
113 
115 {
116  return m_ClassProbabilities;
117 }
118 
120 {
121  return true;
122 }
123 
125 {
126  return true;
127 }
128 
129 
130 vigra::ArrayVector<double> mitk::PURFClassifier::CalculateKappa(const Eigen::MatrixXd & /* X_in */, const Eigen::MatrixXi & Y_in)
131 {
132  int maximumValue = Y_in.maxCoeff();
133  vigra::ArrayVector<double> kappa(maximumValue + 1);
134  vigra::ArrayVector<double> counts(maximumValue + 1);
135  for (int i = 0; i < Y_in.rows(); ++i)
136  {
137  counts[Y_in(i, 0)] += 1;
138  }
139  for (int i = 0; i < maximumValue+1; ++i)
140  {
141  if (counts[i] > 0)
142  {
143  kappa[i] = counts[0] * m_ClassProbabilities[i] / counts[i] + 1;
144  }
145  else
146  {
147  kappa[i] = 1;
148  }
149  }
150  return kappa;
151 }
152 
153 
154 void mitk::PURFClassifier::Train(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in)
155 {
156  this->ConvertParameter();
157 
158  PURFData* purfData = new PURFData;
159  purfData->m_Kappa = this->CalculateKappa(X_in, Y_in);
160 
161  DefaultPUSplitType splitter;
162  splitter.UsePointBasedWeights(m_Parameter->UsePointBasedWeights);
163  splitter.UseRandomSplit(m_Parameter->UseRandomSplit);
164  splitter.SetPrecision(m_Parameter->Precision);
165  splitter.SetMaximumTreeDepth(m_Parameter->TreeDepth);
166  splitter.SetAdditionalData(purfData);
167 
168  // Weights handled as member variable
169  if (m_Parameter->UsePointBasedWeights)
170  {
171  // Set influence of the weight (0 no influenc to 1 max influence)
172  this->m_PointWiseWeight.unaryExpr([this](double t){ return std::pow(t, this->m_Parameter->WeightLambda) ;});
173 
174  vigra::MultiArrayView<2, double> W(vigra::Shape2(this->m_PointWiseWeight.rows(),this->m_PointWiseWeight.cols()),this->m_PointWiseWeight.data());
175  splitter.SetWeights(W);
176  }
177 
178  vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
179  vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data());
180 
181  m_RandomForest.set_options().tree_count(1); // Number of trees that are calculated;
182 
183  m_RandomForest.set_options().use_stratification(m_Parameter->Stratification);
184  m_RandomForest.set_options().sample_with_replacement(m_Parameter->SampleWithReplacement);
185  m_RandomForest.set_options().samples_per_tree(m_Parameter->SamplesPerTree);
186  m_RandomForest.set_options().min_split_node_size(m_Parameter->MinimumSplitNodeSize);
187 
188  m_RandomForest.learn(X, Y,vigra::rf::visitors::VisitorBase(),splitter);
189 
190  std::unique_ptr<TrainingData> data(new TrainingData(m_Parameter->TreeCount,m_RandomForest,splitter,X,Y, *m_Parameter));
191 
192  itk::MultiThreader::Pointer threader = itk::MultiThreader::New();
193  threader->SetSingleMethod(this->TrainTreesCallback,data.get());
194  threader->SingleMethodExecute();
195 
196  // set result trees
197  m_RandomForest.set_options().tree_count(m_Parameter->TreeCount);
198  m_RandomForest.ext_param_.class_count_ = data->m_ClassCount;
199  m_RandomForest.trees_ = data->trees_;
200 
201  // Set Tree Weights to default
202  m_TreeWeights = Eigen::MatrixXd(m_Parameter->TreeCount,1);
203  m_TreeWeights.fill(1.0);
204  delete purfData;
205 }
206 
207 Eigen::MatrixXi mitk::PURFClassifier::Predict(const Eigen::MatrixXd &X_in)
208 {
209  // Initialize output Eigen matrices
210  m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count());
211  m_OutProbability.fill(0);
212  m_OutLabel = Eigen::MatrixXi(X_in.rows(),1);
213  m_OutLabel.fill(0);
214 
215  // If no weights provided
216  if(m_TreeWeights.rows() != m_RandomForest.tree_count())
217  {
218  m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1);
219  m_TreeWeights.fill(1);
220  }
221 
222  vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data());
223  vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data());
224  vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
225  vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data());
226 
227  std::unique_ptr<PredictionData> data;
228  data.reset(new PredictionData(m_RandomForest, X, Y, P, TW));
229 
230  itk::MultiThreader::Pointer threader = itk::MultiThreader::New();
231  threader->SetSingleMethod(this->PredictCallback, data.get());
232  threader->SingleMethodExecute();
233 
234  m_Probabilities = data->m_Probabilities;
235  return m_OutLabel;
236 }
237 
238 ITK_THREAD_RETURN_TYPE mitk::PURFClassifier::TrainTreesCallback(void * arg)
239 {
240  // Get the ThreadInfoStruct
241  typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType;
242  ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg );
243 
244  TrainingData * data = (TrainingData *)(infoStruct->UserData);
245  unsigned int numberOfTreesToCalculate = 0;
246 
247  // define the number of tress the forest have to calculate
248  numberOfTreesToCalculate = data->m_NumberOfTrees / infoStruct->NumberOfThreads;
249 
250  // the 0th thread takes the residuals
251  if(infoStruct->ThreadID == 0) numberOfTreesToCalculate += data->m_NumberOfTrees % infoStruct->NumberOfThreads;
252 
253  if(numberOfTreesToCalculate != 0){
254  // Copy the Treestructure defined in userData
255  vigra::RandomForest<int> rf = data->m_RandomForest;
256 
257  // Initialize a splitter for the leraning process
258  DefaultPUSplitType splitter;
259  splitter.UsePointBasedWeights(data->m_Splitter.IsUsingPointBasedWeights());
260  splitter.UseRandomSplit(data->m_Splitter.IsUsingRandomSplit());
261  splitter.SetPrecision(data->m_Splitter.GetPrecision());
262  splitter.SetMaximumTreeDepth(data->m_Splitter.GetMaximumTreeDepth());
263  splitter.SetWeights(data->m_Splitter.GetWeights());
264  splitter.SetAdditionalData(data->m_Splitter.GetAdditionalData());
265 
266  rf.trees_.clear();
267  rf.set_options().tree_count(numberOfTreesToCalculate);
268  rf.set_options().use_stratification(data->m_Parameter.Stratification);
269  rf.set_options().sample_with_replacement(data->m_Parameter.SampleWithReplacement);
270  rf.set_options().samples_per_tree(data->m_Parameter.SamplesPerTree);
271  rf.set_options().min_split_node_size(data->m_Parameter.MinimumSplitNodeSize);
272  rf.learn(data->m_Feature, data->m_Label,vigra::rf::visitors::VisitorBase(),splitter);
273 
274  data->m_mutex->Lock();
275 
276  for(const auto & tree : rf.trees_)
277  data->trees_.push_back(tree);
278 
279  data->m_ClassCount = rf.class_count();
280  data->m_mutex->Unlock();
281  }
282 
283  return ITK_THREAD_RETURN_VALUE;
284 
285 }
286 
287 ITK_THREAD_RETURN_TYPE mitk::PURFClassifier::PredictCallback(void * arg)
288 {
289  // Get the ThreadInfoStruct
290  typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType;
291  ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg );
292  // assigne the thread id
293  const unsigned int threadId = infoStruct->ThreadID;
294 
295  // Get the user defined parameters containing all
296  // neccesary informations
297  PredictionData * data = (PredictionData *)(infoStruct->UserData);
298  unsigned int numberOfRowsToCalculate = 0;
299 
300  // Get number of rows to calculate
301  numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads;
302 
303  unsigned int start_index = numberOfRowsToCalculate * threadId;
304  unsigned int end_index = numberOfRowsToCalculate * (threadId+1);
305 
306  // the last thread takes the residuals
307  if(threadId == infoStruct->NumberOfThreads-1) {
308  end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads;
309  }
310 
311  vigra::MultiArrayView<2, double> split_features;
312  vigra::MultiArrayView<2, int> split_labels;
313  vigra::MultiArrayView<2, double> split_probability;
314  {
315  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
316  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Feature.shape(1));
317  split_features = data->m_Feature.subarray(lowerBound,upperBound);
318  }
319 
320  {
321  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
322  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index, data->m_Label.shape(1));
323  split_labels = data->m_Label.subarray(lowerBound,upperBound);
324  }
325 
326  {
327  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
328  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Probabilities.shape(1));
329  split_probability = data->m_Probabilities.subarray(lowerBound,upperBound);
330  }
331 
332  data->m_RandomForest.predictLabels(split_features,split_labels);
333  data->m_RandomForest.predictProbabilities(split_features, split_probability);
334 
335 
336  return ITK_THREAD_RETURN_VALUE;
337 
338 }
339 
341 {
342  if(this->m_Parameter == nullptr)
343  this->m_Parameter = new Parameter();
344  // Get the proerty // Some defaults
345 
346  MITK_INFO("PURFClassifier") << "Convert Parameter";
347  if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) this->m_Parameter->UsePointBasedWeights = false;
348  if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit)) this->m_Parameter->UseRandomSplit = false;
349  if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth)) this->m_Parameter->TreeDepth = 20;
350  if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount)) this->m_Parameter->TreeCount = 100;
351  if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) this->m_Parameter->MinimumSplitNodeSize = 5;
352  if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision)) this->m_Parameter->Precision = mitk::eps;
353  if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree)) this->m_Parameter->SamplesPerTree = 0.6;
354  if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement)) this->m_Parameter->SampleWithReplacement = true;
355  if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda)) this->m_Parameter->WeightLambda = 1.0; // Not used yet
356  // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification))
357  this->m_Parameter->Stratification = vigra::RF_NONE; // no Property given
358 }
359 
360 void mitk::PURFClassifier::PrintParameter(std::ostream & str)
361 {
362  if(this->m_Parameter == nullptr)
363  {
364  MITK_WARN("PURFClassifier") << "Parameters are not initialized. Please call ConvertParameter() first!";
365  return;
366  }
367 
368  this->ConvertParameter();
369 
370  // Get the proerty // Some defaults
371  if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights))
372  str << "usepointbasedweight\tNOT SET (default " << this->m_Parameter->UsePointBasedWeights << ")" << "\n";
373  else
374  str << "usepointbasedweight\t" << this->m_Parameter->UsePointBasedWeights << "\n";
375 
376  if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit))
377  str << "userandomsplit\tNOT SET (default " << this->m_Parameter->UseRandomSplit << ")" << "\n";
378  else
379  str << "userandomsplit\t" << this->m_Parameter->UseRandomSplit << "\n";
380 
381  if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth))
382  str << "treedepth\t\tNOT SET (default " << this->m_Parameter->TreeDepth << ")" << "\n";
383  else
384  str << "treedepth\t\t" << this->m_Parameter->TreeDepth << "\n";
385 
386  if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize))
387  str << "minimalsplitnodesize\tNOT SET (default " << this->m_Parameter->MinimumSplitNodeSize << ")" << "\n";
388  else
389  str << "minimalsplitnodesize\t" << this->m_Parameter->MinimumSplitNodeSize << "\n";
390 
391  if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision))
392  str << "precision\t\tNOT SET (default " << this->m_Parameter->Precision << ")" << "\n";
393  else
394  str << "precision\t\t" << this->m_Parameter->Precision << "\n";
395 
396  if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree))
397  str << "samplespertree\tNOT SET (default " << this->m_Parameter->SamplesPerTree << ")" << "\n";
398  else
399  str << "samplespertree\t" << this->m_Parameter->SamplesPerTree << "\n";
400 
401  if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement))
402  str << "samplewithreplacement\tNOT SET (default " << this->m_Parameter->SampleWithReplacement << ")" << "\n";
403  else
404  str << "samplewithreplacement\t" << this->m_Parameter->SampleWithReplacement << "\n";
405 
406  if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount))
407  str << "treecount\t\tNOT SET (default " << this->m_Parameter->TreeCount << ")" << "\n";
408  else
409  str << "treecount\t\t" << this->m_Parameter->TreeCount << "\n";
410 
411  if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda))
412  str << "lambda\t\tNOT SET (default " << this->m_Parameter->WeightLambda << ")" << "\n";
413  else
414  str << "lambda\t\t" << this->m_Parameter->WeightLambda << "\n";
415 
416  // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification))
417  // this->m_Parameter->Stratification = vigra:RF_NONE; // no Property given
418 }
419 
421 {
423  this->GetPropertyList()->SetBoolProperty("usepointbasedweight",val);
424 }
425 
427 {
428  this->GetPropertyList()->SetIntProperty("treedepth",val);
429 }
430 
432 {
433  this->GetPropertyList()->SetIntProperty("minimalsplitnodesize",val);
434 }
435 
437 {
438  this->GetPropertyList()->SetDoubleProperty("precision",val);
439 }
440 
442 {
443  this->GetPropertyList()->SetDoubleProperty("samplespertree",val);
444 }
445 
447 {
448  this->GetPropertyList()->SetBoolProperty("samplewithreplacement",val);
449 }
450 
452 {
453  this->GetPropertyList()->SetIntProperty("treecount",val);
454 }
455 
457 {
458  this->GetPropertyList()->SetDoubleProperty("lambda",val);
459 }
460 
461 void mitk::PURFClassifier::SetRandomForest(const vigra::RandomForest<int> & rf)
462 {
463  this->SetMaximumTreeDepth(rf.ext_param().max_tree_depth);
464  this->SetMinimumSplitNodeSize(rf.options().min_split_node_size_);
465  this->SetTreeCount(rf.options().tree_count_);
466  this->SetSamplesPerTree(rf.options().training_set_proportion_);
467  this->UseSampleWithReplacement(rf.options().sample_with_replacement_);
468  this->m_RandomForest = rf;
469 }
470 
471 const vigra::RandomForest<int> & mitk::PURFClassifier::GetRandomForest() const
472 {
473  return this->m_RandomForest;
474 }
vigra::ArrayVector< double > CalculateKappa(const Eigen::MatrixXd &X_in, const Eigen::MatrixXi &Y_in)
#define MITK_INFO
Definition: mitkLogMacros.h:18
void SetAdditionalData(AdditionalRFDataAbstract *data)
void UseRandomSplit(bool split)
void UsePointBasedWeights(bool weightsOn)
void PrintParameter(std::ostream &str=std::cout)
void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y) override
Build a forest of trees from the training set (X, y).
#define MITK_WARN
Definition: mitkLogMacros.h:19
void SetRandomForest(const vigra::RandomForest< int > &rf)
void UsePointWiseWeight(bool) override
UsePointWiseWeight.
void SetWeights(vigra::MultiArrayView< 2, double > weights)
mitk::ThresholdSplit< mitk::LinearSplitting< mitk::PUImpurityLoss<> >, int, vigra::ClassificationTag > DefaultPUSplitType
virtual void UsePointWiseWeight(bool value)
UsePointWiseWeight.
const vigra::RandomForest< int > & GetRandomForest() const
void SetPrecision(double value)
void SetMaximumTreeDepth(int value)
Eigen::MatrixXi Predict(const Eigen::MatrixXd &X) override
Predict class for X.
mitk::PropertyList::Pointer GetPropertyList() const
Get the data&#39;s property list.
vigra::ArrayVector< double > m_Kappa
bool SupportsPointWiseProbability() override
SupportsPointWiseProbability.
MITKCORE_EXPORT const ScalarType eps
void SetClassProbabilities(Eigen::VectorXd probabilities)
Eigen::VectorXd GetClassProbabilites()
bool SupportsPointWiseWeight() override
SupportsPointWiseWeight.