Medical Imaging Interaction Toolkit  2018.4.99-389bf124
Medical Imaging Interaction Toolkit
RandomForestTraining.cpp
Go to the documentation of this file.
1 /*============================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center (DKFZ)
6 All rights reserved.
7 
8 Use of this source code is governed by a 3-clause BSD license that can be
9 found in the LICENSE file.
10 
11 ============================================================================*/
12 
13 #include <mitkCoreObjectFactory.h>
14 #include "mitkImage.h"
15 #include <mitkLexicalCast.h>
16 #include "mitkCommandLineParser.h"
17 #include <mitkIOUtil.h>
18 #include <itksys/SystemTools.hxx>
19 #include <mitkITKImageImport.h>
20 #include <mitkImageCast.h>
21 #include <mitkProperties.h>
22 
23 // ITK
24 #include <itkImageRegionIterator.h>
25 
26 // MITK
27 #include <mitkIOUtil.h>
28 
29 // Classification
30 #include <mitkCLUtil.h>
32 
33 #include <QDir>
34 #include <QString>
35 #include <QStringList>
36 
37 
38 using namespace mitk;
39 
43 int main(int argc, char* argv[])
44 {
45  mitkCommandLineParser parser;
46  parser.setArgumentPrefix("--", "-");
47 
48  // required params
49  parser.addArgument("inputdir", "i", mitkCommandLineParser::Directory, "Input Directory", "Contains input feature files.", us::Any(), false, false, false, mitkCommandLineParser::Input);
50  parser.addArgument("outputdir", "o", mitkCommandLineParser::Directory, "Output Directory", "Destination of output files.", us::Any(), false, false, false, mitkCommandLineParser::Output);
51  parser.addArgument("classmask", "m", mitkCommandLineParser::File, "Class mask image", "Contains several classes.", us::Any(), false, false, false, mitkCommandLineParser::Input);
52 
53  // optional params
54  parser.addArgument("select", "s", mitkCommandLineParser::String, "Item selection", "Using Regular expression, seperated by space e.g.: '*.nrrd *.vtk *test*'",std::string("*.nrrd"),true);
55  parser.addArgument("treecount", "tc", mitkCommandLineParser::Int, "Treecount", "Number of trees.",50,true);
56  parser.addArgument("treedepth", "td", mitkCommandLineParser::Int, "Treedepth", "Maximal tree depth.",50,true);
57  parser.addArgument("minsplitnodesize", "min", mitkCommandLineParser::Int, "Minimum split node size.", "Minimum split node size.",2,true);
58  parser.addArgument("precision", "p", mitkCommandLineParser::Float, "Split precision.", "Precision.", mitk::eps,true);
59  parser.addArgument("fraction", "f", mitkCommandLineParser::Float, "Fraction of samples per tree.", "Fraction of samples per tree.", 0.6f,true);
60  parser.addArgument("replacment", "r", mitkCommandLineParser::Bool, "Sample with replacement.", "Sample with replacement.", true,true);
61 
62  // Miniapp Infos
63  parser.setCategory("Classification Tools");
64  parser.setTitle("Random Forest Training");
65  parser.setDescription("Vigra RF impl.");
66  parser.setContributor("German Cancer Research Center (DKFZ)");
67 
68  // Params parsing
69  std::map<std::string, us::Any> parsedArgs = parser.parseArguments(argc, argv);
70  if (parsedArgs.size()==0)
71  return EXIT_FAILURE;
72 
73  std::string inputdir = us::any_cast<std::string>(parsedArgs["inputdir"]);
74  std::string outputdir = us::any_cast<std::string>(parsedArgs["outputdir"]);
75  std::string classmask = us::any_cast<std::string>(parsedArgs["classmask"]);
76 
77  int treecount = parsedArgs.count("treecount") ? us::any_cast<int>(parsedArgs["treecount"]) : 50;
78  int treedepth = parsedArgs.count("treedepth") ? us::any_cast<int>(parsedArgs["treedepth"]) : 50;
79  int minsplitnodesize = parsedArgs.count("minsplitnodesize") ? us::any_cast<int>(parsedArgs["minsplitnodesize"]) : 2;
80  float precision = parsedArgs.count("precision") ? us::any_cast<float>(parsedArgs["precision"]) : mitk::eps;
81  float fraction = parsedArgs.count("fraction") ? us::any_cast<float>(parsedArgs["fraction"]) : 0.6;
82  bool withreplacement = parsedArgs.count("replacment") ? us::any_cast<float>(parsedArgs["replacment"]) : true;
83  std::string filt_select =/* parsedArgs.count("select") ? us::any_cast<std::string>(parsedArgs["select"]) :*/ "*.nrrd";
84 
85  QString filter(filt_select.c_str());
86 
87  // **** in principle repeat this block to create a feature matrix X_all for all subjects (in dir)
88  // Get nrrd filepath
89  QDir dir(inputdir.c_str());
90  auto strl = dir.entryList(filter.split(" "),QDir::Files);
91 
92  // load class mask
93  mitk::Image::Pointer mask = mitk::IOUtil::Load<mitk::Image>(classmask);
94  unsigned int num_samples = 0;
95  mitk::CLUtil::CountVoxel(mask,num_samples);
96 
97  // initialize featurematrix [num_samples, num_featureimages]
98  Eigen::MatrixXd X(num_samples, strl.size());
99 
100  for(int i = 0 ; i < strl.size(); i++)
101  {
102  // load feature image
103  mitk::Image::Pointer img = mitk::IOUtil::Load<mitk::Image>(inputdir + strl[i].toStdString());
104  // transfom it into a [num_samples, 1] vector depending on the classmask
105  Eigen::MatrixXd _x = mitk::CLUtil::Transform<double>(img,mask);
106  // replace i-th (empty) col with feature vector in _x
107  X.block(0,i,num_samples,1) = _x;
108  }
109  // ****
110 
111  // transform classmask into the label-vector [num_samples, 1]
112  Eigen::MatrixXi Y = mitk::CLUtil::Transform<int>(mask,mask);
113 
115  classifier->SetTreeCount(treecount);
116  classifier->SetMaximumTreeDepth(treedepth);
117  classifier->SetMinimumSplitNodeSize(minsplitnodesize);
118  classifier->SetPrecision(precision);
119  classifier->SetSamplesPerTree(fraction);
120  classifier->UseSampleWithReplacement(withreplacement);
121 
122  classifier->PrintParameter();
123  classifier->Train(X,Y);
124 
125  MITK_INFO << classifier->IsEmpty();
126 
127  // no metainformations are saved currently
128  // only the raw vigra rf data
129  mitk::IOUtil::Save(classifier, outputdir + "RandomForest.hdf5");
130 
131  Eigen::MatrixXi Y_pred = classifier->Predict(X);
132  Eigen::MatrixXd Probs = classifier->GetPointWiseProbabilities();
133 
134  MITK_INFO << Y_pred.rows() << " " << Y_pred.cols();
135  MITK_INFO << Probs.rows() << " " << Probs.cols();
136 
137  // mitk::Image::Pointer prediction = mitk::CLUtil::Transform<int>(Y_pred,mask);
138  mitk::Image::Pointer probs_1 = mitk::CLUtil::Transform<double>(Probs.col(0),mask);
139  mitk::Image::Pointer probs_2 = mitk::CLUtil::Transform<double>(Probs.col(1),mask);
140  mitk::Image::Pointer probs_3 = mitk::CLUtil::Transform<double>(Probs.col(2),mask);
141 
142  mitk::IOUtil::Save(probs_1, outputdir + "probs_1.nrrd");
143  mitk::IOUtil::Save(probs_2, outputdir + "probs_2.nrrd");
144  mitk::IOUtil::Save(probs_3, outputdir + "probs_3.nrrd");
145  // mitk::IOUtil::Save(probs_2, outputdir + "test.nrrd");
146 
147  return EXIT_SUCCESS;
148 }
#define MITK_INFO
Definition: mitkLogMacros.h:18
void setContributor(std::string contributor)
DataCollection - Class to facilitate loading/accessing structured data.
ValueType * any_cast(Any *operand)
Definition: usAny.h:377
void addArgument(const std::string &longarg, const std::string &shortarg, Type type, const std::string &argLabel, const std::string &argHelp=std::string(), const us::Any &defaultValue=us::Any(), bool optional=true, bool ignoreRest=false, bool deprecated=false, mitkCommandLineParser::Channel channel=mitkCommandLineParser::Channel::None)
std::map< std::string, us::Any > parseArguments(const StringContainerType &arguments, bool *ok=nullptr)
Definition: usAny.h:163
void setCategory(std::string category)
void setArgumentPrefix(const std::string &longPrefix, const std::string &shortPrefix)
static void Save(const mitk::BaseData *data, const std::string &path, bool setPathProperty=false)
Save a mitk::BaseData instance.
Definition: mitkIOUtil.cpp:774
MITKCORE_EXPORT const ScalarType eps
mitk::Image::Pointer mask
int main(int argc, char *argv[])
void setTitle(std::string title)
static void CountVoxel(mitk::Image::Pointer image, std::map< unsigned int, unsigned int > &map)
CountVoxel.
Definition: mitkCLUtil.cpp:88
void setDescription(std::string description)