Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
mitkVigraRandomForestClassifier.h
Go to the documentation of this file.
1 /*===================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
7 All rights reserved.
8 
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 A PARTICULAR PURPOSE.
12 
13 See LICENSE.txt or http://www.mitk.org for details.
14 
15 ===================================================================*/
16 
17 #ifndef mitkVigraRandomForestClassifier_h
18 #define mitkVigraRandomForestClassifier_h
19 
20 #include <MitkCLVigraRandomForestExports.h>
21 #include <mitkAbstractClassifier.h>
22 
23 //#include <vigra/multi_array.hxx>
24 #include <vigra/random_forest.hxx>
25 
26 #include <mitkBaseData.h>
27 
28 namespace mitk
29 {
30  class MITKCLVIGRARANDOMFOREST_EXPORT VigraRandomForestClassifier : public AbstractClassifier
31  {
32  public:
33 
35  itkFactorylessNewMacro(Self)
36  itkCloneMacro(Self)
37 
39 
41 
42  void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y);
43  void OnlineTrain(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y);
44  Eigen::MatrixXi Predict(const Eigen::MatrixXd &X);
45  Eigen::MatrixXi PredictWeighted(const Eigen::MatrixXd &X);
46 
47 
48  bool SupportsPointWiseWeight();
49  bool SupportsPointWiseProbability();
50  void ConvertParameter();
51 
52  void SetRandomForest(const vigra::RandomForest<int> & rf);
53  const vigra::RandomForest<int> & GetRandomForest() const;
54 
55  void UsePointWiseWeight(bool);
56  void SetMaximumTreeDepth(int);
57  void SetMinimumSplitNodeSize(int);
58  void SetPrecision(double);
59  void SetSamplesPerTree(double);
60  void UseSampleWithReplacement(bool);
61  void SetTreeCount(int);
62  void SetWeightLambda(double);
63 
64  void SetTreeWeights(Eigen::MatrixXd weights);
65  void SetTreeWeight(int treeId, double weight);
66  Eigen::MatrixXd GetTreeWeights() const;
67 
68  void PrintParameter(std::ostream &str = std::cout);
69 
70  private:
71  // *-------------------
72  // * THREADING
73  // *-------------------
74 
75 
76  struct TrainingData;
77  struct PredictionData;
78  struct EigenToVigraTransform;
79  struct Parameter;
80 
81  Eigen::MatrixXd m_TreeWeights;
82 
83  Parameter * m_Parameter;
84  vigra::RandomForest<int> m_RandomForest;
85 
86  static ITK_THREAD_RETURN_TYPE TrainTreesCallback(void *);
87  static ITK_THREAD_RETURN_TYPE PredictCallback(void *);
88  static ITK_THREAD_RETURN_TYPE PredictWeightedCallback(void *);
89  static void VigraPredictWeighted(PredictionData *data, vigra::MultiArrayView<2, double> & X, vigra::MultiArrayView<2, int> & Y, vigra::MultiArrayView<2, double> & P);
90  };
91 }
92 
93 #endif //mitkVigraRandomForestClassifier_h
Base of all data objects.
Definition: mitkBaseData.h:39
DataCollection - Class to facilitate loading/accessing structured data.
#define mitkClassMacro(className, SuperClassName)
Definition: mitkCommon.h:44