Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
RandomForestTraining.cpp
Go to the documentation of this file.
1 /*===================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
7 All rights reserved.
8 
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 A PARTICULAR PURPOSE.
12 
13 See LICENSE.txt or http://www.mitk.org for details.
14 
15 ===================================================================*/
16 
17 #include <mitkCoreObjectFactory.h>
18 #include "mitkImage.h"
19 #include <boost/lexical_cast.hpp>
20 #include "mitkCommandLineParser.h"
21 #include <mitkIOUtil.h>
22 #include <itksys/SystemTools.hxx>
23 #include <mitkITKImageImport.h>
24 #include <mitkImageCast.h>
25 #include <mitkProperties.h>
26 
27 // ITK
28 #include <itkImageRegionIterator.h>
29 
30 // MITK
31 #include <mitkIOUtil.h>
32 
33 // Classification
34 #include <mitkCLUtil.h>
36 
37 #include <QDir>
38 #include <QString>
39 #include <QStringList>
40 
41 
42 using namespace mitk;
43 
47 int main(int argc, char* argv[])
48 {
49  mitkCommandLineParser parser;
50  parser.setArgumentPrefix("--", "-");
51 
52  // required params
53  parser.addArgument("inputdir", "i", mitkCommandLineParser::InputDirectory, "Input Directory", "Contains input feature files.", us::Any(), false);
54  parser.addArgument("outputdir", "o", mitkCommandLineParser::OutputDirectory, "Output Directory", "Destination of output files.", us::Any(), false);
55  parser.addArgument("classmask", "m", mitkCommandLineParser::InputFile, "Class mask image", "Contains several classes.", us::Any(), false);
56 
57  // optional params
58  parser.addArgument("select", "s", mitkCommandLineParser::String, "Item selection", "Using Regular expression, seperated by space e.g.: '*.nrrd *.vtk *test*'",std::string("*.nrrd"),true);
59  parser.addArgument("treecount", "tc", mitkCommandLineParser::Int, "Treecount", "Number of trees.",50,true);
60  parser.addArgument("treedepth", "td", mitkCommandLineParser::Int, "Treedepth", "Maximal tree depth.",50,true);
61  parser.addArgument("minsplitnodesize", "min", mitkCommandLineParser::Int, "Minimum split node size.", "Minimum split node size.",2,true);
62  parser.addArgument("precision", "p", mitkCommandLineParser::Float, "Split precision.", "Precision.", mitk::eps,true);
63  parser.addArgument("fraction", "f", mitkCommandLineParser::Float, "Fraction of samples per tree.", "Fraction of samples per tree.", 0.6f,true);
64  parser.addArgument("replacment", "r", mitkCommandLineParser::Bool, "Sample with replacement.", "Sample with replacement.", true,true);
65 
66  // Miniapp Infos
67  parser.setCategory("Classification Tools");
68  parser.setTitle("Random Forest Training");
69  parser.setDescription("Vigra RF impl.");
70  parser.setContributor("MBI");
71 
72  // Params parsing
73  std::map<std::string, us::Any> parsedArgs = parser.parseArguments(argc, argv);
74  if (parsedArgs.size()==0)
75  return EXIT_FAILURE;
76 
77  std::string inputdir = us::any_cast<std::string>(parsedArgs["inputdir"]);
78  std::string outputdir = us::any_cast<std::string>(parsedArgs["outputdir"]);
79  std::string classmask = us::any_cast<std::string>(parsedArgs["classmask"]);
80 
81  int treecount = parsedArgs.count("treecount") ? us::any_cast<int>(parsedArgs["treecount"]) : 50;
82  int treedepth = parsedArgs.count("treedepth") ? us::any_cast<int>(parsedArgs["treedepth"]) : 50;
83  int minsplitnodesize = parsedArgs.count("minsplitnodesize") ? us::any_cast<int>(parsedArgs["minsplitnodesize"]) : 2;
84  float precision = parsedArgs.count("precision") ? us::any_cast<float>(parsedArgs["precision"]) : mitk::eps;
85  float fraction = parsedArgs.count("fraction") ? us::any_cast<float>(parsedArgs["fraction"]) : 0.6;
86  bool withreplacement = parsedArgs.count("replacment") ? us::any_cast<float>(parsedArgs["replacment"]) : true;
87  std::string filt_select =/* parsedArgs.count("select") ? us::any_cast<std::string>(parsedArgs["select"]) :*/ "*.nrrd";
88 
89  QString filter(filt_select.c_str());
90 
91  // **** in principle repeat this block to create a feature matrix X_all for all subjects (in dir)
92  // Get nrrd filepath
93  QDir dir(inputdir.c_str());
94  auto strl = dir.entryList(filter.split(" "),QDir::Files);
95 
96  // load class mask
98  unsigned int num_samples = 0;
99  mitk::CLUtil::CountVoxel(mask,num_samples);
100 
101  // initialize featurematrix [num_samples, num_featureimages]
102  Eigen::MatrixXd X(num_samples, strl.size());
103 
104  for(int i = 0 ; i < strl.size(); i++)
105  {
106  // load feature image
107  mitk::Image::Pointer img = mitk::IOUtil::LoadImage(inputdir + strl[i].toStdString());
108  // transfom it into a [num_samples, 1] vector depending on the classmask
109  Eigen::MatrixXd _x = mitk::CLUtil::Transform<double>(img,mask);
110  // replace i-th (empty) col with feature vector in _x
111  X.block(0,i,num_samples,1) = _x;
112  }
113  // ****
114 
115  // transform classmask into the label-vector [num_samples, 1]
116  Eigen::MatrixXi Y = mitk::CLUtil::Transform<int>(mask,mask);
117 
119  classifier->SetTreeCount(treecount);
120  classifier->SetMaximumTreeDepth(treedepth);
121  classifier->SetMinimumSplitNodeSize(minsplitnodesize);
122  classifier->SetPrecision(precision);
123  classifier->SetSamplesPerTree(fraction);
124  classifier->UseSampleWithReplacement(withreplacement);
125 
126  classifier->PrintParameter();
127  classifier->Train(X,Y);
128 
129  MITK_INFO << classifier->IsEmpty();
130 
131  // no metainformations are saved currently
132  // only the raw vigra rf data
133  mitk::IOUtil::Save(classifier, outputdir + "RandomForest.hdf5");
134 
135  Eigen::MatrixXi Y_pred = classifier->Predict(X);
136  Eigen::MatrixXd Probs = classifier->GetPointWiseProbabilities();
137 
138  MITK_INFO << Y_pred.rows() << " " << Y_pred.cols();
139  MITK_INFO << Probs.rows() << " " << Probs.cols();
140 
141  // mitk::Image::Pointer prediction = mitk::CLUtil::Transform<int>(Y_pred,mask);
142  mitk::Image::Pointer probs_1 = mitk::CLUtil::Transform<double>(Probs.col(0),mask);
143  mitk::Image::Pointer probs_2 = mitk::CLUtil::Transform<double>(Probs.col(1),mask);
144  mitk::Image::Pointer probs_3 = mitk::CLUtil::Transform<double>(Probs.col(2),mask);
145 
146  mitk::IOUtil::Save(probs_1, outputdir + "probs_1.nrrd");
147  mitk::IOUtil::Save(probs_2, outputdir + "probs_2.nrrd");
148  mitk::IOUtil::Save(probs_3, outputdir + "probs_3.nrrd");
149  // mitk::IOUtil::Save(probs_2, outputdir + "test.nrrd");
150 
151  return EXIT_SUCCESS;
152 }
static void Save(const mitk::BaseData *data, const std::string &path)
Save a mitk::BaseData instance.
Definition: mitkIOUtil.cpp:824
#define MITK_INFO
Definition: mitkLogMacros.h:22
void setContributor(std::string contributor)
DataCollection - Class to facilitate loading/accessing structured data.
ValueType * any_cast(Any *operand)
Definition: usAny.h:377
std::map< std::string, us::Any > parseArguments(const StringContainerType &arguments, bool *ok=nullptr)
void addArgument(const std::string &longarg, const std::string &shortarg, Type type, const std::string &argLabel, const std::string &argHelp=std::string(), const us::Any &defaultValue=us::Any(), bool optional=true, bool ignoreRest=false, bool deprecated=false)
Definition: usAny.h:163
void setCategory(std::string category)
void setArgumentPrefix(const std::string &longPrefix, const std::string &shortPrefix)
MITKCORE_EXPORT const ScalarType eps
int main(int argc, char *argv[])
void setTitle(std::string title)
static void CountVoxel(mitk::Image::Pointer image, std::map< unsigned int, unsigned int > &map)
CountVoxel.
Definition: mitkCLUtil.cpp:69
void setDescription(std::string description)
static mitk::Image::Pointer LoadImage(const std::string &path)
LoadImage Convenience method to load an arbitrary mitkImage.
Definition: mitkIOUtil.cpp:597
static itkEventMacro(BoundingShapeInteractionEvent, itk::AnyEvent) class MITKBOUNDINGSHAPE_EXPORT BoundingShapeInteractor Pointer New()
Basic interaction methods for mitk::GeometryData.