Medical Imaging Interaction Toolkit  2018.4.99-389bf124
Medical Imaging Interaction Toolkit
mitkVigraRandomForestClassifier.cpp
Go to the documentation of this file.
1 /*============================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center (DKFZ)
6 All rights reserved.
7 
8 Use of this source code is governed by a 3-clause BSD license that can be
9 found in the LICENSE file.
10 
11 ============================================================================*/
12 
13 // MITK includes
15 #include <mitkThresholdSplit.h>
16 #include <mitkImpurityLoss.h>
17 #include <mitkLinearSplitting.h>
18 #include <mitkProperties.h>
19 
20 // Vigra includes
21 #include <vigra/random_forest.hxx>
22 #include <vigra/random_forest/rf_split.hxx>
23 
24 // ITK include
25 #include <itkFastMutexLock.h>
26 #include <itkMultiThreader.h>
27 #include <itkCommand.h>
28 
30 
31 struct mitk::VigraRandomForestClassifier::Parameter
32 {
33  vigra::RF_OptionTag Stratification;
34  bool SampleWithReplacement;
35  bool UseRandomSplit;
36  bool UsePointBasedWeights;
37  int TreeCount;
38  int MinimumSplitNodeSize;
39  int TreeDepth;
40  double Precision;
41  double WeightLambda;
42  double SamplesPerTree;
43 };
44 
45 struct mitk::VigraRandomForestClassifier::TrainingData
46 {
47  TrainingData(unsigned int numberOfTrees,
48  const vigra::RandomForest<int> & refRF,
49  const DefaultSplitType & refSplitter,
50  const vigra::MultiArrayView<2, double> refFeature,
51  const vigra::MultiArrayView<2, int> refLabel,
52  const Parameter parameter)
53  : m_ClassCount(0),
54  m_NumberOfTrees(numberOfTrees),
55  m_RandomForest(refRF),
56  m_Splitter(refSplitter),
57  m_Feature(refFeature),
58  m_Label(refLabel),
59  m_Parameter(parameter)
60  {
61  m_mutex = itk::FastMutexLock::New();
62  }
63  vigra::ArrayVector<vigra::RandomForest<int>::DecisionTree_t> trees_;
64 
65  int m_ClassCount;
66  unsigned int m_NumberOfTrees;
67  const vigra::RandomForest<int> & m_RandomForest;
68  const DefaultSplitType & m_Splitter;
69  const vigra::MultiArrayView<2, double> m_Feature;
70  const vigra::MultiArrayView<2, int> m_Label;
71  itk::FastMutexLock::Pointer m_mutex;
72  Parameter m_Parameter;
73 };
74 
75 struct mitk::VigraRandomForestClassifier::PredictionData
76 {
77  PredictionData(const vigra::RandomForest<int> & refRF,
78  const vigra::MultiArrayView<2, double> refFeature,
79  vigra::MultiArrayView<2, int> refLabel,
80  vigra::MultiArrayView<2, double> refProb,
81  vigra::MultiArrayView<2, double> refTreeWeights)
82  : m_RandomForest(refRF),
83  m_Feature(refFeature),
84  m_Label(refLabel),
85  m_Probabilities(refProb),
86  m_TreeWeights(refTreeWeights)
87  {
88  }
89  const vigra::RandomForest<int> & m_RandomForest;
90  const vigra::MultiArrayView<2, double> m_Feature;
91  vigra::MultiArrayView<2, int> m_Label;
92  vigra::MultiArrayView<2, double> m_Probabilities;
93  vigra::MultiArrayView<2, double> m_TreeWeights;
94 };
95 
97  :m_Parameter(nullptr)
98 {
99  itk::SimpleMemberCommand<mitk::VigraRandomForestClassifier>::Pointer command = itk::SimpleMemberCommand<mitk::VigraRandomForestClassifier>::New();
100  command->SetCallbackFunction(this, &mitk::VigraRandomForestClassifier::ConvertParameter);
101  this->GetPropertyList()->AddObserver( itk::ModifiedEvent(), command );
102 }
103 
105 {
106 }
107 
109 {
110  return true;
111 }
112 
114 {
115  return true;
116 }
117 
118 void mitk::VigraRandomForestClassifier::OnlineTrain(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in)
119 {
120  vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
121  vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data());
122  m_RandomForest.onlineLearn(X,Y,0,true);
123 }
124 
125 void mitk::VigraRandomForestClassifier::Train(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in)
126 {
127  this->ConvertParameter();
128 
129  DefaultSplitType splitter;
130  splitter.UsePointBasedWeights(m_Parameter->UsePointBasedWeights);
131  splitter.UseRandomSplit(m_Parameter->UseRandomSplit);
132  splitter.SetPrecision(m_Parameter->Precision);
133  splitter.SetMaximumTreeDepth(m_Parameter->TreeDepth);
134 
135  // Weights handled as member variable
136  if (m_Parameter->UsePointBasedWeights)
137  {
138  // Set influence of the weight (0 no influenc to 1 max influence)
139  this->m_PointWiseWeight.unaryExpr([this](double t){ return std::pow(t, this->m_Parameter->WeightLambda) ;});
140 
141  vigra::MultiArrayView<2, double> W(vigra::Shape2(this->m_PointWiseWeight.rows(),this->m_PointWiseWeight.cols()),this->m_PointWiseWeight.data());
142  splitter.SetWeights(W);
143  }
144 
145  vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
146  vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data());
147 
148  m_RandomForest.set_options().tree_count(1); // Number of trees that are calculated;
149 
150  m_RandomForest.set_options().use_stratification(m_Parameter->Stratification);
151  m_RandomForest.set_options().sample_with_replacement(m_Parameter->SampleWithReplacement);
152  m_RandomForest.set_options().samples_per_tree(m_Parameter->SamplesPerTree);
153  m_RandomForest.set_options().min_split_node_size(m_Parameter->MinimumSplitNodeSize);
154 
155  m_RandomForest.learn(X, Y,vigra::rf::visitors::VisitorBase(),splitter);
156 
157  std::unique_ptr<TrainingData> data(new TrainingData(m_Parameter->TreeCount,m_RandomForest,splitter,X,Y, *m_Parameter));
158 
159  itk::MultiThreader::Pointer threader = itk::MultiThreader::New();
160  threader->SetSingleMethod(this->TrainTreesCallback,data.get());
161  threader->SingleMethodExecute();
162 
163  // set result trees
164  m_RandomForest.set_options().tree_count(m_Parameter->TreeCount);
165  m_RandomForest.ext_param_.class_count_ = data->m_ClassCount;
166  m_RandomForest.trees_ = data->trees_;
167 
168  // Set Tree Weights to default
169  m_TreeWeights = Eigen::MatrixXd(m_Parameter->TreeCount,1);
170  m_TreeWeights.fill(1.0);
171 }
172 
173 Eigen::MatrixXi mitk::VigraRandomForestClassifier::Predict(const Eigen::MatrixXd &X_in)
174 {
175  // Initialize output Eigen matrices
176  m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count());
177  m_OutProbability.fill(0);
178  m_OutLabel = Eigen::MatrixXi(X_in.rows(),1);
179  m_OutLabel.fill(0);
180 
181  // If no weights provided
182  if(m_TreeWeights.rows() != m_RandomForest.tree_count())
183  {
184  m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1);
185  m_TreeWeights.fill(1);
186  }
187 
188 
189  vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data());
190  vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data());
191  vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
192  vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data());
193 
194  std::unique_ptr<PredictionData> data;
195  data.reset(new PredictionData(m_RandomForest, X, Y, P, TW));
196 
197  itk::MultiThreader::Pointer threader = itk::MultiThreader::New();
198  threader->SetSingleMethod(this->PredictCallback, data.get());
199  threader->SingleMethodExecute();
200 
201  m_Probabilities = data->m_Probabilities;
202  return m_OutLabel;
203 }
204 
205 Eigen::MatrixXi mitk::VigraRandomForestClassifier::PredictWeighted(const Eigen::MatrixXd &X_in)
206 {
207  // Initialize output Eigen matrices
208  m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count());
209  m_OutProbability.fill(0);
210  m_OutLabel = Eigen::MatrixXi(X_in.rows(),1);
211  m_OutLabel.fill(0);
212 
213  // If no weights provided
214  if(m_TreeWeights.rows() != m_RandomForest.tree_count())
215  {
216  m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1);
217  m_TreeWeights.fill(1);
218  }
219 
220 
221  vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data());
222  vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data());
223  vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
224  vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data());
225 
226  std::unique_ptr<PredictionData> data;
227  data.reset( new PredictionData(m_RandomForest,X,Y,P,TW));
228 
229  itk::MultiThreader::Pointer threader = itk::MultiThreader::New();
230  threader->SetSingleMethod(this->PredictWeightedCallback,data.get());
231  threader->SingleMethodExecute();
232 
233  return m_OutLabel;
234 }
235 
236 
237 
239 {
240  m_TreeWeights = weights;
241 }
242 
244 {
245  return m_TreeWeights;
246 }
247 
248 ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::TrainTreesCallback(void * arg)
249 {
250  // Get the ThreadInfoStruct
251  typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType;
252  ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg );
253 
254  TrainingData * data = (TrainingData *)(infoStruct->UserData);
255  unsigned int numberOfTreesToCalculate = 0;
256 
257  // define the number of tress the forest have to calculate
258  numberOfTreesToCalculate = data->m_NumberOfTrees / infoStruct->NumberOfThreads;
259 
260  // the 0th thread takes the residuals
261  if(infoStruct->ThreadID == 0) numberOfTreesToCalculate += data->m_NumberOfTrees % infoStruct->NumberOfThreads;
262 
263  if(numberOfTreesToCalculate != 0){
264  // Copy the Treestructure defined in userData
265  vigra::RandomForest<int> rf = data->m_RandomForest;
266 
267  // Initialize a splitter for the leraning process
268  DefaultSplitType splitter;
269  splitter.UsePointBasedWeights(data->m_Splitter.IsUsingPointBasedWeights());
270  splitter.UseRandomSplit(data->m_Splitter.IsUsingRandomSplit());
271  splitter.SetPrecision(data->m_Splitter.GetPrecision());
272  splitter.SetMaximumTreeDepth(data->m_Splitter.GetMaximumTreeDepth());
273  splitter.SetWeights(data->m_Splitter.GetWeights());
274 
275  rf.trees_.clear();
276  rf.set_options().tree_count(numberOfTreesToCalculate);
277  rf.set_options().use_stratification(data->m_Parameter.Stratification);
278  rf.set_options().sample_with_replacement(data->m_Parameter.SampleWithReplacement);
279  rf.set_options().samples_per_tree(data->m_Parameter.SamplesPerTree);
280  rf.set_options().min_split_node_size(data->m_Parameter.MinimumSplitNodeSize);
281  rf.learn(data->m_Feature, data->m_Label,vigra::rf::visitors::VisitorBase(),splitter);
282 
283  data->m_mutex->Lock();
284 
285  for(const auto & tree : rf.trees_)
286  data->trees_.push_back(tree);
287 
288  data->m_ClassCount = rf.class_count();
289  data->m_mutex->Unlock();
290  }
291 
292  return ITK_THREAD_RETURN_VALUE;
293 
294 }
295 
296 ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictCallback(void * arg)
297 {
298  // Get the ThreadInfoStruct
299  typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType;
300  ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg );
301  // assigne the thread id
302  const unsigned int threadId = infoStruct->ThreadID;
303 
304  // Get the user defined parameters containing all
305  // neccesary informations
306  PredictionData * data = (PredictionData *)(infoStruct->UserData);
307  unsigned int numberOfRowsToCalculate = 0;
308 
309  // Get number of rows to calculate
310  numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads;
311 
312  unsigned int start_index = numberOfRowsToCalculate * threadId;
313  unsigned int end_index = numberOfRowsToCalculate * (threadId+1);
314 
315  // the last thread takes the residuals
316  if(threadId == infoStruct->NumberOfThreads-1) {
317  end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads;
318  }
319 
320  vigra::MultiArrayView<2, double> split_features;
321  vigra::MultiArrayView<2, int> split_labels;
322  vigra::MultiArrayView<2, double> split_probability;
323  {
324  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
325  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Feature.shape(1));
326  split_features = data->m_Feature.subarray(lowerBound,upperBound);
327  }
328 
329  {
330  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
331  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index, data->m_Label.shape(1));
332  split_labels = data->m_Label.subarray(lowerBound,upperBound);
333  }
334 
335  {
336  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
337  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Probabilities.shape(1));
338  split_probability = data->m_Probabilities.subarray(lowerBound,upperBound);
339  }
340 
341  data->m_RandomForest.predictLabels(split_features,split_labels);
342  data->m_RandomForest.predictProbabilities(split_features, split_probability);
343 
344 
345  return ITK_THREAD_RETURN_VALUE;
346 
347 }
348 
349 ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictWeightedCallback(void * arg)
350 {
351  // Get the ThreadInfoStruct
352  typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType;
353  ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg );
354  // assigne the thread id
355  const unsigned int threadId = infoStruct->ThreadID;
356 
357  // Get the user defined parameters containing all
358  // neccesary informations
359  PredictionData * data = (PredictionData *)(infoStruct->UserData);
360  unsigned int numberOfRowsToCalculate = 0;
361 
362  // Get number of rows to calculate
363  numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads;
364 
365  unsigned int start_index = numberOfRowsToCalculate * threadId;
366  unsigned int end_index = numberOfRowsToCalculate * (threadId+1);
367 
368  // the last thread takes the residuals
369  if(threadId == infoStruct->NumberOfThreads-1) {
370  end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads;
371  }
372 
373  vigra::MultiArrayView<2, double> split_features;
374  vigra::MultiArrayView<2, int> split_labels;
375  vigra::MultiArrayView<2, double> split_probability;
376  {
377  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
378  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Feature.shape(1));
379  split_features = data->m_Feature.subarray(lowerBound,upperBound);
380  }
381 
382  {
383  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
384  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index, data->m_Label.shape(1));
385  split_labels = data->m_Label.subarray(lowerBound,upperBound);
386  }
387 
388  {
389  vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
390  vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Probabilities.shape(1));
391  split_probability = data->m_Probabilities.subarray(lowerBound,upperBound);
392  }
393 
394  VigraPredictWeighted(data, split_features,split_labels,split_probability);
395 
396  return ITK_THREAD_RETURN_VALUE;
397 }
398 
399 
400 void mitk::VigraRandomForestClassifier::VigraPredictWeighted(PredictionData * data, vigra::MultiArrayView<2, double> & X, vigra::MultiArrayView<2, int> & Y, vigra::MultiArrayView<2, double> & P)
401 {
402 
403  int isSampleWeighted = data->m_RandomForest.options_.predict_weighted_;
404 //#pragma omp parallel for
405  for(int row=0; row < vigra::rowCount(X); ++row)
406  {
407  vigra::MultiArrayView<2, double, vigra::StridedArrayTag> currentRow(rowVector(X, row));
408 
409  vigra::ArrayVector<double>::const_iterator weights;
410 
411  //totalWeight == totalVoteCount!
412  double totalWeight = 0.0;
413 
414  //Let each tree classify...
415  for(int k=0; k<data->m_RandomForest.options_.tree_count_; ++k)
416  {
417  //get weights predicted by single tree
418  weights = data->m_RandomForest.trees_[k /*tree_indices_[k]*/].predict(currentRow);
419  double numberOfLeafObservations = (*(weights-1));
420 
421  //update votecount.
422  for(int l=0; l<data->m_RandomForest.ext_param_.class_count_; ++l)
423  {
424  // Either the original weights are taken or the tree is additional weighted by the number of Observations in the leaf node.
425  double cur_w = weights[l] * (isSampleWeighted * numberOfLeafObservations + (1-isSampleWeighted));
426  cur_w = cur_w * data->m_TreeWeights(k,0);
427  P(row, l) += (int)cur_w;
428  //every weight in totalWeight.
429  totalWeight += cur_w;
430  }
431  }
432 
433  //Normalise votes in each row by total VoteCount (totalWeight
434  for(int l=0; l< data->m_RandomForest.ext_param_.class_count_; ++l)
435  {
436  P(row, l) /= vigra::detail::RequiresExplicitCast<double>::cast(totalWeight);
437  }
438  int erg;
439  int maxCol = 0;
440  for (int col=0;col<data->m_RandomForest.class_count();++col)
441  {
442  if (data->m_Probabilities(row,col) > data->m_Probabilities(row, maxCol))
443  maxCol = col;
444  }
445  data->m_RandomForest.ext_param_.to_classlabel(maxCol, erg);
446  Y(row,0) = erg;
447  }
448 }
449 
451 {
452  if(this->m_Parameter == nullptr)
453  this->m_Parameter = new Parameter();
454  // Get the proerty // Some defaults
455 
456  MITK_INFO("VigraRandomForestClassifier") << "Convert Parameter";
457  if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) this->m_Parameter->UsePointBasedWeights = false;
458  if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit)) this->m_Parameter->UseRandomSplit = false;
459  if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth)) this->m_Parameter->TreeDepth = 20;
460  if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount)) this->m_Parameter->TreeCount = 100;
461  if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) this->m_Parameter->MinimumSplitNodeSize = 5;
462  if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision)) this->m_Parameter->Precision = mitk::eps;
463  if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree)) this->m_Parameter->SamplesPerTree = 0.6;
464  if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement)) this->m_Parameter->SampleWithReplacement = true;
465  if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda)) this->m_Parameter->WeightLambda = 1.0; // Not used yet
466  // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification))
467  this->m_Parameter->Stratification = vigra::RF_NONE; // no Property given
468 }
469 
471 {
472  if(this->m_Parameter == nullptr)
473  {
474  MITK_WARN("VigraRandomForestClassifier") << "Parameters are not initialized. Please call ConvertParameter() first!";
475  return;
476  }
477 
478  this->ConvertParameter();
479 
480  // Get the proerty // Some defaults
481  if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights))
482  str << "usepointbasedweight\tNOT SET (default " << this->m_Parameter->UsePointBasedWeights << ")" << "\n";
483  else
484  str << "usepointbasedweight\t" << this->m_Parameter->UsePointBasedWeights << "\n";
485 
486  if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit))
487  str << "userandomsplit\tNOT SET (default " << this->m_Parameter->UseRandomSplit << ")" << "\n";
488  else
489  str << "userandomsplit\t" << this->m_Parameter->UseRandomSplit << "\n";
490 
491  if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth))
492  str << "treedepth\t\tNOT SET (default " << this->m_Parameter->TreeDepth << ")" << "\n";
493  else
494  str << "treedepth\t\t" << this->m_Parameter->TreeDepth << "\n";
495 
496  if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize))
497  str << "minimalsplitnodesize\tNOT SET (default " << this->m_Parameter->MinimumSplitNodeSize << ")" << "\n";
498  else
499  str << "minimalsplitnodesize\t" << this->m_Parameter->MinimumSplitNodeSize << "\n";
500 
501  if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision))
502  str << "precision\t\tNOT SET (default " << this->m_Parameter->Precision << ")" << "\n";
503  else
504  str << "precision\t\t" << this->m_Parameter->Precision << "\n";
505 
506  if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree))
507  str << "samplespertree\tNOT SET (default " << this->m_Parameter->SamplesPerTree << ")" << "\n";
508  else
509  str << "samplespertree\t" << this->m_Parameter->SamplesPerTree << "\n";
510 
511  if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement))
512  str << "samplewithreplacement\tNOT SET (default " << this->m_Parameter->SampleWithReplacement << ")" << "\n";
513  else
514  str << "samplewithreplacement\t" << this->m_Parameter->SampleWithReplacement << "\n";
515 
516  if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount))
517  str << "treecount\t\tNOT SET (default " << this->m_Parameter->TreeCount << ")" << "\n";
518  else
519  str << "treecount\t\t" << this->m_Parameter->TreeCount << "\n";
520 
521  if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda))
522  str << "lambda\t\tNOT SET (default " << this->m_Parameter->WeightLambda << ")" << "\n";
523  else
524  str << "lambda\t\t" << this->m_Parameter->WeightLambda << "\n";
525 
526  // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification))
527  // this->m_Parameter->Stratification = vigra:RF_NONE; // no Property given
528 }
529 
531 {
533  this->GetPropertyList()->SetBoolProperty("usepointbasedweight",val);
534 }
535 
537 {
538  this->GetPropertyList()->SetIntProperty("treedepth",val);
539 }
540 
542 {
543  this->GetPropertyList()->SetIntProperty("minimalsplitnodesize",val);
544 }
545 
547 {
548  this->GetPropertyList()->SetDoubleProperty("precision",val);
549 }
550 
552 {
553  this->GetPropertyList()->SetDoubleProperty("samplespertree",val);
554 }
555 
557 {
558  this->GetPropertyList()->SetBoolProperty("samplewithreplacement",val);
559 }
560 
562 {
563  this->GetPropertyList()->SetIntProperty("treecount",val);
564 }
565 
567 {
568  this->GetPropertyList()->SetDoubleProperty("lambda",val);
569 }
570 
572 {
573  m_TreeWeights(treeId,0) = weight;
574 }
575 
576 void mitk::VigraRandomForestClassifier::SetRandomForest(const vigra::RandomForest<int> & rf)
577 {
578  this->SetMaximumTreeDepth(rf.ext_param().max_tree_depth);
579  this->SetMinimumSplitNodeSize(rf.options().min_split_node_size_);
580  this->SetTreeCount(rf.options().tree_count_);
581  this->SetSamplesPerTree(rf.options().training_set_proportion_);
582  this->UseSampleWithReplacement(rf.options().sample_with_replacement_);
583  this->m_RandomForest = rf;
584 }
585 
586 const vigra::RandomForest<int> & mitk::VigraRandomForestClassifier::GetRandomForest() const
587 {
588  return this->m_RandomForest;
589 }
float k(1.0)
#define MITK_INFO
Definition: mitkLogMacros.h:18
void UseRandomSplit(bool split)
const vigra::RandomForest< int > & GetRandomForest() const
void UsePointBasedWeights(bool weightsOn)
void PrintParameter(std::ostream &str=std::cout)
Eigen::MatrixXi Predict(const Eigen::MatrixXd &X) override
Predict class for X.
mitk::ThresholdSplit< mitk::LinearSplitting< mitk::ImpurityLoss<> >, int, vigra::ClassificationTag > DefaultSplitType
#define MITK_WARN
Definition: mitkLogMacros.h:19
bool SupportsPointWiseProbability() override
SupportsPointWiseProbability.
void SetWeights(vigra::MultiArrayView< 2, double > weights)
virtual void UsePointWiseWeight(bool value)
UsePointWiseWeight.
void SetRandomForest(const vigra::RandomForest< int > &rf)
void SetPrecision(double value)
void SetMaximumTreeDepth(int value)
mitk::PropertyList::Pointer GetPropertyList() const
Get the data&#39;s property list.
MITKCORE_EXPORT const ScalarType eps
void OnlineTrain(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y)
bool SupportsPointWiseWeight() override
SupportsPointWiseWeight.
void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y) override
Build a forest of trees from the training set (X, y).
void UsePointWiseWeight(bool) override
UsePointWiseWeight.
Eigen::MatrixXi PredictWeighted(const Eigen::MatrixXd &X)