Medical Imaging Interaction Toolkit  2023.04.00
Medical Imaging Interaction Toolkit
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
mitkVigraRandomForestClassifier.h
Go to the documentation of this file.
1 /*============================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center (DKFZ)
6 All rights reserved.
7 
8 Use of this source code is governed by a 3-clause BSD license that can be
9 found in the LICENSE file.
10 
11 ============================================================================*/
12 
13 #ifndef mitkVigraRandomForestClassifier_h
14 #define mitkVigraRandomForestClassifier_h
15 
16 #include <MitkCLVigraRandomForestExports.h>
17 #include <mitkAbstractClassifier.h>
18 
19 //#include <vigra/multi_array.hxx>
20 #include <vigra/random_forest.hxx>
21 
22 #include <mitkBaseData.h>
23 
24 namespace mitk
25 {
26  class MITKCLVIGRARANDOMFOREST_EXPORT VigraRandomForestClassifier : public AbstractClassifier
27  {
28  public:
29 
31 
32  itkFactorylessNewMacro(Self);
33 
34  itkCloneMacro(Self);
35 
37 
38  ~VigraRandomForestClassifier() override;
39 
40  void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y) override;
41  void OnlineTrain(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y);
42  Eigen::MatrixXi Predict(const Eigen::MatrixXd &X) override;
43  Eigen::MatrixXi PredictWeighted(const Eigen::MatrixXd &X);
44 
45 
46  bool SupportsPointWiseWeight() override;
47  bool SupportsPointWiseProbability() override;
48  void ConvertParameter();
49 
50  void SetRandomForest(const vigra::RandomForest<int> & rf);
51  const vigra::RandomForest<int> & GetRandomForest() const;
52 
53  void UsePointWiseWeight(bool) override;
54  void SetMaximumTreeDepth(int);
55  void SetMinimumSplitNodeSize(int);
56  void SetPrecision(double);
57  void SetSamplesPerTree(double);
58  void UseSampleWithReplacement(bool);
59  void SetTreeCount(int);
60  void SetWeightLambda(double);
61 
62  void SetTreeWeights(Eigen::MatrixXd weights);
63  void SetTreeWeight(int treeId, double weight);
64  Eigen::MatrixXd GetTreeWeights() const;
65 
66  void PrintParameter(std::ostream &str = std::cout);
67 
68  private:
69  // *-------------------
70  // * THREADING
71  // *-------------------
72 
73 
74  struct TrainingData;
75  struct PredictionData;
76  struct EigenToVigraTransform;
77  struct Parameter;
78 
79  vigra::MultiArrayView<2, double> m_Probabilities;
80  Eigen::MatrixXd m_TreeWeights;
81 
82  Parameter * m_Parameter;
83  vigra::RandomForest<int> m_RandomForest;
84 
85  static itk::ITK_THREAD_RETURN_TYPE TrainTreesCallback(void *);
86  static itk::ITK_THREAD_RETURN_TYPE PredictCallback(void *);
87  static itk::ITK_THREAD_RETURN_TYPE PredictWeightedCallback(void *);
88  static void VigraPredictWeighted(PredictionData *data, vigra::MultiArrayView<2, double> & X, vigra::MultiArrayView<2, int> & Y, vigra::MultiArrayView<2, double> & P);
89  };
90 }
91 
92 #endif
mitk
DataCollection - Class to facilitate loading/accessing structured data.
Definition: RenderingTests.dox:1
mitk::BaseData
Base of all data objects.
Definition: mitkBaseData.h:42
mitk::AbstractClassifier
Definition: mitkAbstractClassifier.h:32
mitkBaseData.h
mitkClassMacro
#define mitkClassMacro(className, SuperClassName)
Definition: mitkCommon.h:36
mitk::VigraRandomForestClassifier
Definition: mitkVigraRandomForestClassifier.h:26
mitkAbstractClassifier.h