Medical Imaging Interaction Toolkit  2023.04.00
Medical Imaging Interaction Toolkit
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
mitkPURFClassifier.h
Go to the documentation of this file.
1 /*============================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center (DKFZ)
6 All rights reserved.
7 
8 Use of this source code is governed by a 3-clause BSD license that can be
9 found in the LICENSE file.
10 
11 ============================================================================*/
12 
13 #ifndef mitkPURFClassifier_h
14 #define mitkPURFClassifier_h
15 
16 #include <MitkCLVigraRandomForestExports.h>
17 #include <mitkAbstractClassifier.h>
18 
19 //#include <vigra/multi_array.hxx>
20 #include <vigra/random_forest.hxx>
21 
22 #include <mitkBaseData.h>
23 
24 namespace mitk
25 {
26  class MITKCLVIGRARANDOMFOREST_EXPORT PURFClassifier : public AbstractClassifier
27  {
28  public:
29 
31 
32  itkFactorylessNewMacro(Self);
33 
34  itkCloneMacro(Self);
35 
37 
38  ~PURFClassifier() override;
39 
40  void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y) override;
41 
42  Eigen::MatrixXi Predict(const Eigen::MatrixXd &X) override;
43  Eigen::MatrixXi PredictWeighted(const Eigen::MatrixXd &X);
44 
45 
46  bool SupportsPointWiseWeight() override;
47  bool SupportsPointWiseProbability() override;
48  void ConvertParameter();
49  vigra::ArrayVector<double> CalculateKappa(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in);
50 
51  void SetRandomForest(const vigra::RandomForest<int> & rf);
52  const vigra::RandomForest<int> & GetRandomForest() const;
53 
54  void UsePointWiseWeight(bool) override;
55  void SetMaximumTreeDepth(int);
56  void SetMinimumSplitNodeSize(int);
57  void SetPrecision(double);
58  void SetSamplesPerTree(double);
59  void UseSampleWithReplacement(bool);
60  void SetTreeCount(int);
61  void SetWeightLambda(double);
62 
63  void PrintParameter(std::ostream &str = std::cout);
64 
65  void SetClassProbabilities(Eigen::VectorXd probabilities);
66  Eigen::VectorXd GetClassProbabilites();
67 
68  private:
69  // *-------------------
70  // * THREADING
71  // *-------------------
72 
73 
74  struct TrainingData;
75  struct PredictionData;
76  struct EigenToVigraTransform;
77  struct Parameter;
78 
79  vigra::MultiArrayView<2, double> m_Probabilities;
80  Eigen::MatrixXd m_TreeWeights;
81  Eigen::VectorXd m_ClassProbabilities;
82 
83  Parameter * m_Parameter;
84  vigra::RandomForest<int> m_RandomForest;
85 
86  static itk::ITK_THREAD_RETURN_TYPE TrainTreesCallback(void *);
87  static itk::ITK_THREAD_RETURN_TYPE PredictCallback(void *);
88  static itk::ITK_THREAD_RETURN_TYPE PredictWeightedCallback(void *);
89  static void VigraPredictWeighted(PredictionData *data, vigra::MultiArrayView<2, double> & X, vigra::MultiArrayView<2, int> & Y, vigra::MultiArrayView<2, double> & P);
90  };
91 }
92 
93 #endif
mitk::PURFClassifier
Definition: mitkPURFClassifier.h:26
mitk
DataCollection - Class to facilitate loading/accessing structured data.
Definition: RenderingTests.dox:1
mitk::BaseData
Base of all data objects.
Definition: mitkBaseData.h:42
mitk::AbstractClassifier
Definition: mitkAbstractClassifier.h:32
mitkBaseData.h
mitkClassMacro
#define mitkClassMacro(className, SuperClassName)
Definition: mitkCommon.h:36
mitkAbstractClassifier.h