Medical Imaging Interaction Toolkit  2023.04.00
Medical Imaging Interaction Toolkit
mitkVigraRandomForestClassifier.h
Go to the documentation of this file.
1 /*============================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center (DKFZ)
6 All rights reserved.
7 
8 Use of this source code is governed by a 3-clause BSD license that can be
9 found in the LICENSE file.
10 
11 ============================================================================*/
12 
13 #ifndef mitkVigraRandomForestClassifier_h
14 #define mitkVigraRandomForestClassifier_h
15 
16 #include <MitkCLVigraRandomForestExports.h>
17 #include <mitkAbstractClassifier.h>
18 
19 //#include <vigra/multi_array.hxx>
20 #include <vigra/random_forest.hxx>
21 
22 #include <mitkBaseData.h>
23 
24 namespace mitk
25 {
26  class MITKCLVIGRARANDOMFOREST_EXPORT VigraRandomForestClassifier : public AbstractClassifier
27  {
28  public:
29 
31 
32  itkFactorylessNewMacro(Self);
33 
34  itkCloneMacro(Self);
35 
37 
38  ~VigraRandomForestClassifier() override;
39 
40  void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y) override;
41  void OnlineTrain(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y);
42  Eigen::MatrixXi Predict(const Eigen::MatrixXd &X) override;
43  Eigen::MatrixXi PredictWeighted(const Eigen::MatrixXd &X);
44 
45 
46  bool SupportsPointWiseWeight() override;
47  bool SupportsPointWiseProbability() override;
48  void ConvertParameter();
49 
50  void SetRandomForest(const vigra::RandomForest<int> & rf);
51  const vigra::RandomForest<int> & GetRandomForest() const;
52 
53  void UsePointWiseWeight(bool) override;
54  void SetMaximumTreeDepth(int);
55  void SetMinimumSplitNodeSize(int);
56  void SetPrecision(double);
57  void SetSamplesPerTree(double);
58  void UseSampleWithReplacement(bool);
59  void SetTreeCount(int);
60  void SetWeightLambda(double);
61 
62  void SetTreeWeights(Eigen::MatrixXd weights);
63  void SetTreeWeight(int treeId, double weight);
64  Eigen::MatrixXd GetTreeWeights() const;
65 
66  void PrintParameter(std::ostream &str = std::cout);
67 
68  private:
69  // *-------------------
70  // * THREADING
71  // *-------------------
72 
73 
74  struct TrainingData;
75  struct PredictionData;
76  struct EigenToVigraTransform;
77  struct Parameter;
78 
79  vigra::MultiArrayView<2, double> m_Probabilities;
80  Eigen::MatrixXd m_TreeWeights;
81 
82  Parameter * m_Parameter;
83  vigra::RandomForest<int> m_RandomForest;
84 
85  static itk::ITK_THREAD_RETURN_TYPE TrainTreesCallback(void *);
86  static itk::ITK_THREAD_RETURN_TYPE PredictCallback(void *);
87  static itk::ITK_THREAD_RETURN_TYPE PredictWeightedCallback(void *);
88  static void VigraPredictWeighted(PredictionData *data, vigra::MultiArrayView<2, double> & X, vigra::MultiArrayView<2, int> & Y, vigra::MultiArrayView<2, double> & P);
89  };
90 }
91 
92 #endif
mitk
DataCollection - Class to facilitate loading/accessing structured data.
Definition: RenderingTests.dox:1
mitk::BaseData
Base of all data objects.
Definition: mitkBaseData.h:42
mitk::AbstractClassifier
Definition: mitkAbstractClassifier.h:32
mitkBaseData.h
mitkClassMacro
#define mitkClassMacro(className, SuperClassName)
Definition: mitkCommon.h:36
mitk::VigraRandomForestClassifier
Definition: mitkVigraRandomForestClassifier.h:26
mitkAbstractClassifier.h