22 #include <vigra/random_forest.hxx> 23 #include <vigra/random_forest/rf_split.hxx> 26 #include <itkFastMutexLock.h> 27 #include <itkMultiThreader.h> 28 #include <itkCommand.h> 32 struct mitk::PURFClassifier::Parameter
34 vigra::RF_OptionTag Stratification;
35 bool SampleWithReplacement;
37 bool UsePointBasedWeights;
39 int MinimumSplitNodeSize;
43 double SamplesPerTree;
46 struct mitk::PURFClassifier::TrainingData
48 TrainingData(
unsigned int numberOfTrees,
49 const vigra::RandomForest<int> & refRF,
51 const vigra::MultiArrayView<2, double> refFeature,
52 const vigra::MultiArrayView<2, int> refLabel,
53 const Parameter parameter)
55 m_NumberOfTrees(numberOfTrees),
56 m_RandomForest(refRF),
57 m_Splitter(refSplitter),
58 m_Feature(refFeature),
60 m_Parameter(parameter)
62 m_mutex = itk::FastMutexLock::New();
64 vigra::ArrayVector<vigra::RandomForest<int>::DecisionTree_t> trees_;
67 unsigned int m_NumberOfTrees;
68 const vigra::RandomForest<int> & m_RandomForest;
70 const vigra::MultiArrayView<2, double> m_Feature;
71 const vigra::MultiArrayView<2, int> m_Label;
72 itk::FastMutexLock::Pointer m_mutex;
73 Parameter m_Parameter;
76 struct mitk::PURFClassifier::PredictionData
78 PredictionData(
const vigra::RandomForest<int> & refRF,
79 const vigra::MultiArrayView<2, double> refFeature,
80 vigra::MultiArrayView<2, int> refLabel,
81 vigra::MultiArrayView<2, double> refProb,
82 vigra::MultiArrayView<2, double> refTreeWeights)
83 : m_RandomForest(refRF),
84 m_Feature(refFeature),
86 m_Probabilities(refProb),
87 m_TreeWeights(refTreeWeights)
90 const vigra::RandomForest<int> & m_RandomForest;
91 const vigra::MultiArrayView<2, double> m_Feature;
92 vigra::MultiArrayView<2, int> m_Label;
93 vigra::MultiArrayView<2, double> m_Probabilities;
94 vigra::MultiArrayView<2, double> m_TreeWeights;
100 itk::SimpleMemberCommand<mitk::PURFClassifier>::Pointer command = itk::SimpleMemberCommand<mitk::PURFClassifier>::New();
111 m_ClassProbabilities = probabilities;
116 return m_ClassProbabilities;
132 int maximumValue = Y_in.maxCoeff();
133 vigra::ArrayVector<double> kappa(maximumValue + 1);
134 vigra::ArrayVector<double> counts(maximumValue + 1);
135 for (
int i = 0; i < Y_in.rows(); ++i)
137 counts[Y_in(i, 0)] += 1;
139 for (
int i = 0; i < maximumValue+1; ++i)
143 kappa[i] = counts[0] * m_ClassProbabilities[i] / counts[i] + 1;
169 if (m_Parameter->UsePointBasedWeights)
172 this->
m_PointWiseWeight.unaryExpr([
this](
double t){
return std::pow(t, this->m_Parameter->WeightLambda) ;});
178 vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
179 vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data());
181 m_RandomForest.set_options().tree_count(1);
183 m_RandomForest.set_options().use_stratification(m_Parameter->Stratification);
184 m_RandomForest.set_options().sample_with_replacement(m_Parameter->SampleWithReplacement);
185 m_RandomForest.set_options().samples_per_tree(m_Parameter->SamplesPerTree);
186 m_RandomForest.set_options().min_split_node_size(m_Parameter->MinimumSplitNodeSize);
188 m_RandomForest.learn(X, Y,vigra::rf::visitors::VisitorBase(),splitter);
190 std::unique_ptr<TrainingData> data(
new TrainingData(m_Parameter->TreeCount,m_RandomForest,splitter,X,Y, *m_Parameter));
192 itk::MultiThreader::Pointer threader = itk::MultiThreader::New();
193 threader->SetSingleMethod(this->TrainTreesCallback,data.get());
194 threader->SingleMethodExecute();
197 m_RandomForest.set_options().tree_count(m_Parameter->TreeCount);
198 m_RandomForest.ext_param_.class_count_ = data->m_ClassCount;
199 m_RandomForest.trees_ = data->trees_;
202 m_TreeWeights = Eigen::MatrixXd(m_Parameter->TreeCount,1);
203 m_TreeWeights.fill(1.0);
210 m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count());
216 if(m_TreeWeights.rows() != m_RandomForest.tree_count())
218 m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1);
219 m_TreeWeights.fill(1);
224 vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data());
225 vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data());
227 std::unique_ptr<PredictionData> data;
228 data.reset(
new PredictionData(m_RandomForest, X, Y, P, TW));
230 itk::MultiThreader::Pointer threader = itk::MultiThreader::New();
231 threader->SetSingleMethod(this->PredictCallback, data.get());
232 threader->SingleMethodExecute();
234 m_Probabilities = data->m_Probabilities;
238 ITK_THREAD_RETURN_TYPE mitk::PURFClassifier::TrainTreesCallback(
void * arg)
241 typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType;
242 ThreadInfoType * infoStruct =
static_cast< ThreadInfoType *
>( arg );
244 TrainingData * data = (TrainingData *)(infoStruct->UserData);
245 unsigned int numberOfTreesToCalculate = 0;
248 numberOfTreesToCalculate = data->m_NumberOfTrees / infoStruct->NumberOfThreads;
251 if(infoStruct->ThreadID == 0) numberOfTreesToCalculate += data->m_NumberOfTrees % infoStruct->NumberOfThreads;
253 if(numberOfTreesToCalculate != 0){
255 vigra::RandomForest<int> rf = data->m_RandomForest;
263 splitter.
SetWeights(data->m_Splitter.GetWeights());
267 rf.set_options().tree_count(numberOfTreesToCalculate);
268 rf.set_options().use_stratification(data->m_Parameter.Stratification);
269 rf.set_options().sample_with_replacement(data->m_Parameter.SampleWithReplacement);
270 rf.set_options().samples_per_tree(data->m_Parameter.SamplesPerTree);
271 rf.set_options().min_split_node_size(data->m_Parameter.MinimumSplitNodeSize);
272 rf.learn(data->m_Feature, data->m_Label,vigra::rf::visitors::VisitorBase(),splitter);
274 data->m_mutex->Lock();
276 for(
const auto & tree : rf.trees_)
277 data->trees_.push_back(tree);
279 data->m_ClassCount = rf.class_count();
280 data->m_mutex->Unlock();
283 return ITK_THREAD_RETURN_VALUE;
287 ITK_THREAD_RETURN_TYPE mitk::PURFClassifier::PredictCallback(
void * arg)
290 typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType;
291 ThreadInfoType * infoStruct =
static_cast< ThreadInfoType *
>( arg );
293 const unsigned int threadId = infoStruct->ThreadID;
297 PredictionData * data = (PredictionData *)(infoStruct->UserData);
298 unsigned int numberOfRowsToCalculate = 0;
301 numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads;
303 unsigned int start_index = numberOfRowsToCalculate * threadId;
304 unsigned int end_index = numberOfRowsToCalculate * (threadId+1);
307 if(threadId == infoStruct->NumberOfThreads-1) {
308 end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads;
311 vigra::MultiArrayView<2, double> split_features;
312 vigra::MultiArrayView<2, int> split_labels;
313 vigra::MultiArrayView<2, double> split_probability;
315 vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
316 vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Feature.shape(1));
317 split_features = data->m_Feature.subarray(lowerBound,upperBound);
321 vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
322 vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index, data->m_Label.shape(1));
323 split_labels = data->m_Label.subarray(lowerBound,upperBound);
327 vigra::TinyVector<vigra::MultiArrayIndex, 2> lowerBound(start_index,0);
328 vigra::TinyVector<vigra::MultiArrayIndex, 2> upperBound(end_index,data->m_Probabilities.shape(1));
329 split_probability = data->m_Probabilities.subarray(lowerBound,upperBound);
332 data->m_RandomForest.predictLabels(split_features,split_labels);
333 data->m_RandomForest.predictProbabilities(split_features, split_probability);
336 return ITK_THREAD_RETURN_VALUE;
342 if(this->m_Parameter ==
nullptr)
343 this->m_Parameter =
new Parameter();
346 MITK_INFO(
"PURFClassifier") <<
"Convert Parameter";
347 if(!this->
GetPropertyList()->Get(
"usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) this->m_Parameter->UsePointBasedWeights =
false;
348 if(!this->
GetPropertyList()->Get(
"userandomsplit",this->m_Parameter->UseRandomSplit)) this->m_Parameter->UseRandomSplit =
false;
349 if(!this->
GetPropertyList()->Get(
"treedepth",this->m_Parameter->TreeDepth)) this->m_Parameter->TreeDepth = 20;
350 if(!this->
GetPropertyList()->Get(
"treecount",this->m_Parameter->TreeCount)) this->m_Parameter->TreeCount = 100;
351 if(!this->
GetPropertyList()->Get(
"minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) this->m_Parameter->MinimumSplitNodeSize = 5;
352 if(!this->
GetPropertyList()->Get(
"precision",this->m_Parameter->Precision)) this->m_Parameter->Precision =
mitk::eps;
353 if(!this->
GetPropertyList()->Get(
"samplespertree",this->m_Parameter->SamplesPerTree)) this->m_Parameter->SamplesPerTree = 0.6;
354 if(!this->
GetPropertyList()->Get(
"samplewithreplacement",this->m_Parameter->SampleWithReplacement)) this->m_Parameter->SampleWithReplacement =
true;
355 if(!this->
GetPropertyList()->Get(
"lambda",this->m_Parameter->WeightLambda)) this->m_Parameter->WeightLambda = 1.0;
357 this->m_Parameter->Stratification = vigra::RF_NONE;
362 if(this->m_Parameter ==
nullptr)
364 MITK_WARN(
"PURFClassifier") <<
"Parameters are not initialized. Please call ConvertParameter() first!";
371 if(!this->
GetPropertyList()->Get(
"usepointbasedweight",this->m_Parameter->UsePointBasedWeights))
372 str <<
"usepointbasedweight\tNOT SET (default " << this->m_Parameter->UsePointBasedWeights <<
")" <<
"\n";
374 str <<
"usepointbasedweight\t" << this->m_Parameter->UsePointBasedWeights <<
"\n";
376 if(!this->
GetPropertyList()->Get(
"userandomsplit",this->m_Parameter->UseRandomSplit))
377 str <<
"userandomsplit\tNOT SET (default " << this->m_Parameter->UseRandomSplit <<
")" <<
"\n";
379 str <<
"userandomsplit\t" << this->m_Parameter->UseRandomSplit <<
"\n";
381 if(!this->
GetPropertyList()->Get(
"treedepth",this->m_Parameter->TreeDepth))
382 str <<
"treedepth\t\tNOT SET (default " << this->m_Parameter->TreeDepth <<
")" <<
"\n";
384 str <<
"treedepth\t\t" << this->m_Parameter->TreeDepth <<
"\n";
386 if(!this->
GetPropertyList()->Get(
"minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize))
387 str <<
"minimalsplitnodesize\tNOT SET (default " << this->m_Parameter->MinimumSplitNodeSize <<
")" <<
"\n";
389 str <<
"minimalsplitnodesize\t" << this->m_Parameter->MinimumSplitNodeSize <<
"\n";
391 if(!this->
GetPropertyList()->Get(
"precision",this->m_Parameter->Precision))
392 str <<
"precision\t\tNOT SET (default " << this->m_Parameter->Precision <<
")" <<
"\n";
394 str <<
"precision\t\t" << this->m_Parameter->Precision <<
"\n";
396 if(!this->
GetPropertyList()->Get(
"samplespertree",this->m_Parameter->SamplesPerTree))
397 str <<
"samplespertree\tNOT SET (default " << this->m_Parameter->SamplesPerTree <<
")" <<
"\n";
399 str <<
"samplespertree\t" << this->m_Parameter->SamplesPerTree <<
"\n";
401 if(!this->
GetPropertyList()->Get(
"samplewithreplacement",this->m_Parameter->SampleWithReplacement))
402 str <<
"samplewithreplacement\tNOT SET (default " << this->m_Parameter->SampleWithReplacement <<
")" <<
"\n";
404 str <<
"samplewithreplacement\t" << this->m_Parameter->SampleWithReplacement <<
"\n";
406 if(!this->
GetPropertyList()->Get(
"treecount",this->m_Parameter->TreeCount))
407 str <<
"treecount\t\tNOT SET (default " << this->m_Parameter->TreeCount <<
")" <<
"\n";
409 str <<
"treecount\t\t" << this->m_Parameter->TreeCount <<
"\n";
411 if(!this->
GetPropertyList()->Get(
"lambda",this->m_Parameter->WeightLambda))
412 str <<
"lambda\t\tNOT SET (default " << this->m_Parameter->WeightLambda <<
")" <<
"\n";
414 str <<
"lambda\t\t" << this->m_Parameter->WeightLambda <<
"\n";
468 this->m_RandomForest = rf;
473 return this->m_RandomForest;
void UseSampleWithReplacement(bool)
vigra::ArrayVector< double > CalculateKappa(const Eigen::MatrixXd &X_in, const Eigen::MatrixXi &Y_in)
void SetAdditionalData(AdditionalRFDataAbstract *data)
void SetSamplesPerTree(double)
void UseRandomSplit(bool split)
Eigen::MatrixXi m_OutLabel
void SetWeightLambda(double)
void UsePointBasedWeights(bool weightsOn)
void PrintParameter(std::ostream &str=std::cout)
void SetMaximumTreeDepth(int)
void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y) override
Build a forest of trees from the training set (X, y).
void SetRandomForest(const vigra::RandomForest< int > &rf)
Eigen::MatrixXd m_PointWiseWeight
void UsePointWiseWeight(bool) override
UsePointWiseWeight.
void SetWeights(vigra::MultiArrayView< 2, double > weights)
mitk::ThresholdSplit< mitk::LinearSplitting< mitk::PUImpurityLoss<> >, int, vigra::ClassificationTag > DefaultPUSplitType
virtual void UsePointWiseWeight(bool value)
UsePointWiseWeight.
void SetMinimumSplitNodeSize(int)
~PURFClassifier() override
const vigra::RandomForest< int > & GetRandomForest() const
void SetPrecision(double value)
void SetMaximumTreeDepth(int value)
Eigen::MatrixXi Predict(const Eigen::MatrixXd &X) override
Predict class for X.
mitk::PropertyList::Pointer GetPropertyList() const
Get the data's property list.
void SetPrecision(double)
Eigen::MatrixXd m_OutProbability
vigra::ArrayVector< double > m_Kappa
bool SupportsPointWiseProbability() override
SupportsPointWiseProbability.
MITKCORE_EXPORT const ScalarType eps
void SetClassProbabilities(Eigen::VectorXd probabilities)
Eigen::VectorXd GetClassProbabilites()
bool SupportsPointWiseWeight() override
SupportsPointWiseWeight.