Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
mitkVigraRandomForestClassifier.h
Go to the documentation of this file.
1 /*===================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
7 All rights reserved.
8 
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 A PARTICULAR PURPOSE.
12 
13 See LICENSE.txt or http://www.mitk.org for details.
14 
15 ===================================================================*/
16 
17 #ifndef mitkVigraRandomForestClassifier_h
18 #define mitkVigraRandomForestClassifier_h
19 
20 #include <MitkCLVigraRandomForestExports.h>
21 #include <mitkAbstractClassifier.h>
22 
23 //#include <vigra/multi_array.hxx>
24 #include <vigra/random_forest.hxx>
25 
26 #include <mitkBaseData.h>
27 
28 namespace mitk
29 {
30  class MITKCLVIGRARANDOMFOREST_EXPORT VigraRandomForestClassifier : public AbstractClassifier
31  {
32  public:
33 
35  itkFactorylessNewMacro(Self)
36  itkCloneMacro(Self)
37 
39 
41 
42  void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y);
43  void OnlineTrain(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y);
44  Eigen::MatrixXi Predict(const Eigen::MatrixXd &X);
45  Eigen::MatrixXi PredictWeighted(const Eigen::MatrixXd &X);
46 
47 
48  bool SupportsPointWiseWeight();
49  bool SupportsPointWiseProbability();
50  void ConvertParameter();
51 
52  void SetRandomForest(const vigra::RandomForest<int> & rf);
53  const vigra::RandomForest<int> & GetRandomForest() const;
54 
55  void UsePointWiseWeight(bool);
56  void SetMaximumTreeDepth(int);
57  void SetMinimumSplitNodeSize(int);
58  void SetPrecision(double);
59  void SetSamplesPerTree(double);
60  void UseSampleWithReplacement(bool);
61  void SetTreeCount(int);
62  void SetWeightLambda(double);
63 
64  void SetTreeWeights(Eigen::MatrixXd weights);
65  void SetTreeWeight(int treeId, double weight);
66  Eigen::MatrixXd GetTreeWeights() const;
67 
68  void PrintParameter(std::ostream &str = std::cout);
69 
70  private:
71  // *-------------------
72  // * THREADING
73  // *-------------------
74 
75 
76  struct TrainingData;
77  struct PredictionData;
78  struct EigenToVigraTransform;
79  struct Parameter;
80 
81  Eigen::MatrixXd m_TreeWeights;
82 
83  Parameter * m_Parameter;
84  vigra::RandomForest<int> m_RandomForest;
85 
86  static ITK_THREAD_RETURN_TYPE TrainTreesCallback(void *);
87  static ITK_THREAD_RETURN_TYPE PredictCallback(void *);
88  static ITK_THREAD_RETURN_TYPE PredictWeightedCallback(void *);
89  static void VigraPredictWeighted(PredictionData *data, vigra::MultiArrayView<2, double> & X, vigra::MultiArrayView<2, int> & Y, vigra::MultiArrayView<2, double> & P);
90  };
91 }
92 
93 #endif //mitkVigraRandomForestClassifier_h
Base of all data objects.
Definition: mitkBaseData.h:39
DataCollection - Class to facilitate loading/accessing structured data.
#define mitkClassMacro(className, SuperClassName)
Definition: mitkCommon.h:44