Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
RandomForestTraining.cpp
Go to the documentation of this file.
1 /*===================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
7 All rights reserved.
8 
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 A PARTICULAR PURPOSE.
12 
13 See LICENSE.txt or http://www.mitk.org for details.
14 
15 ===================================================================*/
16 
17 #include <mitkCoreObjectFactory.h>
18 #include "mitkImage.h"
19 #include <boost/lexical_cast.hpp>
20 #include "mitkCommandLineParser.h"
21 #include <mitkIOUtil.h>
22 #include <itksys/SystemTools.hxx>
23 #include <mitkITKImageImport.h>
24 #include <mitkImageCast.h>
25 #include <mitkProperties.h>
26 
27 // ITK
28 #include <itkImageRegionIterator.h>
29 
30 // MITK
31 #include <mitkIOUtil.h>
32 
33 // Classification
34 #include <mitkCLUtil.h>
36 
37 #include <QDir>
38 #include <QString>
39 #include <QStringList>
40 
41 
42 using namespace mitk;
43 
47 int main(int argc, char* argv[])
48 {
49  mitkCommandLineParser parser;
50  parser.setArgumentPrefix("--", "-");
51 
52  // required params
53  parser.addArgument("inputdir", "i", mitkCommandLineParser::InputDirectory, "Input Directory", "Contains input feature files.", us::Any(), false);
54  parser.addArgument("outputdir", "o", mitkCommandLineParser::OutputDirectory, "Output Directory", "Destination of output files.", us::Any(), false);
55  parser.addArgument("classmask", "m", mitkCommandLineParser::InputFile, "Class mask image", "Contains several classes.", us::Any(), false);
56 
57  // optional params
58  parser.addArgument("select", "s", mitkCommandLineParser::String, "Item selection", "Using Regular expression, seperated by space e.g.: '*.nrrd *.vtk *test*'",std::string("*.nrrd"),true);
59  parser.addArgument("treecount", "tc", mitkCommandLineParser::Int, "Treecount", "Number of trees.",50,true);
60  parser.addArgument("treedepth", "td", mitkCommandLineParser::Int, "Treedepth", "Maximal tree depth.",50,true);
61  parser.addArgument("minsplitnodesize", "min", mitkCommandLineParser::Int, "Minimum split node size.", "Minimum split node size.",2,true);
62  parser.addArgument("precision", "p", mitkCommandLineParser::Float, "Split precision.", "Precision.", mitk::eps,true);
63  parser.addArgument("fraction", "f", mitkCommandLineParser::Float, "Fraction of samples per tree.", "Fraction of samples per tree.", 0.6f,true);
64  parser.addArgument("replacment", "r", mitkCommandLineParser::Bool, "Sample with replacement.", "Sample with replacement.", true,true);
65 
66  // Miniapp Infos
67  parser.setCategory("Classification Tools");
68  parser.setTitle("Random Forest Training");
69  parser.setDescription("Vigra RF impl.");
70  parser.setContributor("MBI");
71 
72  // Params parsing
73  std::map<std::string, us::Any> parsedArgs = parser.parseArguments(argc, argv);
74  if (parsedArgs.size()==0)
75  return EXIT_FAILURE;
76 
77  std::string inputdir = us::any_cast<std::string>(parsedArgs["inputdir"]);
78  std::string outputdir = us::any_cast<std::string>(parsedArgs["outputdir"]);
79  std::string classmask = us::any_cast<std::string>(parsedArgs["classmask"]);
80 
81  int treecount = parsedArgs.count("treecount") ? us::any_cast<int>(parsedArgs["treecount"]) : 50;
82  int treedepth = parsedArgs.count("treedepth") ? us::any_cast<int>(parsedArgs["treedepth"]) : 50;
83  int minsplitnodesize = parsedArgs.count("minsplitnodesize") ? us::any_cast<int>(parsedArgs["minsplitnodesize"]) : 2;
84  float precision = parsedArgs.count("precision") ? us::any_cast<float>(parsedArgs["precision"]) : mitk::eps;
85  float fraction = parsedArgs.count("fraction") ? us::any_cast<float>(parsedArgs["fraction"]) : 0.6;
86  bool withreplacement = parsedArgs.count("replacment") ? us::any_cast<float>(parsedArgs["replacment"]) : true;
87  std::string filt_select =/* parsedArgs.count("select") ? us::any_cast<std::string>(parsedArgs["select"]) :*/ "*.nrrd";
88 
89  QString filter(filt_select.c_str());
90 
91  // **** in principle repeat this block to create a feature matrix X_all for all subjects (in dir)
92  // Get nrrd filepath
93  QDir dir(inputdir.c_str());
94  auto strl = dir.entryList(filter.split(" "),QDir::Files);
95 
96  // load class mask
98  unsigned int num_samples = 0;
99  mitk::CLUtil::CountVoxel(mask,num_samples);
100 
101  // initialize featurematrix [num_samples, num_featureimages]
102  Eigen::MatrixXd X(num_samples, strl.size());
103 
104  for(int i = 0 ; i < strl.size(); i++)
105  {
106  // load feature image
107  mitk::Image::Pointer img = mitk::IOUtil::LoadImage(inputdir + strl[i].toStdString());
108  // transfom it into a [num_samples, 1] vector depending on the classmask
109  Eigen::MatrixXd _x = mitk::CLUtil::Transform<double>(img,mask);
110  // replace i-th (empty) col with feature vector in _x
111  X.block(0,i,num_samples,1) = _x;
112  }
113  // ****
114 
115  // transform classmask into the label-vector [num_samples, 1]
116  Eigen::MatrixXi Y = mitk::CLUtil::Transform<int>(mask,mask);
117 
119  classifier->SetTreeCount(treecount);
120  classifier->SetMaximumTreeDepth(treedepth);
121  classifier->SetMinimumSplitNodeSize(minsplitnodesize);
122  classifier->SetPrecision(precision);
123  classifier->SetSamplesPerTree(fraction);
124  classifier->UseSampleWithReplacement(withreplacement);
125 
126  classifier->PrintParameter();
127  classifier->Train(X,Y);
128 
129  MITK_INFO << classifier->IsEmpty();
130 
131  // no metainformations are saved currently
132  // only the raw vigra rf data
133  mitk::IOUtil::Save(classifier, outputdir + "RandomForest.hdf5");
134 
135  Eigen::MatrixXi Y_pred = classifier->Predict(X);
136  Eigen::MatrixXd Probs = classifier->GetPointWiseProbabilities();
137 
138  MITK_INFO << Y_pred.rows() << " " << Y_pred.cols();
139  MITK_INFO << Probs.rows() << " " << Probs.cols();
140 
141  // mitk::Image::Pointer prediction = mitk::CLUtil::Transform<int>(Y_pred,mask);
142  mitk::Image::Pointer probs_1 = mitk::CLUtil::Transform<double>(Probs.col(0),mask);
143  mitk::Image::Pointer probs_2 = mitk::CLUtil::Transform<double>(Probs.col(1),mask);
144  mitk::Image::Pointer probs_3 = mitk::CLUtil::Transform<double>(Probs.col(2),mask);
145 
146  mitk::IOUtil::Save(probs_1, outputdir + "probs_1.nrrd");
147  mitk::IOUtil::Save(probs_2, outputdir + "probs_2.nrrd");
148  mitk::IOUtil::Save(probs_3, outputdir + "probs_3.nrrd");
149  // mitk::IOUtil::Save(probs_2, outputdir + "test.nrrd");
150 
151  return EXIT_SUCCESS;
152 }
static void Save(const mitk::BaseData *data, const std::string &path)
Save a mitk::BaseData instance.
Definition: mitkIOUtil.cpp:824
#define MITK_INFO
Definition: mitkLogMacros.h:22
void setContributor(std::string contributor)
DataCollection - Class to facilitate loading/accessing structured data.
ValueType * any_cast(Any *operand)
Definition: usAny.h:377
std::map< std::string, us::Any > parseArguments(const StringContainerType &arguments, bool *ok=nullptr)
void addArgument(const std::string &longarg, const std::string &shortarg, Type type, const std::string &argLabel, const std::string &argHelp=std::string(), const us::Any &defaultValue=us::Any(), bool optional=true, bool ignoreRest=false, bool deprecated=false)
Definition: usAny.h:163
void setCategory(std::string category)
void setArgumentPrefix(const std::string &longPrefix, const std::string &shortPrefix)
MITKCORE_EXPORT const ScalarType eps
int main(int argc, char *argv[])
void setTitle(std::string title)
static void CountVoxel(mitk::Image::Pointer image, std::map< unsigned int, unsigned int > &map)
CountVoxel.
Definition: mitkCLUtil.cpp:69
void setDescription(std::string description)
static mitk::Image::Pointer LoadImage(const std::string &path)
LoadImage Convenience method to load an arbitrary mitkImage.
Definition: mitkIOUtil.cpp:597
static itkEventMacro(BoundingShapeInteractionEvent, itk::AnyEvent) class MITKBOUNDINGSHAPE_EXPORT BoundingShapeInteractor Pointer New()
Basic interaction methods for mitk::GeometryData.