21 #include <vigra/random_forest.hxx> 22 #include <vigra/random_forest/rf_split.hxx> 25 #include <itkFastMutexLock.h> 26 #include <itkMultiThreader.h> 27 #include <itkCommand.h> 31 struct mitk::VigraRandomForestClassifier::Parameter
33 vigra::RF_OptionTag Stratification;
34 bool SampleWithReplacement;
36 bool UsePointBasedWeights;
38 int MinimumSplitNodeSize;
42 double SamplesPerTree;
45 struct mitk::VigraRandomForestClassifier::TrainingData
47 TrainingData(
unsigned int numberOfTrees,
48 const vigra::RandomForest<int> & refRF,
50 const vigra::MultiArrayView<2, double> refFeature,
51 const vigra::MultiArrayView<2, int> refLabel,
52 const Parameter parameter)
54 m_NumberOfTrees(numberOfTrees),
55 m_RandomForest(refRF),
56 m_Splitter(refSplitter),
57 m_Feature(refFeature),
59 m_Parameter(parameter)
61 m_mutex = itk::FastMutexLock::New();
63 vigra::ArrayVector<vigra::RandomForest<int>::DecisionTree_t> trees_;
66 unsigned int m_NumberOfTrees;
67 const vigra::RandomForest<int> & m_RandomForest;
69 const vigra::MultiArrayView<2, double> m_Feature;
70 const vigra::MultiArrayView<2, int> m_Label;
71 itk::FastMutexLock::Pointer m_mutex;
72 Parameter m_Parameter;
75 struct mitk::VigraRandomForestClassifier::PredictionData
77 PredictionData(
const vigra::RandomForest<int> & refRF,
78 const vigra::MultiArrayView<2, double> refFeature,
79 vigra::MultiArrayView<2, int> refLabel,
80 vigra::MultiArrayView<2, double> refProb,
81 vigra::MultiArrayView<2, double> refTreeWeights)
82 : m_RandomForest(refRF),
83 m_Feature(refFeature),
85 m_Probabilities(refProb),
86 m_TreeWeights(refTreeWeights)
89 const vigra::RandomForest<int> & m_RandomForest;
90 const vigra::MultiArrayView<2, double> m_Feature;
91 vigra::MultiArrayView<2, int> m_Label;
92 vigra::MultiArrayView<2, double> m_Probabilities;
93 vigra::MultiArrayView<2, double> m_TreeWeights;
99 itk::SimpleMemberCommand<mitk::VigraRandomForestClassifier>::Pointer command = itk::SimpleMemberCommand<mitk::VigraRandomForestClassifier>::New();
120 vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
121 vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data());
122 m_RandomForest.onlineLearn(X,Y,0,
true);
136 if (m_Parameter->UsePointBasedWeights)
139 this->
m_PointWiseWeight.unaryExpr([
this](
double t){
return std::pow(t, this->m_Parameter->WeightLambda) ;});
145 vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
146 vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data());
148 m_RandomForest.set_options().tree_count(1);
150 m_RandomForest.set_options().use_stratification(m_Parameter->Stratification);
151 m_RandomForest.set_options().sample_with_replacement(m_Parameter->SampleWithReplacement);
152 m_RandomForest.set_options().samples_per_tree(m_Parameter->SamplesPerTree);
153 m_RandomForest.set_options().min_split_node_size(m_Parameter->MinimumSplitNodeSize);
155 m_RandomForest.learn(X, Y,vigra::rf::visitors::VisitorBase(),splitter);
157 std::unique_ptr<TrainingData> data(
new TrainingData(m_Parameter->TreeCount,m_RandomForest,splitter,X,Y, *m_Parameter));
159 itk::MultiThreader::Pointer threader = itk::MultiThreader::New();
160 threader->SetSingleMethod(this->TrainTreesCallback,data.get());
161 threader->SingleMethodExecute();
164 m_RandomForest.set_options().tree_count(m_Parameter->TreeCount);
165 m_RandomForest.ext_param_.class_count_ = data->m_ClassCount;
166 m_RandomForest.trees_ = data->trees_;
169 m_TreeWeights = Eigen::MatrixXd(m_Parameter->TreeCount,1);
170 m_TreeWeights.fill(1.0);
176 m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count());
182 if(m_TreeWeights.rows() != m_RandomForest.tree_count())
184 m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1);
185 m_TreeWeights.fill(1);
191 vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
192 vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data());
194 std::unique_ptr<PredictionData> data;
195 data.reset(
new PredictionData(m_RandomForest, X, Y, P, TW));
197 itk::MultiThreader::Pointer threader = itk::MultiThreader::New();
198 threader->SetSingleMethod(this->PredictCallback, data.get());
199 threader->SingleMethodExecute();
201 m_Probabilities = data->m_Probabilities;
208 m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count());
214 if(m_TreeWeights.rows() != m_RandomForest.tree_count())
216 m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1);
217 m_TreeWeights.fill(1);
223 vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
224 vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data());
226 std::unique_ptr<PredictionData> data;
227 data.reset(
new PredictionData(m_RandomForest,X,Y,P,TW));
229 itk::MultiThreader::Pointer threader = itk::MultiThreader::New();
230 threader->SetSingleMethod(this->PredictWeightedCallback,data.get());
231 threader->SingleMethodExecute();
240 m_TreeWeights = weights;
245 return m_TreeWeights;
248 ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::TrainTreesCallback(
void * arg)
251 typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType;
252 ThreadInfoType * infoStruct =
static_cast< ThreadInfoType *
>( arg );
254 TrainingData * data = (TrainingData *)(infoStruct->UserData);
255 unsigned int numberOfTreesToCalculate = 0;
258 numberOfTreesToCalculate = data->m_NumberOfTrees / infoStruct->NumberOfThreads;
261 if(infoStruct->ThreadID == 0) numberOfTreesToCalculate += data->m_NumberOfTrees % infoStruct->NumberOfThreads;
263 if(numberOfTreesToCalculate != 0){
265 vigra::RandomForest<int> rf = data->m_RandomForest;
273 splitter.
SetWeights(data->m_Splitter.GetWeights());
276 rf.set_options().tree_count(numberOfTreesToCalculate);
277 rf.set_options().use_stratification(data->m_Parameter.Stratification);
278 rf.set_options().sample_with_replacement(data->m_Parameter.SampleWithReplacement);
279 rf.set_options().samples_per_tree(data->m_Parameter.SamplesPerTree);
280 rf.set_options().min_split_node_size(data->m_Parameter.MinimumSplitNodeSize);
281 rf.learn(data->m_Feature, data->m_Label,vigra::rf::visitors::VisitorBase(),splitter);
283 data->m_mutex->Lock();
285 for(
const auto & tree : rf.trees_)
286 data->trees_.push_back(tree);
288 data->m_ClassCount = rf.class_count();
289 data->m_mutex->Unlock();
292 return ITK_THREAD_RETURN_VALUE;
296 ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictCallback(
void * arg)
299 typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType;
300 ThreadInfoType * infoStruct =
static_cast< ThreadInfoType *
>( arg );
302 const unsigned int threadId = infoStruct->ThreadID;
306 PredictionData * data = (PredictionData *)(infoStruct->UserData);
307 unsigned int numberOfRowsToCalculate = 0;
310 numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads;
312 unsigned int start_index = numberOfRowsToCalculate * threadId;
313 unsigned int end_index = numberOfRowsToCalculate * (threadId+1);
316 if(threadId == infoStruct->NumberOfThreads-1) {
317 end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads;
320 vigra::MultiArrayView<2, double> split_features;
321 vigra::MultiArrayView<2, int> split_labels;
322 vigra::MultiArrayView<2, double> split_probability;
324 vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
325 vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Feature.shape(1));
326 split_features = data->m_Feature.subarray(lowerBound,upperBound);
330 vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
331 vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index, data->m_Label.shape(1));
332 split_labels = data->m_Label.subarray(lowerBound,upperBound);
336 vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
337 vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Probabilities.shape(1));
338 split_probability = data->m_Probabilities.subarray(lowerBound,upperBound);
341 data->m_RandomForest.predictLabels(split_features,split_labels);
342 data->m_RandomForest.predictProbabilities(split_features, split_probability);
345 return ITK_THREAD_RETURN_VALUE;
349 ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictWeightedCallback(
void * arg)
352 typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType;
353 ThreadInfoType * infoStruct =
static_cast< ThreadInfoType *
>( arg );
355 const unsigned int threadId = infoStruct->ThreadID;
359 PredictionData * data = (PredictionData *)(infoStruct->UserData);
360 unsigned int numberOfRowsToCalculate = 0;
363 numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads;
365 unsigned int start_index = numberOfRowsToCalculate * threadId;
366 unsigned int end_index = numberOfRowsToCalculate * (threadId+1);
369 if(threadId == infoStruct->NumberOfThreads-1) {
370 end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads;
373 vigra::MultiArrayView<2, double> split_features;
374 vigra::MultiArrayView<2, int> split_labels;
375 vigra::MultiArrayView<2, double> split_probability;
377 vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
378 vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Feature.shape(1));
379 split_features = data->m_Feature.subarray(lowerBound,upperBound);
383 vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
384 vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index, data->m_Label.shape(1));
385 split_labels = data->m_Label.subarray(lowerBound,upperBound);
389 vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
390 vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Probabilities.shape(1));
391 split_probability = data->m_Probabilities.subarray(lowerBound,upperBound);
394 VigraPredictWeighted(data, split_features,split_labels,split_probability);
396 return ITK_THREAD_RETURN_VALUE;
400 void mitk::VigraRandomForestClassifier::VigraPredictWeighted(PredictionData * data, vigra::MultiArrayView<2, double> & X, vigra::MultiArrayView<2, int> & Y, vigra::MultiArrayView<2, double> & P)
403 int isSampleWeighted = data->m_RandomForest.options_.predict_weighted_;
405 for(
int row=0; row < vigra::rowCount(X); ++row)
407 vigra::MultiArrayView<2, double, vigra::StridedArrayTag> currentRow(rowVector(X, row));
409 vigra::ArrayVector<double>::const_iterator weights;
412 double totalWeight = 0.0;
415 for(
int k=0;
k<data->m_RandomForest.options_.tree_count_; ++
k)
418 weights = data->m_RandomForest.trees_[
k ].predict(currentRow);
419 double numberOfLeafObservations = (*(weights-1));
422 for(
int l=0; l<data->m_RandomForest.ext_param_.class_count_; ++l)
425 double cur_w = weights[l] * (isSampleWeighted * numberOfLeafObservations + (1-isSampleWeighted));
426 cur_w = cur_w * data->m_TreeWeights(
k,0);
427 P(row, l) += (int)cur_w;
429 totalWeight += cur_w;
434 for(
int l=0; l< data->m_RandomForest.ext_param_.class_count_; ++l)
436 P(row, l) /= vigra::detail::RequiresExplicitCast<double>::cast(totalWeight);
440 for (
int col=0;col<data->m_RandomForest.class_count();++col)
442 if (data->m_Probabilities(row,col) > data->m_Probabilities(row, maxCol))
445 data->m_RandomForest.ext_param_.to_classlabel(maxCol, erg);
452 if(this->m_Parameter ==
nullptr)
453 this->m_Parameter =
new Parameter();
456 MITK_INFO(
"VigraRandomForestClassifier") <<
"Convert Parameter";
457 if(!this->
GetPropertyList()->Get(
"usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) this->m_Parameter->UsePointBasedWeights =
false;
458 if(!this->
GetPropertyList()->Get(
"userandomsplit",this->m_Parameter->UseRandomSplit)) this->m_Parameter->UseRandomSplit =
false;
459 if(!this->
GetPropertyList()->Get(
"treedepth",this->m_Parameter->TreeDepth)) this->m_Parameter->TreeDepth = 20;
460 if(!this->
GetPropertyList()->Get(
"treecount",this->m_Parameter->TreeCount)) this->m_Parameter->TreeCount = 100;
461 if(!this->
GetPropertyList()->Get(
"minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) this->m_Parameter->MinimumSplitNodeSize = 5;
462 if(!this->
GetPropertyList()->Get(
"precision",this->m_Parameter->Precision)) this->m_Parameter->Precision =
mitk::eps;
463 if(!this->
GetPropertyList()->Get(
"samplespertree",this->m_Parameter->SamplesPerTree)) this->m_Parameter->SamplesPerTree = 0.6;
464 if(!this->
GetPropertyList()->Get(
"samplewithreplacement",this->m_Parameter->SampleWithReplacement)) this->m_Parameter->SampleWithReplacement =
true;
465 if(!this->
GetPropertyList()->Get(
"lambda",this->m_Parameter->WeightLambda)) this->m_Parameter->WeightLambda = 1.0;
467 this->m_Parameter->Stratification = vigra::RF_NONE;
472 if(this->m_Parameter ==
nullptr)
474 MITK_WARN(
"VigraRandomForestClassifier") <<
"Parameters are not initialized. Please call ConvertParameter() first!";
481 if(!this->
GetPropertyList()->Get(
"usepointbasedweight",this->m_Parameter->UsePointBasedWeights))
482 str <<
"usepointbasedweight\tNOT SET (default " << this->m_Parameter->UsePointBasedWeights <<
")" <<
"\n";
484 str <<
"usepointbasedweight\t" << this->m_Parameter->UsePointBasedWeights <<
"\n";
486 if(!this->
GetPropertyList()->Get(
"userandomsplit",this->m_Parameter->UseRandomSplit))
487 str <<
"userandomsplit\tNOT SET (default " << this->m_Parameter->UseRandomSplit <<
")" <<
"\n";
489 str <<
"userandomsplit\t" << this->m_Parameter->UseRandomSplit <<
"\n";
491 if(!this->
GetPropertyList()->Get(
"treedepth",this->m_Parameter->TreeDepth))
492 str <<
"treedepth\t\tNOT SET (default " << this->m_Parameter->TreeDepth <<
")" <<
"\n";
494 str <<
"treedepth\t\t" << this->m_Parameter->TreeDepth <<
"\n";
496 if(!this->
GetPropertyList()->Get(
"minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize))
497 str <<
"minimalsplitnodesize\tNOT SET (default " << this->m_Parameter->MinimumSplitNodeSize <<
")" <<
"\n";
499 str <<
"minimalsplitnodesize\t" << this->m_Parameter->MinimumSplitNodeSize <<
"\n";
501 if(!this->
GetPropertyList()->Get(
"precision",this->m_Parameter->Precision))
502 str <<
"precision\t\tNOT SET (default " << this->m_Parameter->Precision <<
")" <<
"\n";
504 str <<
"precision\t\t" << this->m_Parameter->Precision <<
"\n";
506 if(!this->
GetPropertyList()->Get(
"samplespertree",this->m_Parameter->SamplesPerTree))
507 str <<
"samplespertree\tNOT SET (default " << this->m_Parameter->SamplesPerTree <<
")" <<
"\n";
509 str <<
"samplespertree\t" << this->m_Parameter->SamplesPerTree <<
"\n";
511 if(!this->
GetPropertyList()->Get(
"samplewithreplacement",this->m_Parameter->SampleWithReplacement))
512 str <<
"samplewithreplacement\tNOT SET (default " << this->m_Parameter->SampleWithReplacement <<
")" <<
"\n";
514 str <<
"samplewithreplacement\t" << this->m_Parameter->SampleWithReplacement <<
"\n";
516 if(!this->
GetPropertyList()->Get(
"treecount",this->m_Parameter->TreeCount))
517 str <<
"treecount\t\tNOT SET (default " << this->m_Parameter->TreeCount <<
")" <<
"\n";
519 str <<
"treecount\t\t" << this->m_Parameter->TreeCount <<
"\n";
521 if(!this->
GetPropertyList()->Get(
"lambda",this->m_Parameter->WeightLambda))
522 str <<
"lambda\t\tNOT SET (default " << this->m_Parameter->WeightLambda <<
")" <<
"\n";
524 str <<
"lambda\t\t" << this->m_Parameter->WeightLambda <<
"\n";
573 m_TreeWeights(treeId,0) = weight;
583 this->m_RandomForest = rf;
588 return this->m_RandomForest;
void SetMaximumTreeDepth(int)
void SetPrecision(double)
void SetTreeWeight(int treeId, double weight)
void UseRandomSplit(bool split)
Eigen::MatrixXi m_OutLabel
void UseSampleWithReplacement(bool)
const vigra::RandomForest< int > & GetRandomForest() const
void UsePointBasedWeights(bool weightsOn)
void PrintParameter(std::ostream &str=std::cout)
Eigen::MatrixXi Predict(const Eigen::MatrixXd &X) override
Predict class for X.
mitk::ThresholdSplit< mitk::LinearSplitting< mitk::ImpurityLoss<> >, int, vigra::ClassificationTag > DefaultSplitType
void SetMinimumSplitNodeSize(int)
Eigen::MatrixXd GetTreeWeights() const
void SetWeightLambda(double)
bool SupportsPointWiseProbability() override
SupportsPointWiseProbability.
Eigen::MatrixXd m_PointWiseWeight
void SetTreeWeights(Eigen::MatrixXd weights)
void SetWeights(vigra::MultiArrayView< 2, double > weights)
virtual void UsePointWiseWeight(bool value)
UsePointWiseWeight.
VigraRandomForestClassifier()
void SetRandomForest(const vigra::RandomForest< int > &rf)
void SetPrecision(double value)
void SetMaximumTreeDepth(int value)
mitk::PropertyList::Pointer GetPropertyList() const
Get the data's property list.
~VigraRandomForestClassifier() override
Eigen::MatrixXd m_OutProbability
MITKCORE_EXPORT const ScalarType eps
void OnlineTrain(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y)
bool SupportsPointWiseWeight() override
SupportsPointWiseWeight.
void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y) override
Build a forest of trees from the training set (X, y).
void UsePointWiseWeight(bool) override
UsePointWiseWeight.
void SetSamplesPerTree(double)
Eigen::MatrixXi PredictWeighted(const Eigen::MatrixXd &X)