Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
CLSimpleVoxelClassification.cpp
Go to the documentation of this file.
1 /*===================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
7 All rights reserved.
8 
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 A PARTICULAR PURPOSE.
12 
13 See LICENSE.txt or http://www.mitk.org for details.
14 
15 ===================================================================*/
16 
17 
18 #include <sstream>
19 
20 #include <mitkConfigFileReader.h>
21 #include <mitkDataCollection.h>
22 #include <mitkCollectionReader.h>
23 #include <mitkCollectionWriter.h>
25 #include <mitkCostingStatistic.h>
26 #include <vtkSmartPointer.h>
27 #include <mitkIOUtil.h>
28 
30 #include <mitkRandomForestIO.h>
32 
33 // CTK
34 #include "mitkCommandLineParser.h"
35 
36 
37 
38 
39 int main(int argc, char* argv[])
40 {
41  // Setup CLI Module parsable interface
42  mitkCommandLineParser parser;
43  parser.setTitle("Simple Random Forest Classifier");
44  parser.setCategory("Classification");
45  parser.setDescription("Learns and predicts classes");
46  parser.setContributor("MBI");
47 
48  parser.setArgumentPrefix("--", "-");
49  // Add command line argument names
50  parser.addArgument("help", "h", mitkCommandLineParser::Bool, "Show options");
51  parser.addArgument("loadFile", "l", mitkCommandLineParser::InputFile,
52  "DataCollection File");
53  parser.addArgument(
54  "colIds", "c", mitkCommandLineParser::String,
55  "Patient Identifiers from DataCollection used for training");
56  parser.addArgument("testId", "t", mitkCommandLineParser::String,
57  "Patient Identifier from DataCollection used for testing");
58  parser.addArgument("features", "b", mitkCommandLineParser::String,
59  "Features");
60  parser.addArgument("stats", "s", mitkCommandLineParser::String,
61  "Output file for stats");
62  parser.addArgument("treeDepth", "d", mitkCommandLineParser::Int,
63  "limits tree depth");
64  parser.addArgument("forestSize", "f", mitkCommandLineParser::Int,
65  "number of trees");
66  parser.addArgument("configName", "n", mitkCommandLineParser::String,
67  "human readable name for configuration");
69  "output folder for results");
70 
71  parser.addArgument("classmap", "m", mitkCommandLineParser::String,
72  "name of class that is to be learnt");
73 
74 
75  std::map<std::string, us::Any> parsedArgs = parser.parseArguments(argc, argv);
76  // Show a help message
77  if (parsedArgs.size()==0 || parsedArgs.count("help") || parsedArgs.count("h")) {
78  std::cout << parser.helpText();
79  return EXIT_SUCCESS;
80  }
81 
82  // Default values
83  bool useStatsFile = false;
84  unsigned int forestSize = 8;
85  unsigned int treeDepth = 10;
86  std::string configName = "";
87  std::string outputFolder = "";
88 
89  std::vector<std::string> features;
90  std::vector<std::string> trainingIds;
91  std::vector<std::string> testingIds;
92  std::vector<std::string> loadIds; // features + masks needed for training and evaluation
93  std::string outputFile;
94  std::string classMap;
95  std::string xmlFile;
96  std::ofstream experimentFS;
97 
98  // Parse input parameters
99  {
100  if (parsedArgs.count("colIds") || parsedArgs.count("c")) {
101  std::istringstream ss(us::any_cast<std::string>(parsedArgs["colIds"]));
102  std::string token;
103 
104  while (std::getline(ss, token, ','))
105  trainingIds.push_back(token);
106  }
107 
108  if (parsedArgs.count("output") || parsedArgs.count("o")) {
109  outputFolder = us::any_cast<std::string>(parsedArgs["output"]);
110  }
111 
112  if (parsedArgs.count("classmap") || parsedArgs.count("m")) {
113  classMap = us::any_cast<std::string>(parsedArgs["classmap"]);
114  }
115 
116  if (parsedArgs.count("configName") || parsedArgs.count("n")) {
117  configName = us::any_cast<std::string>(parsedArgs["configName"]);
118  }
119 
120  if (parsedArgs.count("features") || parsedArgs.count("b")) {
121  std::istringstream ss(us::any_cast<std::string>(parsedArgs["features"]));
122  std::string token;
123 
124  while (std::getline(ss, token, ','))
125  features.push_back(token);
126  }
127 
128  if (parsedArgs.count("treeDepth") || parsedArgs.count("d")) {
129  treeDepth = us::any_cast<int>(parsedArgs["treeDepth"]);
130  }
131 
132 
133  if (parsedArgs.count("forestSize") || parsedArgs.count("f")) {
134  forestSize = us::any_cast<int>(parsedArgs["forestSize"]);
135  }
136 
137  if (parsedArgs.count("stats") || parsedArgs.count("s")) {
138  useStatsFile = true;
139  experimentFS.open(us::any_cast<std::string>(parsedArgs["stats"]).c_str(),
140  std::ios_base::app);
141  }
142 
143 
144  if (parsedArgs.count("testId") || parsedArgs.count("t")) {
145  std::istringstream ss(us::any_cast<std::string>(parsedArgs["testId"]));
146  std::string token;
147 
148  while (std::getline(ss, token, ','))
149  testingIds.push_back(token);
150  }
151 
152  for (unsigned int i = 0; i < features.size(); i++) {
153  loadIds.push_back(features.at(i));
154  }
155  loadIds.push_back(classMap);
156 
157  if (parsedArgs.count("stats") || parsedArgs.count("s")) {
158  outputFile = us::any_cast<std::string>(parsedArgs["stats"]);
159  }
160 
161  if (parsedArgs.count("loadFile") || parsedArgs.count("l")) {
162  xmlFile = us::any_cast<std::string>(parsedArgs["loadFile"]);
163  } else {
164  MITK_ERROR << parser.helpText();
165  return EXIT_FAILURE;
166  }
167  }
168 
169  mitk::DataCollection::Pointer trainCollection;
170  mitk::DataCollection::Pointer testCollection;
171  {
172  mitk::CollectionReader colReader;
173  // Load only relevant images
174  colReader.SetDataItemNames(loadIds);
175  colReader.AddSubColIds(testingIds);
176  testCollection = colReader.LoadCollection(xmlFile);
177  colReader.ClearDataElementIds();
178  colReader.ClearSubColIds();
179  colReader.SetDataItemNames(loadIds);
180  colReader.AddSubColIds(trainingIds);
181  trainCollection = colReader.LoadCollection(xmlFile);
182  }
183 
185  // If required do Training....
187  //mitk::DecisionForest forest;
188 
190 
191  forest->SetTreeCount(forestSize);
192  forest->SetMaximumTreeDepth(treeDepth);
193 
194 
195  // create feature matrix
196  auto trainDataX = mitk::DCUtilities::DC3dDToMatrixXd(trainCollection, features, classMap);
197  // create label matrix
198  auto trainDataY = mitk::DCUtilities::DC3dDToMatrixXi(trainCollection, classMap, classMap);
199 
200  forest->Train(trainDataX, trainDataY);
201 
202 
203  // prepare feature matrix for test case
204  auto testDataX = mitk::DCUtilities::DC3dDToMatrixXd(testCollection,features, classMap);
205  auto testDataNewY = forest->Predict(testDataX);
206 
207 
208  mitk::DCUtilities::MatrixToDC3d(testDataNewY, testCollection, "RESULT", classMap);
209 
210  Eigen::MatrixXd Probs = forest->GetPointWiseProbabilities();
211 
212 
213  Eigen::MatrixXd prob0 = Probs.col(0);
214  Eigen::MatrixXd prob1 = Probs.col(1);
215 
216  mitk::DCUtilities::MatrixToDC3d(prob0, testCollection,"prob0", classMap);
217  mitk::DCUtilities::MatrixToDC3d(prob1, testCollection,"prob1", classMap);
218 
219 
220  std::vector<std::string> outputFilter;
222  outputFolder + "/result_collection.xml",
223  outputFilter);
224  return EXIT_SUCCESS;
225 }
DataCollection::Pointer LoadCollection(const std::string &xmlFileName)
Build up a mitk::DataCollection from a XML resource.
void SetDataItemNames(std::vector< std::string > itemNames)
#define MITK_ERROR
Definition: mitkLogMacros.h:24
void setContributor(std::string contributor)
int main(int argc, char *argv[])
ValueType * any_cast(Any *operand)
Definition: usAny.h:377
std::map< std::string, us::Any > parseArguments(const StringContainerType &arguments, bool *ok=nullptr)
void addArgument(const std::string &longarg, const std::string &shortarg, Type type, const std::string &argLabel, const std::string &argHelp=std::string(), const us::Any &defaultValue=us::Any(), bool optional=true, bool ignoreRest=false, bool deprecated=false)
static void MatrixToDC3d(const Eigen::MatrixXd &matrix, mitk::DataCollection::Pointer dc, const std::vector< std::string > &names, std::string mask)
static Eigen::MatrixXi DC3dDToMatrixXi(mitk::DataCollection::Pointer dc, std::string name, std::string mask)
void setCategory(std::string category)
static bool ExportCollectionToFolder(DataCollection *dataCollection, std::string xmlFile, std::vector< std::string > filter)
ExportCollectionToFolder.
void setArgumentPrefix(const std::string &longPrefix, const std::string &shortPrefix)
static Eigen::MatrixXd DC3dDToMatrixXd(mitk::DataCollection::Pointer dc, std::string names, std::string mask)
void AddSubColIds(std::vector< std::string > subColIds)
const char features[]
std::string helpText() const
void setTitle(std::string title)
void setDescription(std::string description)
static itkEventMacro(BoundingShapeInteractionEvent, itk::AnyEvent) class MITKBOUNDINGSHAPE_EXPORT BoundingShapeInteractor Pointer New()
Basic interaction methods for mitk::GeometryData.