Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
mitkVigraRandomForestClassifier.cpp
Go to the documentation of this file.
1 /*===================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
7 All rights reserved.
8 
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 A PARTICULAR PURPOSE.
12 
13 See LICENSE.txt or http://www.mitk.org for details.
14 
15 ===================================================================*/
16 
17 // MITK includes
19 #include <mitkThresholdSplit.h>
20 #include <mitkImpurityLoss.h>
21 #include <mitkLinearSplitting.h>
22 #include <mitkProperties.h>
23 
24 // Vigra includes
25 #include <vigra/random_forest.hxx>
26 #include <vigra/random_forest/rf_split.hxx>
27 
28 // ITK include
29 #include <itkFastMutexLock.h>
30 #include <itkMultiThreader.h>
31 #include <itkCommand.h>
32 
34 
35 struct mitk::VigraRandomForestClassifier::Parameter
36 {
37  vigra::RF_OptionTag Stratification;
38  bool SampleWithReplacement;
39  bool UseRandomSplit;
40  bool UsePointBasedWeights;
41  int TreeCount;
42  int MinimumSplitNodeSize;
43  int TreeDepth;
44  double Precision;
45  double WeightLambda;
46  double SamplesPerTree;
47 };
48 
49 struct mitk::VigraRandomForestClassifier::TrainingData
50 {
51  TrainingData(unsigned int numberOfTrees,
52  const vigra::RandomForest<int> & refRF,
53  const DefaultSplitType & refSplitter,
54  const vigra::MultiArrayView<2, double> refFeature,
55  const vigra::MultiArrayView<2, int> refLabel,
56  const Parameter parameter)
57  : m_ClassCount(0),
58  m_NumberOfTrees(numberOfTrees),
59  m_RandomForest(refRF),
60  m_Splitter(refSplitter),
61  m_Feature(refFeature),
62  m_Label(refLabel),
63  m_Parameter(parameter)
64  {
65  m_mutex = itk::FastMutexLock::New();
66  }
67  vigra::ArrayVector<vigra::RandomForest<int>::DecisionTree_t> trees_;
68 
69  int m_ClassCount;
70  unsigned int m_NumberOfTrees;
71  const vigra::RandomForest<int> & m_RandomForest;
72  const DefaultSplitType & m_Splitter;
73  const vigra::MultiArrayView<2, double> m_Feature;
74  const vigra::MultiArrayView<2, int> m_Label;
76  Parameter m_Parameter;
77 };
78 
79 struct mitk::VigraRandomForestClassifier::PredictionData
80 {
81  PredictionData(const vigra::RandomForest<int> & refRF,
82  const vigra::MultiArrayView<2, double> refFeature,
83  vigra::MultiArrayView<2, int> refLabel,
84  vigra::MultiArrayView<2, double> refProb,
85  vigra::MultiArrayView<2, double> refTreeWeights)
86  : m_RandomForest(refRF),
87  m_Feature(refFeature),
88  m_Label(refLabel),
89  m_Probabilities(refProb),
90  m_TreeWeights(refTreeWeights)
91  {
92  }
93  const vigra::RandomForest<int> & m_RandomForest;
94  const vigra::MultiArrayView<2, double> m_Feature;
95  vigra::MultiArrayView<2, int> m_Label;
96  vigra::MultiArrayView<2, double> m_Probabilities;
97  vigra::MultiArrayView<2, double> m_TreeWeights;
98 };
99 
101  :m_Parameter(nullptr)
102 {
104  command->SetCallbackFunction(this, &mitk::VigraRandomForestClassifier::ConvertParameter);
105  this->GetPropertyList()->AddObserver( itk::ModifiedEvent(), command );
106 }
107 
109 {
110 }
111 
113 {
114  return true;
115 }
116 
118 {
119  return true;
120 }
121 
122 void mitk::VigraRandomForestClassifier::OnlineTrain(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in)
123 {
124  vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
125  vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data());
126  m_RandomForest.onlineLearn(X,Y,0,true);
127 }
128 
129 void mitk::VigraRandomForestClassifier::Train(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in)
130 {
131  this->ConvertParameter();
132 
133  DefaultSplitType splitter;
134  splitter.UsePointBasedWeights(m_Parameter->UsePointBasedWeights);
135  splitter.UseRandomSplit(m_Parameter->UseRandomSplit);
136  splitter.SetPrecision(m_Parameter->Precision);
137  splitter.SetMaximumTreeDepth(m_Parameter->TreeDepth);
138 
139  // Weights handled as member variable
140  if (m_Parameter->UsePointBasedWeights)
141  {
142  // Set influence of the weight (0 no influenc to 1 max influence)
143  this->m_PointWiseWeight.unaryExpr([this](double t){ return std::pow(t, this->m_Parameter->WeightLambda) ;});
144 
145  vigra::MultiArrayView<2, double> W(vigra::Shape2(this->m_PointWiseWeight.rows(),this->m_PointWiseWeight.cols()),this->m_PointWiseWeight.data());
146  splitter.SetWeights(W);
147  }
148 
149  vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
150  vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data());
151 
152  m_RandomForest.set_options().tree_count(1); // Number of trees that are calculated;
153 
154  m_RandomForest.set_options().use_stratification(m_Parameter->Stratification);
155  m_RandomForest.set_options().sample_with_replacement(m_Parameter->SampleWithReplacement);
156  m_RandomForest.set_options().samples_per_tree(m_Parameter->SamplesPerTree);
157  m_RandomForest.set_options().min_split_node_size(m_Parameter->MinimumSplitNodeSize);
158 
159  m_RandomForest.learn(X, Y,vigra::rf::visitors::VisitorBase(),splitter);
160 
161  std::unique_ptr<TrainingData> data(new TrainingData(m_Parameter->TreeCount,m_RandomForest,splitter,X,Y, *m_Parameter));
162 
164  threader->SetSingleMethod(this->TrainTreesCallback,data.get());
165  threader->SingleMethodExecute();
166 
167  // set result trees
168  m_RandomForest.set_options().tree_count(m_Parameter->TreeCount);
169  m_RandomForest.ext_param_.class_count_ = data->m_ClassCount;
170  m_RandomForest.trees_ = data->trees_;
171 
172  // Set Tree Weights to default
173  m_TreeWeights = Eigen::MatrixXd(m_Parameter->TreeCount,1);
174  m_TreeWeights.fill(1.0);
175 }
176 
177 Eigen::MatrixXi mitk::VigraRandomForestClassifier::Predict(const Eigen::MatrixXd &X_in)
178 {
179  // Initialize output Eigen matrices
180  m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count());
181  m_OutProbability.fill(0);
182  m_OutLabel = Eigen::MatrixXi(X_in.rows(),1);
183  m_OutLabel.fill(0);
184 
185  // If no weights provided
186  if(m_TreeWeights.rows() != m_RandomForest.tree_count())
187  {
188  m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1);
189  m_TreeWeights.fill(1);
190  }
191 
192 
193  vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data());
194  vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data());
195  vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
196  vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data());
197 
198  std::unique_ptr<PredictionData> data;
199  data.reset( new PredictionData(m_RandomForest,X,Y,P,TW));
200 
202  threader->SetSingleMethod(this->PredictCallback,data.get());
203  threader->SingleMethodExecute();
204 
205  return m_OutLabel;
206 }
207 
208 Eigen::MatrixXi mitk::VigraRandomForestClassifier::PredictWeighted(const Eigen::MatrixXd &X_in)
209 {
210  // Initialize output Eigen matrices
211  m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count());
212  m_OutProbability.fill(0);
213  m_OutLabel = Eigen::MatrixXi(X_in.rows(),1);
214  m_OutLabel.fill(0);
215 
216  // If no weights provided
217  if(m_TreeWeights.rows() != m_RandomForest.tree_count())
218  {
219  m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1);
220  m_TreeWeights.fill(1);
221  }
222 
223 
224  vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data());
225  vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data());
226  vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
227  vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data());
228 
229  std::unique_ptr<PredictionData> data;
230  data.reset( new PredictionData(m_RandomForest,X,Y,P,TW));
231 
233  threader->SetSingleMethod(this->PredictWeightedCallback,data.get());
234  threader->SingleMethodExecute();
235 
236  return m_OutLabel;
237 }
238 
239 
240 
242 {
243  m_TreeWeights = weights;
244 }
245 
247 {
248  return m_TreeWeights;
249 }
250 
251 ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::TrainTreesCallback(void * arg)
252 {
253  // Get the ThreadInfoStruct
254  typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType;
255  ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg );
256 
257  TrainingData * data = (TrainingData *)(infoStruct->UserData);
258  unsigned int numberOfTreesToCalculate = 0;
259 
260  // define the number of tress the forest have to calculate
261  numberOfTreesToCalculate = data->m_NumberOfTrees / infoStruct->NumberOfThreads;
262 
263  // the 0th thread takes the residuals
264  if(infoStruct->ThreadID == 0) numberOfTreesToCalculate += data->m_NumberOfTrees % infoStruct->NumberOfThreads;
265 
266  if(numberOfTreesToCalculate != 0){
267  // Copy the Treestructure defined in userData
268  vigra::RandomForest<int> rf = data->m_RandomForest;
269 
270  // Initialize a splitter for the leraning process
271  DefaultSplitType splitter;
272  splitter.UsePointBasedWeights(data->m_Splitter.IsUsingPointBasedWeights());
273  splitter.UseRandomSplit(data->m_Splitter.IsUsingRandomSplit());
274  splitter.SetPrecision(data->m_Splitter.GetPrecision());
275  splitter.SetMaximumTreeDepth(data->m_Splitter.GetMaximumTreeDepth());
276  splitter.SetWeights(data->m_Splitter.GetWeights());
277 
278  rf.trees_.clear();
279  rf.set_options().tree_count(numberOfTreesToCalculate);
280  rf.set_options().use_stratification(data->m_Parameter.Stratification);
281  rf.set_options().sample_with_replacement(data->m_Parameter.SampleWithReplacement);
282  rf.set_options().samples_per_tree(data->m_Parameter.SamplesPerTree);
283  rf.set_options().min_split_node_size(data->m_Parameter.MinimumSplitNodeSize);
284  rf.learn(data->m_Feature, data->m_Label,vigra::rf::visitors::VisitorBase(),splitter);
285 
286  data->m_mutex->Lock();
287 
288  for(const auto & tree : rf.trees_)
289  data->trees_.push_back(tree);
290 
291  data->m_ClassCount = rf.class_count();
292  data->m_mutex->Unlock();
293  }
294 
295  return NULL;
296 
297 }
298 
299 ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictCallback(void * arg)
300 {
301  // Get the ThreadInfoStruct
302  typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType;
303  ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg );
304  // assigne the thread id
305  const unsigned int threadId = infoStruct->ThreadID;
306 
307  // Get the user defined parameters containing all
308  // neccesary informations
309  PredictionData * data = (PredictionData *)(infoStruct->UserData);
310  unsigned int numberOfRowsToCalculate = 0;
311 
312  // Get number of rows to calculate
313  numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads;
314 
315  unsigned int start_index = numberOfRowsToCalculate * threadId;
316  unsigned int end_index = numberOfRowsToCalculate * (threadId+1);
317 
318  // the last thread takes the residuals
319  if(threadId == infoStruct->NumberOfThreads-1) {
320  end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads;
321  }
322 
323  vigra::MultiArrayView<2, double> split_features;
324  vigra::MultiArrayView<2, int> split_labels;
325  vigra::MultiArrayView<2, double> split_probability;
326  {
327  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
328  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Feature.shape(1));
329  split_features = data->m_Feature.subarray(lowerBound,upperBound);
330  }
331 
332  {
333  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
334  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index, data->m_Label.shape(1));
335  split_labels = data->m_Label.subarray(lowerBound,upperBound);
336  }
337 
338  {
339  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
340  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Probabilities.shape(1));
341  split_probability = data->m_Probabilities.subarray(lowerBound,upperBound);
342  }
343 
344  data->m_RandomForest.predictLabels(split_features,split_labels);
345  data->m_RandomForest.predictProbabilities(split_features, split_probability);
346 
347 
348  return NULL;
349 
350 }
351 
352 ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictWeightedCallback(void * arg)
353 {
354  // Get the ThreadInfoStruct
355  typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType;
356  ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg );
357  // assigne the thread id
358  const unsigned int threadId = infoStruct->ThreadID;
359 
360  // Get the user defined parameters containing all
361  // neccesary informations
362  PredictionData * data = (PredictionData *)(infoStruct->UserData);
363  unsigned int numberOfRowsToCalculate = 0;
364 
365  // Get number of rows to calculate
366  numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads;
367 
368  unsigned int start_index = numberOfRowsToCalculate * threadId;
369  unsigned int end_index = numberOfRowsToCalculate * (threadId+1);
370 
371  // the last thread takes the residuals
372  if(threadId == infoStruct->NumberOfThreads-1) {
373  end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads;
374  }
375 
376  vigra::MultiArrayView<2, double> split_features;
377  vigra::MultiArrayView<2, int> split_labels;
378  vigra::MultiArrayView<2, double> split_probability;
379  {
380  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
381  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Feature.shape(1));
382  split_features = data->m_Feature.subarray(lowerBound,upperBound);
383  }
384 
385  {
386  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
387  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index, data->m_Label.shape(1));
388  split_labels = data->m_Label.subarray(lowerBound,upperBound);
389  }
390 
391  {
392  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
393  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Probabilities.shape(1));
394  split_probability = data->m_Probabilities.subarray(lowerBound,upperBound);
395  }
396 
397  VigraPredictWeighted(data, split_features,split_labels,split_probability);
398 
399  return NULL;
400 }
401 
402 
403 void mitk::VigraRandomForestClassifier::VigraPredictWeighted(PredictionData * data, vigra::MultiArrayView<2, double> & X, vigra::MultiArrayView<2, int> & Y, vigra::MultiArrayView<2, double> & P)
404 {
405 
406  int isSampleWeighted = data->m_RandomForest.options_.predict_weighted_;
407 //#pragma omp parallel for
408  for(int row=0; row < vigra::rowCount(X); ++row)
409  {
410  vigra::MultiArrayView<2, double, vigra::StridedArrayTag> currentRow(rowVector(X, row));
411 
412  vigra::ArrayVector<double>::const_iterator weights;
413 
414  //totalWeight == totalVoteCount!
415  double totalWeight = 0.0;
416 
417  //Let each tree classify...
418  for(int k=0; k<data->m_RandomForest.options_.tree_count_; ++k)
419  {
420  //get weights predicted by single tree
421  weights = data->m_RandomForest.trees_[k /*tree_indices_[k]*/].predict(currentRow);
422  double numberOfLeafObservations = (*(weights-1));
423 
424  //update votecount.
425  for(int l=0; l<data->m_RandomForest.ext_param_.class_count_; ++l)
426  {
427  // Either the original weights are taken or the tree is additional weighted by the number of Observations in the leaf node.
428  double cur_w = weights[l] * (isSampleWeighted * numberOfLeafObservations + (1-isSampleWeighted));
429  cur_w = cur_w * data->m_TreeWeights(k,0);
430  P(row, l) += (int)cur_w;
431  //every weight in totalWeight.
432  totalWeight += cur_w;
433  }
434  }
435 
436  //Normalise votes in each row by total VoteCount (totalWeight
437  for(int l=0; l< data->m_RandomForest.ext_param_.class_count_; ++l)
438  {
439  P(row, l) /= vigra::detail::RequiresExplicitCast<double>::cast(totalWeight);
440  }
441  int erg;
442  int maxCol = 0;
443  for (int col=0;col<data->m_RandomForest.class_count();++col)
444  {
445  if (data->m_Probabilities(row,col) > data->m_Probabilities(row, maxCol))
446  maxCol = col;
447  }
448  data->m_RandomForest.ext_param_.to_classlabel(maxCol, erg);
449  Y(row,0) = erg;
450  }
451 }
452 
454 {
455  if(this->m_Parameter == nullptr)
456  this->m_Parameter = new Parameter();
457  // Get the proerty // Some defaults
458 
459  MITK_INFO("VigraRandomForestClassifier") << "Convert Parameter";
460  if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) this->m_Parameter->UsePointBasedWeights = false;
461  if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit)) this->m_Parameter->UseRandomSplit = false;
462  if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth)) this->m_Parameter->TreeDepth = 20;
463  if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount)) this->m_Parameter->TreeCount = 100;
464  if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) this->m_Parameter->MinimumSplitNodeSize = 5;
465  if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision)) this->m_Parameter->Precision = mitk::eps;
466  if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree)) this->m_Parameter->SamplesPerTree = 0.6;
467  if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement)) this->m_Parameter->SampleWithReplacement = true;
468  if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda)) this->m_Parameter->WeightLambda = 1.0; // Not used yet
469  // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification))
470  this->m_Parameter->Stratification = vigra::RF_NONE; // no Property given
471 }
472 
474 {
475  if(this->m_Parameter == nullptr)
476  {
477  MITK_WARN("VigraRandomForestClassifier") << "Parameters are not initialized. Please call ConvertParameter() first!";
478  return;
479  }
480 
481  this->ConvertParameter();
482 
483  // Get the proerty // Some defaults
484  if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights))
485  str << "usepointbasedweight\tNOT SET (default " << this->m_Parameter->UsePointBasedWeights << ")" << "\n";
486  else
487  str << "usepointbasedweight\t" << this->m_Parameter->UsePointBasedWeights << "\n";
488 
489  if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit))
490  str << "userandomsplit\tNOT SET (default " << this->m_Parameter->UseRandomSplit << ")" << "\n";
491  else
492  str << "userandomsplit\t" << this->m_Parameter->UseRandomSplit << "\n";
493 
494  if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth))
495  str << "treedepth\t\tNOT SET (default " << this->m_Parameter->TreeDepth << ")" << "\n";
496  else
497  str << "treedepth\t\t" << this->m_Parameter->TreeDepth << "\n";
498 
499  if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize))
500  str << "minimalsplitnodesize\tNOT SET (default " << this->m_Parameter->MinimumSplitNodeSize << ")" << "\n";
501  else
502  str << "minimalsplitnodesize\t" << this->m_Parameter->MinimumSplitNodeSize << "\n";
503 
504  if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision))
505  str << "precision\t\tNOT SET (default " << this->m_Parameter->Precision << ")" << "\n";
506  else
507  str << "precision\t\t" << this->m_Parameter->Precision << "\n";
508 
509  if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree))
510  str << "samplespertree\tNOT SET (default " << this->m_Parameter->SamplesPerTree << ")" << "\n";
511  else
512  str << "samplespertree\t" << this->m_Parameter->SamplesPerTree << "\n";
513 
514  if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement))
515  str << "samplewithreplacement\tNOT SET (default " << this->m_Parameter->SampleWithReplacement << ")" << "\n";
516  else
517  str << "samplewithreplacement\t" << this->m_Parameter->SampleWithReplacement << "\n";
518 
519  if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount))
520  str << "treecount\t\tNOT SET (default " << this->m_Parameter->TreeCount << ")" << "\n";
521  else
522  str << "treecount\t\t" << this->m_Parameter->TreeCount << "\n";
523 
524  if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda))
525  str << "lambda\t\tNOT SET (default " << this->m_Parameter->WeightLambda << ")" << "\n";
526  else
527  str << "lambda\t\t" << this->m_Parameter->WeightLambda << "\n";
528 
529  // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification))
530  // this->m_Parameter->Stratification = vigra:RF_NONE; // no Property given
531 }
532 
534 {
536  this->GetPropertyList()->SetBoolProperty("usepointbasedweight",val);
537 }
538 
540 {
541  this->GetPropertyList()->SetIntProperty("treedepth",val);
542 }
543 
545 {
546  this->GetPropertyList()->SetIntProperty("minimalsplitnodesize",val);
547 }
548 
550 {
551  this->GetPropertyList()->SetDoubleProperty("precision",val);
552 }
553 
555 {
556  this->GetPropertyList()->SetDoubleProperty("samplespertree",val);
557 }
558 
560 {
561  this->GetPropertyList()->SetBoolProperty("samplewithreplacement",val);
562 }
563 
565 {
566  this->GetPropertyList()->SetIntProperty("treecount",val);
567 }
568 
570 {
571  this->GetPropertyList()->SetDoubleProperty("lambda",val);
572 }
573 
575 {
576  m_TreeWeights(treeId,0) = weight;
577 }
578 
579 void mitk::VigraRandomForestClassifier::SetRandomForest(const vigra::RandomForest<int> & rf)
580 {
581  this->SetMaximumTreeDepth(rf.ext_param().max_tree_depth);
582  this->SetMinimumSplitNodeSize(rf.options().min_split_node_size_);
583  this->SetTreeCount(rf.options().tree_count_);
584  this->SetSamplesPerTree(rf.options().training_set_proportion_);
585  this->UseSampleWithReplacement(rf.options().sample_with_replacement_);
586  this->m_RandomForest = rf;
587 }
588 
589 const vigra::RandomForest<int> & mitk::VigraRandomForestClassifier::GetRandomForest() const
590 {
591  return this->m_RandomForest;
592 }
bool SupportsPointWiseWeight()
SupportsPointWiseWeight.
itk::SmartPointer< Self > Pointer
#define MITK_INFO
Definition: mitkLogMacros.h:22
const vigra::RandomForest< int > & GetRandomForest() const
void UseRandomSplit(bool split)
Eigen::MatrixXi Predict(const Eigen::MatrixXd &X)
void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y)
void UsePointBasedWeights(bool weightsOn)
void PrintParameter(std::ostream &str=std::cout)
mitk::ThresholdSplit< mitk::LinearSplitting< mitk::ImpurityLoss<> >, int, vigra::ClassificationTag > DefaultSplitType
#define MITK_WARN
Definition: mitkLogMacros.h:23
mitk::PropertyList::Pointer GetPropertyList() const
Get the data's property list.
void SetWeights(vigra::MultiArrayView< 2, double > weights)
virtual void UsePointWiseWeight(bool value)
UsePointWiseWeight.
void SetRandomForest(const vigra::RandomForest< int > &rf)
void SetPrecision(double value)
void SetMaximumTreeDepth(int value)
MITKCORE_EXPORT const ScalarType eps
void OnlineTrain(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y)
bool SupportsPointWiseProbability()
SupportsPointWiseProbability.
Eigen::MatrixXi PredictWeighted(const Eigen::MatrixXd &X)
static itkEventMacro(BoundingShapeInteractionEvent, itk::AnyEvent) class MITKBOUNDINGSHAPE_EXPORT BoundingShapeInteractor Pointer New()
Basic interaction methods for mitk::GeometryData.