13 #ifndef mitkVigraRandomForestClassifier_h
14 #define mitkVigraRandomForestClassifier_h
16 #include <MitkCLVigraRandomForestExports.h>
20 #include <vigra/random_forest.hxx>
32 itkFactorylessNewMacro(
Self);
40 void Train(
const Eigen::MatrixXd &X,
const Eigen::MatrixXi &Y)
override;
41 void OnlineTrain(
const Eigen::MatrixXd &X,
const Eigen::MatrixXi &Y);
42 Eigen::MatrixXi Predict(
const Eigen::MatrixXd &X)
override;
43 Eigen::MatrixXi PredictWeighted(
const Eigen::MatrixXd &X);
46 bool SupportsPointWiseWeight()
override;
47 bool SupportsPointWiseProbability()
override;
48 void ConvertParameter();
50 void SetRandomForest(
const vigra::RandomForest<int> & rf);
51 const vigra::RandomForest<int> & GetRandomForest()
const;
53 void UsePointWiseWeight(
bool)
override;
54 void SetMaximumTreeDepth(
int);
55 void SetMinimumSplitNodeSize(
int);
56 void SetPrecision(
double);
57 void SetSamplesPerTree(
double);
58 void UseSampleWithReplacement(
bool);
59 void SetTreeCount(
int);
60 void SetWeightLambda(
double);
62 void SetTreeWeights(Eigen::MatrixXd weights);
63 void SetTreeWeight(
int treeId,
double weight);
64 Eigen::MatrixXd GetTreeWeights()
const;
66 void PrintParameter(std::ostream &str = std::cout);
75 struct PredictionData;
76 struct EigenToVigraTransform;
79 vigra::MultiArrayView<2, double> m_Probabilities;
80 Eigen::MatrixXd m_TreeWeights;
82 Parameter * m_Parameter;
83 vigra::RandomForest<int> m_RandomForest;
85 static itk::ITK_THREAD_RETURN_TYPE TrainTreesCallback(
void *);
86 static itk::ITK_THREAD_RETURN_TYPE PredictCallback(
void *);
87 static itk::ITK_THREAD_RETURN_TYPE PredictWeightedCallback(
void *);
88 static void VigraPredictWeighted(PredictionData *data, vigra::MultiArrayView<2, double> & X, vigra::MultiArrayView<2, int> & Y, vigra::MultiArrayView<2, double> & P);