Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
mitkVigraRandomForestClassifier.cpp
Go to the documentation of this file.
1 /*===================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
7 All rights reserved.
8 
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 A PARTICULAR PURPOSE.
12 
13 See LICENSE.txt or http://www.mitk.org for details.
14 
15 ===================================================================*/
16 
17 // MITK includes
19 #include <mitkThresholdSplit.h>
20 #include <mitkImpurityLoss.h>
21 #include <mitkLinearSplitting.h>
22 #include <mitkProperties.h>
23 
24 // Vigra includes
25 #include <vigra/random_forest.hxx>
26 #include <vigra/random_forest/rf_split.hxx>
27 
28 // ITK include
29 #include <itkFastMutexLock.h>
30 #include <itkMultiThreader.h>
31 #include <itkCommand.h>
32 
34 
35 struct mitk::VigraRandomForestClassifier::Parameter
36 {
37  vigra::RF_OptionTag Stratification;
38  bool SampleWithReplacement;
39  bool UseRandomSplit;
40  bool UsePointBasedWeights;
41  int TreeCount;
42  int MinimumSplitNodeSize;
43  int TreeDepth;
44  double Precision;
45  double WeightLambda;
46  double SamplesPerTree;
47 };
48 
49 struct mitk::VigraRandomForestClassifier::TrainingData
50 {
51  TrainingData(unsigned int numberOfTrees,
52  const vigra::RandomForest<int> & refRF,
53  const DefaultSplitType & refSplitter,
54  const vigra::MultiArrayView<2, double> refFeature,
55  const vigra::MultiArrayView<2, int> refLabel,
56  const Parameter parameter)
57  : m_ClassCount(0),
58  m_NumberOfTrees(numberOfTrees),
59  m_RandomForest(refRF),
60  m_Splitter(refSplitter),
61  m_Feature(refFeature),
62  m_Label(refLabel),
63  m_Parameter(parameter)
64  {
65  m_mutex = itk::FastMutexLock::New();
66  }
67  vigra::ArrayVector<vigra::RandomForest<int>::DecisionTree_t> trees_;
68 
69  int m_ClassCount;
70  unsigned int m_NumberOfTrees;
71  const vigra::RandomForest<int> & m_RandomForest;
72  const DefaultSplitType & m_Splitter;
73  const vigra::MultiArrayView<2, double> m_Feature;
74  const vigra::MultiArrayView<2, int> m_Label;
76  Parameter m_Parameter;
77 };
78 
79 struct mitk::VigraRandomForestClassifier::PredictionData
80 {
81  PredictionData(const vigra::RandomForest<int> & refRF,
82  const vigra::MultiArrayView<2, double> refFeature,
83  vigra::MultiArrayView<2, int> refLabel,
84  vigra::MultiArrayView<2, double> refProb,
85  vigra::MultiArrayView<2, double> refTreeWeights)
86  : m_RandomForest(refRF),
87  m_Feature(refFeature),
88  m_Label(refLabel),
89  m_Probabilities(refProb),
90  m_TreeWeights(refTreeWeights)
91  {
92  }
93  const vigra::RandomForest<int> & m_RandomForest;
94  const vigra::MultiArrayView<2, double> m_Feature;
95  vigra::MultiArrayView<2, int> m_Label;
96  vigra::MultiArrayView<2, double> m_Probabilities;
97  vigra::MultiArrayView<2, double> m_TreeWeights;
98 };
99 
101  :m_Parameter(nullptr)
102 {
104  command->SetCallbackFunction(this, &mitk::VigraRandomForestClassifier::ConvertParameter);
105  this->GetPropertyList()->AddObserver( itk::ModifiedEvent(), command );
106 }
107 
109 {
110 }
111 
113 {
114  return true;
115 }
116 
118 {
119  return true;
120 }
121 
122 void mitk::VigraRandomForestClassifier::OnlineTrain(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in)
123 {
124  vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
125  vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data());
126  m_RandomForest.onlineLearn(X,Y,0,true);
127 }
128 
129 void mitk::VigraRandomForestClassifier::Train(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in)
130 {
131  this->ConvertParameter();
132 
133  DefaultSplitType splitter;
134  splitter.UsePointBasedWeights(m_Parameter->UsePointBasedWeights);
135  splitter.UseRandomSplit(m_Parameter->UseRandomSplit);
136  splitter.SetPrecision(m_Parameter->Precision);
137  splitter.SetMaximumTreeDepth(m_Parameter->TreeDepth);
138 
139  // Weights handled as member variable
140  if (m_Parameter->UsePointBasedWeights)
141  {
142  // Set influence of the weight (0 no influenc to 1 max influence)
143  this->m_PointWiseWeight.unaryExpr([this](double t){ return std::pow(t, this->m_Parameter->WeightLambda) ;});
144 
145  vigra::MultiArrayView<2, double> W(vigra::Shape2(this->m_PointWiseWeight.rows(),this->m_PointWiseWeight.cols()),this->m_PointWiseWeight.data());
146  splitter.SetWeights(W);
147  }
148 
149  vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
150  vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data());
151 
152  m_RandomForest.set_options().tree_count(1); // Number of trees that are calculated;
153 
154  m_RandomForest.set_options().use_stratification(m_Parameter->Stratification);
155  m_RandomForest.set_options().sample_with_replacement(m_Parameter->SampleWithReplacement);
156  m_RandomForest.set_options().samples_per_tree(m_Parameter->SamplesPerTree);
157  m_RandomForest.set_options().min_split_node_size(m_Parameter->MinimumSplitNodeSize);
158 
159  m_RandomForest.learn(X, Y,vigra::rf::visitors::VisitorBase(),splitter);
160 
161  std::unique_ptr<TrainingData> data(new TrainingData(m_Parameter->TreeCount,m_RandomForest,splitter,X,Y, *m_Parameter));
162 
164  threader->SetSingleMethod(this->TrainTreesCallback,data.get());
165  threader->SingleMethodExecute();
166 
167  // set result trees
168  m_RandomForest.set_options().tree_count(m_Parameter->TreeCount);
169  m_RandomForest.ext_param_.class_count_ = data->m_ClassCount;
170  m_RandomForest.trees_ = data->trees_;
171 
172  // Set Tree Weights to default
173  m_TreeWeights = Eigen::MatrixXd(m_Parameter->TreeCount,1);
174  m_TreeWeights.fill(1.0);
175 }
176 
177 Eigen::MatrixXi mitk::VigraRandomForestClassifier::Predict(const Eigen::MatrixXd &X_in)
178 {
179  // Initialize output Eigen matrices
180  m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count());
181  m_OutProbability.fill(0);
182  m_OutLabel = Eigen::MatrixXi(X_in.rows(),1);
183  m_OutLabel.fill(0);
184 
185  // If no weights provided
186  if(m_TreeWeights.rows() != m_RandomForest.tree_count())
187  {
188  m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1);
189  m_TreeWeights.fill(1);
190  }
191 
192 
193  vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data());
194  vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data());
195  vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
196  vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data());
197 
198  std::unique_ptr<PredictionData> data;
199  data.reset( new PredictionData(m_RandomForest,X,Y,P,TW));
200 
202  threader->SetSingleMethod(this->PredictCallback,data.get());
203  threader->SingleMethodExecute();
204 
205  return m_OutLabel;
206 }
207 
208 Eigen::MatrixXi mitk::VigraRandomForestClassifier::PredictWeighted(const Eigen::MatrixXd &X_in)
209 {
210  // Initialize output Eigen matrices
211  m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count());
212  m_OutProbability.fill(0);
213  m_OutLabel = Eigen::MatrixXi(X_in.rows(),1);
214  m_OutLabel.fill(0);
215 
216  // If no weights provided
217  if(m_TreeWeights.rows() != m_RandomForest.tree_count())
218  {
219  m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1);
220  m_TreeWeights.fill(1);
221  }
222 
223 
224  vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data());
225  vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data());
226  vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
227  vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data());
228 
229  std::unique_ptr<PredictionData> data;
230  data.reset( new PredictionData(m_RandomForest,X,Y,P,TW));
231 
233  threader->SetSingleMethod(this->PredictWeightedCallback,data.get());
234  threader->SingleMethodExecute();
235 
236  return m_OutLabel;
237 }
238 
239 
240 
242 {
243  m_TreeWeights = weights;
244 }
245 
247 {
248  return m_TreeWeights;
249 }
250 
251 ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::TrainTreesCallback(void * arg)
252 {
253  // Get the ThreadInfoStruct
254  typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType;
255  ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg );
256 
257  TrainingData * data = (TrainingData *)(infoStruct->UserData);
258  unsigned int numberOfTreesToCalculate = 0;
259 
260  // define the number of tress the forest have to calculate
261  numberOfTreesToCalculate = data->m_NumberOfTrees / infoStruct->NumberOfThreads;
262 
263  // the 0th thread takes the residuals
264  if(infoStruct->ThreadID == 0) numberOfTreesToCalculate += data->m_NumberOfTrees % infoStruct->NumberOfThreads;
265 
266  if(numberOfTreesToCalculate != 0){
267  // Copy the Treestructure defined in userData
268  vigra::RandomForest<int> rf = data->m_RandomForest;
269 
270  // Initialize a splitter for the leraning process
271  DefaultSplitType splitter;
272  splitter.UsePointBasedWeights(data->m_Splitter.IsUsingPointBasedWeights());
273  splitter.UseRandomSplit(data->m_Splitter.IsUsingRandomSplit());
274  splitter.SetPrecision(data->m_Splitter.GetPrecision());
275  splitter.SetMaximumTreeDepth(data->m_Splitter.GetMaximumTreeDepth());
276  splitter.SetWeights(data->m_Splitter.GetWeights());
277 
278  rf.trees_.clear();
279  rf.set_options().tree_count(numberOfTreesToCalculate);
280  rf.set_options().use_stratification(data->m_Parameter.Stratification);
281  rf.set_options().sample_with_replacement(data->m_Parameter.SampleWithReplacement);
282  rf.set_options().samples_per_tree(data->m_Parameter.SamplesPerTree);
283  rf.set_options().min_split_node_size(data->m_Parameter.MinimumSplitNodeSize);
284  rf.learn(data->m_Feature, data->m_Label,vigra::rf::visitors::VisitorBase(),splitter);
285 
286  data->m_mutex->Lock();
287 
288  for(const auto & tree : rf.trees_)
289  data->trees_.push_back(tree);
290 
291  data->m_ClassCount = rf.class_count();
292  data->m_mutex->Unlock();
293  }
294 
295  return NULL;
296 
297 }
298 
299 ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictCallback(void * arg)
300 {
301  // Get the ThreadInfoStruct
302  typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType;
303  ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg );
304  // assigne the thread id
305  const unsigned int threadId = infoStruct->ThreadID;
306 
307  // Get the user defined parameters containing all
308  // neccesary informations
309  PredictionData * data = (PredictionData *)(infoStruct->UserData);
310  unsigned int numberOfRowsToCalculate = 0;
311 
312  // Get number of rows to calculate
313  numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads;
314 
315  unsigned int start_index = numberOfRowsToCalculate * threadId;
316  unsigned int end_index = numberOfRowsToCalculate * (threadId+1);
317 
318  // the last thread takes the residuals
319  if(threadId == infoStruct->NumberOfThreads-1) {
320  end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads;
321  }
322 
323  vigra::MultiArrayView<2, double> split_features;
324  vigra::MultiArrayView<2, int> split_labels;
325  vigra::MultiArrayView<2, double> split_probability;
326  {
327  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
328  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Feature.shape(1));
329  split_features = data->m_Feature.subarray(lowerBound,upperBound);
330  }
331 
332  {
333  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
334  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index, data->m_Label.shape(1));
335  split_labels = data->m_Label.subarray(lowerBound,upperBound);
336  }
337 
338  {
339  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
340  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Probabilities.shape(1));
341  split_probability = data->m_Probabilities.subarray(lowerBound,upperBound);
342  }
343 
344  data->m_RandomForest.predictLabels(split_features,split_labels);
345  data->m_RandomForest.predictProbabilities(split_features, split_probability);
346 
347 
348  return NULL;
349 
350 }
351 
352 ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictWeightedCallback(void * arg)
353 {
354  // Get the ThreadInfoStruct
355  typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType;
356  ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg );
357  // assigne the thread id
358  const unsigned int threadId = infoStruct->ThreadID;
359 
360  // Get the user defined parameters containing all
361  // neccesary informations
362  PredictionData * data = (PredictionData *)(infoStruct->UserData);
363  unsigned int numberOfRowsToCalculate = 0;
364 
365  // Get number of rows to calculate
366  numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads;
367 
368  unsigned int start_index = numberOfRowsToCalculate * threadId;
369  unsigned int end_index = numberOfRowsToCalculate * (threadId+1);
370 
371  // the last thread takes the residuals
372  if(threadId == infoStruct->NumberOfThreads-1) {
373  end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads;
374  }
375 
376  vigra::MultiArrayView<2, double> split_features;
377  vigra::MultiArrayView<2, int> split_labels;
378  vigra::MultiArrayView<2, double> split_probability;
379  {
380  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
381  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Feature.shape(1));
382  split_features = data->m_Feature.subarray(lowerBound,upperBound);
383  }
384 
385  {
386  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
387  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index, data->m_Label.shape(1));
388  split_labels = data->m_Label.subarray(lowerBound,upperBound);
389  }
390 
391  {
392  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
393  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Probabilities.shape(1));
394  split_probability = data->m_Probabilities.subarray(lowerBound,upperBound);
395  }
396 
397  VigraPredictWeighted(data, split_features,split_labels,split_probability);
398 
399  return NULL;
400 }
401 
402 
403 void mitk::VigraRandomForestClassifier::VigraPredictWeighted(PredictionData * data, vigra::MultiArrayView<2, double> & X, vigra::MultiArrayView<2, int> & Y, vigra::MultiArrayView<2, double> & P)
404 {
405 
406  int isSampleWeighted = data->m_RandomForest.options_.predict_weighted_;
407 //#pragma omp parallel for
408  for(int row=0; row < vigra::rowCount(X); ++row)
409  {
410  vigra::MultiArrayView<2, double, vigra::StridedArrayTag> currentRow(rowVector(X, row));
411 
412  vigra::ArrayVector<double>::const_iterator weights;
413 
414  //totalWeight == totalVoteCount!
415  double totalWeight = 0.0;
416 
417  //Let each tree classify...
418  for(int k=0; k<data->m_RandomForest.options_.tree_count_; ++k)
419  {
420  //get weights predicted by single tree
421  weights = data->m_RandomForest.trees_[k /*tree_indices_[k]*/].predict(currentRow);
422  double numberOfLeafObservations = (*(weights-1));
423 
424  //update votecount.
425  for(int l=0; l<data->m_RandomForest.ext_param_.class_count_; ++l)
426  {
427  // Either the original weights are taken or the tree is additional weighted by the number of Observations in the leaf node.
428  double cur_w = weights[l] * (isSampleWeighted * numberOfLeafObservations + (1-isSampleWeighted));
429  cur_w = cur_w * data->m_TreeWeights(k,0);
430  P(row, l) += (int)cur_w;
431  //every weight in totalWeight.
432  totalWeight += cur_w;
433  }
434  }
435 
436  //Normalise votes in each row by total VoteCount (totalWeight
437  for(int l=0; l< data->m_RandomForest.ext_param_.class_count_; ++l)
438  {
439  P(row, l) /= vigra::detail::RequiresExplicitCast<double>::cast(totalWeight);
440  }
441  int erg;
442  int maxCol = 0;
443  for (int col=0;col<data->m_RandomForest.class_count();++col)
444  {
445  if (data->m_Probabilities(row,col) > data->m_Probabilities(row, maxCol))
446  maxCol = col;
447  }
448  data->m_RandomForest.ext_param_.to_classlabel(maxCol, erg);
449  Y(row,0) = erg;
450  }
451 }
452 
454 {
455  if(this->m_Parameter == nullptr)
456  this->m_Parameter = new Parameter();
457  // Get the proerty // Some defaults
458 
459  MITK_INFO("VigraRandomForestClassifier") << "Convert Parameter";
460  if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) this->m_Parameter->UsePointBasedWeights = false;
461  if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit)) this->m_Parameter->UseRandomSplit = false;
462  if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth)) this->m_Parameter->TreeDepth = 20;
463  if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount)) this->m_Parameter->TreeCount = 100;
464  if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) this->m_Parameter->MinimumSplitNodeSize = 5;
465  if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision)) this->m_Parameter->Precision = mitk::eps;
466  if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree)) this->m_Parameter->SamplesPerTree = 0.6;
467  if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement)) this->m_Parameter->SampleWithReplacement = true;
468  if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda)) this->m_Parameter->WeightLambda = 1.0; // Not used yet
469  // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification))
470  this->m_Parameter->Stratification = vigra::RF_NONE; // no Property given
471 }
472 
474 {
475  if(this->m_Parameter == nullptr)
476  {
477  MITK_WARN("VigraRandomForestClassifier") << "Parameters are not initialized. Please call ConvertParameter() first!";
478  return;
479  }
480 
481  this->ConvertParameter();
482 
483  // Get the proerty // Some defaults
484  if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights))
485  str << "usepointbasedweight\tNOT SET (default " << this->m_Parameter->UsePointBasedWeights << ")" << "\n";
486  else
487  str << "usepointbasedweight\t" << this->m_Parameter->UsePointBasedWeights << "\n";
488 
489  if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit))
490  str << "userandomsplit\tNOT SET (default " << this->m_Parameter->UseRandomSplit << ")" << "\n";
491  else
492  str << "userandomsplit\t" << this->m_Parameter->UseRandomSplit << "\n";
493 
494  if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth))
495  str << "treedepth\t\tNOT SET (default " << this->m_Parameter->TreeDepth << ")" << "\n";
496  else
497  str << "treedepth\t\t" << this->m_Parameter->TreeDepth << "\n";
498 
499  if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize))
500  str << "minimalsplitnodesize\tNOT SET (default " << this->m_Parameter->MinimumSplitNodeSize << ")" << "\n";
501  else
502  str << "minimalsplitnodesize\t" << this->m_Parameter->MinimumSplitNodeSize << "\n";
503 
504  if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision))
505  str << "precision\t\tNOT SET (default " << this->m_Parameter->Precision << ")" << "\n";
506  else
507  str << "precision\t\t" << this->m_Parameter->Precision << "\n";
508 
509  if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree))
510  str << "samplespertree\tNOT SET (default " << this->m_Parameter->SamplesPerTree << ")" << "\n";
511  else
512  str << "samplespertree\t" << this->m_Parameter->SamplesPerTree << "\n";
513 
514  if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement))
515  str << "samplewithreplacement\tNOT SET (default " << this->m_Parameter->SampleWithReplacement << ")" << "\n";
516  else
517  str << "samplewithreplacement\t" << this->m_Parameter->SampleWithReplacement << "\n";
518 
519  if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount))
520  str << "treecount\t\tNOT SET (default " << this->m_Parameter->TreeCount << ")" << "\n";
521  else
522  str << "treecount\t\t" << this->m_Parameter->TreeCount << "\n";
523 
524  if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda))
525  str << "lambda\t\tNOT SET (default " << this->m_Parameter->WeightLambda << ")" << "\n";
526  else
527  str << "lambda\t\t" << this->m_Parameter->WeightLambda << "\n";
528 
529  // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification))
530  // this->m_Parameter->Stratification = vigra:RF_NONE; // no Property given
531 }
532 
534 {
536  this->GetPropertyList()->SetBoolProperty("usepointbasedweight",val);
537 }
538 
540 {
541  this->GetPropertyList()->SetIntProperty("treedepth",val);
542 }
543 
545 {
546  this->GetPropertyList()->SetIntProperty("minimalsplitnodesize",val);
547 }
548 
550 {
551  this->GetPropertyList()->SetDoubleProperty("precision",val);
552 }
553 
555 {
556  this->GetPropertyList()->SetDoubleProperty("samplespertree",val);
557 }
558 
560 {
561  this->GetPropertyList()->SetBoolProperty("samplewithreplacement",val);
562 }
563 
565 {
566  this->GetPropertyList()->SetIntProperty("treecount",val);
567 }
568 
570 {
571  this->GetPropertyList()->SetDoubleProperty("lambda",val);
572 }
573 
575 {
576  m_TreeWeights(treeId,0) = weight;
577 }
578 
579 void mitk::VigraRandomForestClassifier::SetRandomForest(const vigra::RandomForest<int> & rf)
580 {
581  this->SetMaximumTreeDepth(rf.ext_param().max_tree_depth);
582  this->SetMinimumSplitNodeSize(rf.options().min_split_node_size_);
583  this->SetTreeCount(rf.options().tree_count_);
584  this->SetSamplesPerTree(rf.options().training_set_proportion_);
585  this->UseSampleWithReplacement(rf.options().sample_with_replacement_);
586  this->m_RandomForest = rf;
587 }
588 
589 const vigra::RandomForest<int> & mitk::VigraRandomForestClassifier::GetRandomForest() const
590 {
591  return this->m_RandomForest;
592 }
bool SupportsPointWiseWeight()
SupportsPointWiseWeight.
itk::SmartPointer< Self > Pointer
#define MITK_INFO
Definition: mitkLogMacros.h:22
const vigra::RandomForest< int > & GetRandomForest() const
void UseRandomSplit(bool split)
Eigen::MatrixXi Predict(const Eigen::MatrixXd &X)
void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y)
void UsePointBasedWeights(bool weightsOn)
void PrintParameter(std::ostream &str=std::cout)
mitk::ThresholdSplit< mitk::LinearSplitting< mitk::ImpurityLoss<> >, int, vigra::ClassificationTag > DefaultSplitType
#define MITK_WARN
Definition: mitkLogMacros.h:23
mitk::PropertyList::Pointer GetPropertyList() const
Get the data's property list.
void SetWeights(vigra::MultiArrayView< 2, double > weights)
virtual void UsePointWiseWeight(bool value)
UsePointWiseWeight.
void SetRandomForest(const vigra::RandomForest< int > &rf)
void SetPrecision(double value)
void SetMaximumTreeDepth(int value)
MITKCORE_EXPORT const ScalarType eps
void OnlineTrain(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y)
bool SupportsPointWiseProbability()
SupportsPointWiseProbability.
Eigen::MatrixXi PredictWeighted(const Eigen::MatrixXd &X)
static itkEventMacro(BoundingShapeInteractionEvent, itk::AnyEvent) class MITKBOUNDINGSHAPE_EXPORT BoundingShapeInteractor Pointer New()
Basic interaction methods for mitk::GeometryData.