25 #include <vigra/random_forest.hxx>
26 #include <vigra/random_forest/rf_split.hxx>
29 #include <itkFastMutexLock.h>
30 #include <itkMultiThreader.h>
31 #include <itkCommand.h>
35 struct mitk::VigraRandomForestClassifier::Parameter
37 vigra::RF_OptionTag Stratification;
38 bool SampleWithReplacement;
40 bool UsePointBasedWeights;
42 int MinimumSplitNodeSize;
46 double SamplesPerTree;
49 struct mitk::VigraRandomForestClassifier::TrainingData
51 TrainingData(
unsigned int numberOfTrees,
52 const vigra::RandomForest<int> & refRF,
54 const vigra::MultiArrayView<2, double> refFeature,
55 const vigra::MultiArrayView<2, int> refLabel,
56 const Parameter parameter)
58 m_NumberOfTrees(numberOfTrees),
59 m_RandomForest(refRF),
60 m_Splitter(refSplitter),
61 m_Feature(refFeature),
63 m_Parameter(parameter)
67 vigra::ArrayVector<vigra::RandomForest<int>::DecisionTree_t> trees_;
70 unsigned int m_NumberOfTrees;
71 const vigra::RandomForest<int> & m_RandomForest;
73 const vigra::MultiArrayView<2, double> m_Feature;
74 const vigra::MultiArrayView<2, int> m_Label;
76 Parameter m_Parameter;
79 struct mitk::VigraRandomForestClassifier::PredictionData
81 PredictionData(
const vigra::RandomForest<int> & refRF,
82 const vigra::MultiArrayView<2, double> refFeature,
83 vigra::MultiArrayView<2, int> refLabel,
84 vigra::MultiArrayView<2, double> refProb,
85 vigra::MultiArrayView<2, double> refTreeWeights)
86 : m_RandomForest(refRF),
87 m_Feature(refFeature),
89 m_Probabilities(refProb),
90 m_TreeWeights(refTreeWeights)
93 const vigra::RandomForest<int> & m_RandomForest;
94 const vigra::MultiArrayView<2, double> m_Feature;
95 vigra::MultiArrayView<2, int> m_Label;
96 vigra::MultiArrayView<2, double> m_Probabilities;
97 vigra::MultiArrayView<2, double> m_TreeWeights;
101 :m_Parameter(nullptr)
124 vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
125 vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data());
126 m_RandomForest.onlineLearn(X,Y,0,
true);
131 this->ConvertParameter();
140 if (m_Parameter->UsePointBasedWeights)
143 this->m_PointWiseWeight.unaryExpr([
this](
double t){
return std::pow(t, this->m_Parameter->WeightLambda) ;});
145 vigra::MultiArrayView<2, double> W(vigra::Shape2(this->m_PointWiseWeight.rows(),this->m_PointWiseWeight.cols()),this->m_PointWiseWeight.data());
149 vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
150 vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data());
152 m_RandomForest.set_options().tree_count(1);
154 m_RandomForest.set_options().use_stratification(m_Parameter->Stratification);
155 m_RandomForest.set_options().sample_with_replacement(m_Parameter->SampleWithReplacement);
156 m_RandomForest.set_options().samples_per_tree(m_Parameter->SamplesPerTree);
157 m_RandomForest.set_options().min_split_node_size(m_Parameter->MinimumSplitNodeSize);
159 m_RandomForest.learn(X, Y,vigra::rf::visitors::VisitorBase(),splitter);
161 std::unique_ptr<TrainingData> data(
new TrainingData(m_Parameter->TreeCount,m_RandomForest,splitter,X,Y, *m_Parameter));
164 threader->SetSingleMethod(this->TrainTreesCallback,data.get());
165 threader->SingleMethodExecute();
168 m_RandomForest.set_options().tree_count(m_Parameter->TreeCount);
169 m_RandomForest.ext_param_.class_count_ = data->m_ClassCount;
170 m_RandomForest.trees_ = data->trees_;
173 m_TreeWeights = Eigen::MatrixXd(m_Parameter->TreeCount,1);
174 m_TreeWeights.fill(1.0);
180 m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count());
181 m_OutProbability.fill(0);
182 m_OutLabel = Eigen::MatrixXi(X_in.rows(),1);
186 if(m_TreeWeights.rows() != m_RandomForest.tree_count())
188 m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1);
189 m_TreeWeights.fill(1);
193 vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data());
194 vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data());
195 vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
196 vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data());
198 std::unique_ptr<PredictionData> data;
199 data.reset(
new PredictionData(m_RandomForest,X,Y,P,TW));
202 threader->SetSingleMethod(this->PredictCallback,data.get());
203 threader->SingleMethodExecute();
211 m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count());
212 m_OutProbability.fill(0);
213 m_OutLabel = Eigen::MatrixXi(X_in.rows(),1);
217 if(m_TreeWeights.rows() != m_RandomForest.tree_count())
219 m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1);
220 m_TreeWeights.fill(1);
224 vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data());
225 vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data());
226 vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
227 vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data());
229 std::unique_ptr<PredictionData> data;
230 data.reset(
new PredictionData(m_RandomForest,X,Y,P,TW));
233 threader->SetSingleMethod(this->PredictWeightedCallback,data.get());
234 threader->SingleMethodExecute();
243 m_TreeWeights = weights;
248 return m_TreeWeights;
251 ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::TrainTreesCallback(
void * arg)
254 typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType;
255 ThreadInfoType * infoStruct =
static_cast< ThreadInfoType *
>( arg );
257 TrainingData * data = (TrainingData *)(infoStruct->UserData);
258 unsigned int numberOfTreesToCalculate = 0;
261 numberOfTreesToCalculate = data->m_NumberOfTrees / infoStruct->NumberOfThreads;
264 if(infoStruct->ThreadID == 0) numberOfTreesToCalculate += data->m_NumberOfTrees % infoStruct->NumberOfThreads;
266 if(numberOfTreesToCalculate != 0){
268 vigra::RandomForest<int> rf = data->m_RandomForest;
276 splitter.
SetWeights(data->m_Splitter.GetWeights());
279 rf.set_options().tree_count(numberOfTreesToCalculate);
280 rf.set_options().use_stratification(data->m_Parameter.Stratification);
281 rf.set_options().sample_with_replacement(data->m_Parameter.SampleWithReplacement);
282 rf.set_options().samples_per_tree(data->m_Parameter.SamplesPerTree);
283 rf.set_options().min_split_node_size(data->m_Parameter.MinimumSplitNodeSize);
284 rf.learn(data->m_Feature, data->m_Label,vigra::rf::visitors::VisitorBase(),splitter);
286 data->m_mutex->Lock();
288 for(
const auto & tree : rf.trees_)
289 data->trees_.push_back(tree);
291 data->m_ClassCount = rf.class_count();
292 data->m_mutex->Unlock();
299 ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictCallback(
void * arg)
302 typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType;
303 ThreadInfoType * infoStruct =
static_cast< ThreadInfoType *
>( arg );
305 const unsigned int threadId = infoStruct->ThreadID;
309 PredictionData * data = (PredictionData *)(infoStruct->UserData);
310 unsigned int numberOfRowsToCalculate = 0;
313 numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads;
315 unsigned int start_index = numberOfRowsToCalculate * threadId;
316 unsigned int end_index = numberOfRowsToCalculate * (threadId+1);
319 if(threadId == infoStruct->NumberOfThreads-1) {
320 end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads;
323 vigra::MultiArrayView<2, double> split_features;
324 vigra::MultiArrayView<2, int> split_labels;
325 vigra::MultiArrayView<2, double> split_probability;
327 vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
328 vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Feature.shape(1));
329 split_features = data->m_Feature.subarray(lowerBound,upperBound);
333 vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
334 vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index, data->m_Label.shape(1));
335 split_labels = data->m_Label.subarray(lowerBound,upperBound);
339 vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
340 vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Probabilities.shape(1));
341 split_probability = data->m_Probabilities.subarray(lowerBound,upperBound);
344 data->m_RandomForest.predictLabels(split_features,split_labels);
345 data->m_RandomForest.predictProbabilities(split_features, split_probability);
352 ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictWeightedCallback(
void * arg)
355 typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType;
356 ThreadInfoType * infoStruct =
static_cast< ThreadInfoType *
>( arg );
358 const unsigned int threadId = infoStruct->ThreadID;
362 PredictionData * data = (PredictionData *)(infoStruct->UserData);
363 unsigned int numberOfRowsToCalculate = 0;
366 numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads;
368 unsigned int start_index = numberOfRowsToCalculate * threadId;
369 unsigned int end_index = numberOfRowsToCalculate * (threadId+1);
372 if(threadId == infoStruct->NumberOfThreads-1) {
373 end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads;
376 vigra::MultiArrayView<2, double> split_features;
377 vigra::MultiArrayView<2, int> split_labels;
378 vigra::MultiArrayView<2, double> split_probability;
380 vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
381 vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Feature.shape(1));
382 split_features = data->m_Feature.subarray(lowerBound,upperBound);
386 vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
387 vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index, data->m_Label.shape(1));
388 split_labels = data->m_Label.subarray(lowerBound,upperBound);
392 vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
393 vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Probabilities.shape(1));
394 split_probability = data->m_Probabilities.subarray(lowerBound,upperBound);
397 VigraPredictWeighted(data, split_features,split_labels,split_probability);
403 void mitk::VigraRandomForestClassifier::VigraPredictWeighted(PredictionData * data, vigra::MultiArrayView<2, double> & X, vigra::MultiArrayView<2, int> & Y, vigra::MultiArrayView<2, double> & P)
406 int isSampleWeighted = data->m_RandomForest.options_.predict_weighted_;
408 for(
int row=0; row < vigra::rowCount(X); ++row)
410 vigra::MultiArrayView<2, double, vigra::StridedArrayTag> currentRow(rowVector(X, row));
412 vigra::ArrayVector<double>::const_iterator weights;
415 double totalWeight = 0.0;
418 for(
int k=0; k<data->m_RandomForest.options_.tree_count_; ++k)
421 weights = data->m_RandomForest.trees_[k ].predict(currentRow);
422 double numberOfLeafObservations = (*(weights-1));
425 for(
int l=0; l<data->m_RandomForest.ext_param_.class_count_; ++l)
428 double cur_w = weights[l] * (isSampleWeighted * numberOfLeafObservations + (1-isSampleWeighted));
429 cur_w = cur_w * data->m_TreeWeights(k,0);
430 P(row, l) += (int)cur_w;
432 totalWeight += cur_w;
437 for(
int l=0; l< data->m_RandomForest.ext_param_.class_count_; ++l)
439 P(row, l) /= vigra::detail::RequiresExplicitCast<double>::cast(totalWeight);
443 for (
int col=0;col<data->m_RandomForest.class_count();++col)
445 if (data->m_Probabilities(row,col) > data->m_Probabilities(row, maxCol))
448 data->m_RandomForest.ext_param_.to_classlabel(maxCol, erg);
455 if(this->m_Parameter ==
nullptr)
456 this->m_Parameter =
new Parameter();
459 MITK_INFO(
"VigraRandomForestClassifier") <<
"Convert Parameter";
460 if(!this->GetPropertyList()->Get(
"usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) this->m_Parameter->UsePointBasedWeights =
false;
461 if(!this->GetPropertyList()->Get(
"userandomsplit",this->m_Parameter->UseRandomSplit)) this->m_Parameter->UseRandomSplit =
false;
462 if(!this->GetPropertyList()->Get(
"treedepth",this->m_Parameter->TreeDepth)) this->m_Parameter->TreeDepth = 20;
463 if(!this->GetPropertyList()->Get(
"treecount",this->m_Parameter->TreeCount)) this->m_Parameter->TreeCount = 100;
464 if(!this->GetPropertyList()->Get(
"minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) this->m_Parameter->MinimumSplitNodeSize = 5;
465 if(!this->GetPropertyList()->Get(
"precision",this->m_Parameter->Precision)) this->m_Parameter->Precision =
mitk::eps;
466 if(!this->GetPropertyList()->Get(
"samplespertree",this->m_Parameter->SamplesPerTree)) this->m_Parameter->SamplesPerTree = 0.6;
467 if(!this->GetPropertyList()->Get(
"samplewithreplacement",this->m_Parameter->SampleWithReplacement)) this->m_Parameter->SampleWithReplacement =
true;
468 if(!this->GetPropertyList()->Get(
"lambda",this->m_Parameter->WeightLambda)) this->m_Parameter->WeightLambda = 1.0;
470 this->m_Parameter->Stratification = vigra::RF_NONE;
475 if(this->m_Parameter ==
nullptr)
477 MITK_WARN(
"VigraRandomForestClassifier") <<
"Parameters are not initialized. Please call ConvertParameter() first!";
481 this->ConvertParameter();
484 if(!this->GetPropertyList()->Get(
"usepointbasedweight",this->m_Parameter->UsePointBasedWeights))
485 str <<
"usepointbasedweight\tNOT SET (default " << this->m_Parameter->UsePointBasedWeights <<
")" <<
"\n";
487 str <<
"usepointbasedweight\t" << this->m_Parameter->UsePointBasedWeights <<
"\n";
489 if(!this->GetPropertyList()->Get(
"userandomsplit",this->m_Parameter->UseRandomSplit))
490 str <<
"userandomsplit\tNOT SET (default " << this->m_Parameter->UseRandomSplit <<
")" <<
"\n";
492 str <<
"userandomsplit\t" << this->m_Parameter->UseRandomSplit <<
"\n";
494 if(!this->GetPropertyList()->Get(
"treedepth",this->m_Parameter->TreeDepth))
495 str <<
"treedepth\t\tNOT SET (default " << this->m_Parameter->TreeDepth <<
")" <<
"\n";
497 str <<
"treedepth\t\t" << this->m_Parameter->TreeDepth <<
"\n";
499 if(!this->GetPropertyList()->Get(
"minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize))
500 str <<
"minimalsplitnodesize\tNOT SET (default " << this->m_Parameter->MinimumSplitNodeSize <<
")" <<
"\n";
502 str <<
"minimalsplitnodesize\t" << this->m_Parameter->MinimumSplitNodeSize <<
"\n";
504 if(!this->GetPropertyList()->Get(
"precision",this->m_Parameter->Precision))
505 str <<
"precision\t\tNOT SET (default " << this->m_Parameter->Precision <<
")" <<
"\n";
507 str <<
"precision\t\t" << this->m_Parameter->Precision <<
"\n";
509 if(!this->GetPropertyList()->Get(
"samplespertree",this->m_Parameter->SamplesPerTree))
510 str <<
"samplespertree\tNOT SET (default " << this->m_Parameter->SamplesPerTree <<
")" <<
"\n";
512 str <<
"samplespertree\t" << this->m_Parameter->SamplesPerTree <<
"\n";
514 if(!this->GetPropertyList()->Get(
"samplewithreplacement",this->m_Parameter->SampleWithReplacement))
515 str <<
"samplewithreplacement\tNOT SET (default " << this->m_Parameter->SampleWithReplacement <<
")" <<
"\n";
517 str <<
"samplewithreplacement\t" << this->m_Parameter->SampleWithReplacement <<
"\n";
519 if(!this->GetPropertyList()->Get(
"treecount",this->m_Parameter->TreeCount))
520 str <<
"treecount\t\tNOT SET (default " << this->m_Parameter->TreeCount <<
")" <<
"\n";
522 str <<
"treecount\t\t" << this->m_Parameter->TreeCount <<
"\n";
524 if(!this->GetPropertyList()->Get(
"lambda",this->m_Parameter->WeightLambda))
525 str <<
"lambda\t\tNOT SET (default " << this->m_Parameter->WeightLambda <<
")" <<
"\n";
527 str <<
"lambda\t\t" << this->m_Parameter->WeightLambda <<
"\n";
536 this->GetPropertyList()->SetBoolProperty(
"usepointbasedweight",val);
541 this->GetPropertyList()->SetIntProperty(
"treedepth",val);
546 this->GetPropertyList()->SetIntProperty(
"minimalsplitnodesize",val);
551 this->GetPropertyList()->SetDoubleProperty(
"precision",val);
556 this->GetPropertyList()->SetDoubleProperty(
"samplespertree",val);
561 this->GetPropertyList()->SetBoolProperty(
"samplewithreplacement",val);
566 this->GetPropertyList()->SetIntProperty(
"treecount",val);
571 this->GetPropertyList()->SetDoubleProperty(
"lambda",val);
576 m_TreeWeights(treeId,0) = weight;
581 this->SetMaximumTreeDepth(rf.ext_param().max_tree_depth);
582 this->SetMinimumSplitNodeSize(rf.options().min_split_node_size_);
583 this->SetTreeCount(rf.options().tree_count_);
584 this->SetSamplesPerTree(rf.options().training_set_proportion_);
585 this->UseSampleWithReplacement(rf.options().sample_with_replacement_);
586 this->m_RandomForest = rf;
591 return this->m_RandomForest;
bool SupportsPointWiseWeight()
SupportsPointWiseWeight.
itk::SmartPointer< Self > Pointer
void SetMaximumTreeDepth(int)
const vigra::RandomForest< int > & GetRandomForest() const
void SetPrecision(double)
void SetTreeWeight(int treeId, double weight)
void UseRandomSplit(bool split)
Eigen::MatrixXi Predict(const Eigen::MatrixXd &X)
void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y)
void UseSampleWithReplacement(bool)
void UsePointBasedWeights(bool weightsOn)
void PrintParameter(std::ostream &str=std::cout)
~VigraRandomForestClassifier()
mitk::ThresholdSplit< mitk::LinearSplitting< mitk::ImpurityLoss<> >, int, vigra::ClassificationTag > DefaultSplitType
void SetMinimumSplitNodeSize(int)
void SetWeightLambda(double)
Eigen::MatrixXd GetTreeWeights() const
mitk::PropertyList::Pointer GetPropertyList() const
Get the data's property list.
void SetTreeWeights(Eigen::MatrixXd weights)
void SetWeights(vigra::MultiArrayView< 2, double > weights)
virtual void UsePointWiseWeight(bool value)
UsePointWiseWeight.
VigraRandomForestClassifier()
void SetRandomForest(const vigra::RandomForest< int > &rf)
void UsePointWiseWeight(bool)
UsePointWiseWeight.
void SetPrecision(double value)
void SetMaximumTreeDepth(int value)
MITKCORE_EXPORT const ScalarType eps
void OnlineTrain(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y)
bool SupportsPointWiseProbability()
SupportsPointWiseProbability.
void SetSamplesPerTree(double)
Eigen::MatrixXi PredictWeighted(const Eigen::MatrixXd &X)
static itkEventMacro(BoundingShapeInteractionEvent, itk::AnyEvent) class MITKBOUNDINGSHAPE_EXPORT BoundingShapeInteractor Pointer New()
Basic interaction methods for mitk::GeometryData.