Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
CLSimpleVoxelClassification.cpp
Go to the documentation of this file.
1 /*===================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
7 All rights reserved.
8 
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 A PARTICULAR PURPOSE.
12 
13 See LICENSE.txt or http://www.mitk.org for details.
14 
15 ===================================================================*/
16 
17 
18 #include <sstream>
19 
20 #include <mitkConfigFileReader.h>
21 #include <mitkDataCollection.h>
22 #include <mitkCollectionReader.h>
23 #include <mitkCollectionWriter.h>
25 #include <mitkCostingStatistic.h>
26 #include <vtkSmartPointer.h>
27 #include <mitkIOUtil.h>
28 
30 #include <mitkRandomForestIO.h>
32 
33 // CTK
34 #include "mitkCommandLineParser.h"
35 
36 
37 
38 
39 int main(int argc, char* argv[])
40 {
41  // Setup CLI Module parsable interface
42  mitkCommandLineParser parser;
43  parser.setTitle("Simple Random Forest Classifier");
44  parser.setCategory("Classification");
45  parser.setDescription("Learns and predicts classes");
46  parser.setContributor("MBI");
47 
48  parser.setArgumentPrefix("--", "-");
49  // Add command line argument names
50  parser.addArgument("help", "h", mitkCommandLineParser::Bool, "Show options");
51  parser.addArgument("loadFile", "l", mitkCommandLineParser::InputFile,
52  "DataCollection File");
53  parser.addArgument(
54  "colIds", "c", mitkCommandLineParser::String,
55  "Patient Identifiers from DataCollection used for training");
56  parser.addArgument("testId", "t", mitkCommandLineParser::String,
57  "Patient Identifier from DataCollection used for testing");
58  parser.addArgument("features", "b", mitkCommandLineParser::String,
59  "Features");
60  parser.addArgument("stats", "s", mitkCommandLineParser::String,
61  "Output file for stats");
62  parser.addArgument("treeDepth", "d", mitkCommandLineParser::Int,
63  "limits tree depth");
64  parser.addArgument("forestSize", "f", mitkCommandLineParser::Int,
65  "number of trees");
66  parser.addArgument("configName", "n", mitkCommandLineParser::String,
67  "human readable name for configuration");
69  "output folder for results");
70 
71  parser.addArgument("classmap", "m", mitkCommandLineParser::String,
72  "name of class that is to be learnt");
73 
74 
75  std::map<std::string, us::Any> parsedArgs = parser.parseArguments(argc, argv);
76  // Show a help message
77  if (parsedArgs.size()==0 || parsedArgs.count("help") || parsedArgs.count("h")) {
78  std::cout << parser.helpText();
79  return EXIT_SUCCESS;
80  }
81 
82  // Default values
83  bool useStatsFile = false;
84  unsigned int forestSize = 8;
85  unsigned int treeDepth = 10;
86  std::string configName = "";
87  std::string outputFolder = "";
88 
89  std::vector<std::string> features;
90  std::vector<std::string> trainingIds;
91  std::vector<std::string> testingIds;
92  std::vector<std::string> loadIds; // features + masks needed for training and evaluation
93  std::string outputFile;
94  std::string classMap;
95  std::string xmlFile;
96  std::ofstream experimentFS;
97 
98  // Parse input parameters
99  {
100  if (parsedArgs.count("colIds") || parsedArgs.count("c")) {
101  std::istringstream ss(us::any_cast<std::string>(parsedArgs["colIds"]));
102  std::string token;
103 
104  while (std::getline(ss, token, ','))
105  trainingIds.push_back(token);
106  }
107 
108  if (parsedArgs.count("output") || parsedArgs.count("o")) {
109  outputFolder = us::any_cast<std::string>(parsedArgs["output"]);
110  }
111 
112  if (parsedArgs.count("classmap") || parsedArgs.count("m")) {
113  classMap = us::any_cast<std::string>(parsedArgs["classmap"]);
114  }
115 
116  if (parsedArgs.count("configName") || parsedArgs.count("n")) {
117  configName = us::any_cast<std::string>(parsedArgs["configName"]);
118  }
119 
120  if (parsedArgs.count("features") || parsedArgs.count("b")) {
121  std::istringstream ss(us::any_cast<std::string>(parsedArgs["features"]));
122  std::string token;
123 
124  while (std::getline(ss, token, ','))
125  features.push_back(token);
126  }
127 
128  if (parsedArgs.count("treeDepth") || parsedArgs.count("d")) {
129  treeDepth = us::any_cast<int>(parsedArgs["treeDepth"]);
130  }
131 
132 
133  if (parsedArgs.count("forestSize") || parsedArgs.count("f")) {
134  forestSize = us::any_cast<int>(parsedArgs["forestSize"]);
135  }
136 
137  if (parsedArgs.count("stats") || parsedArgs.count("s")) {
138  useStatsFile = true;
139  experimentFS.open(us::any_cast<std::string>(parsedArgs["stats"]).c_str(),
140  std::ios_base::app);
141  }
142 
143 
144  if (parsedArgs.count("testId") || parsedArgs.count("t")) {
145  std::istringstream ss(us::any_cast<std::string>(parsedArgs["testId"]));
146  std::string token;
147 
148  while (std::getline(ss, token, ','))
149  testingIds.push_back(token);
150  }
151 
152  for (unsigned int i = 0; i < features.size(); i++) {
153  loadIds.push_back(features.at(i));
154  }
155  loadIds.push_back(classMap);
156 
157  if (parsedArgs.count("stats") || parsedArgs.count("s")) {
158  outputFile = us::any_cast<std::string>(parsedArgs["stats"]);
159  }
160 
161  if (parsedArgs.count("loadFile") || parsedArgs.count("l")) {
162  xmlFile = us::any_cast<std::string>(parsedArgs["loadFile"]);
163  } else {
164  MITK_ERROR << parser.helpText();
165  return EXIT_FAILURE;
166  }
167  }
168 
169  mitk::DataCollection::Pointer trainCollection;
170  mitk::DataCollection::Pointer testCollection;
171  {
172  mitk::CollectionReader colReader;
173  // Load only relevant images
174  colReader.SetDataItemNames(loadIds);
175  colReader.AddSubColIds(testingIds);
176  testCollection = colReader.LoadCollection(xmlFile);
177  colReader.ClearDataElementIds();
178  colReader.ClearSubColIds();
179  colReader.SetDataItemNames(loadIds);
180  colReader.AddSubColIds(trainingIds);
181  trainCollection = colReader.LoadCollection(xmlFile);
182  }
183 
185  // If required do Training....
187  //mitk::DecisionForest forest;
188 
190 
191  forest->SetTreeCount(forestSize);
192  forest->SetMaximumTreeDepth(treeDepth);
193 
194 
195  // create feature matrix
196  auto trainDataX = mitk::DCUtilities::DC3dDToMatrixXd(trainCollection, features, classMap);
197  // create label matrix
198  auto trainDataY = mitk::DCUtilities::DC3dDToMatrixXi(trainCollection, classMap, classMap);
199 
200  forest->Train(trainDataX, trainDataY);
201 
202 
203  // prepare feature matrix for test case
204  auto testDataX = mitk::DCUtilities::DC3dDToMatrixXd(testCollection,features, classMap);
205  auto testDataNewY = forest->Predict(testDataX);
206 
207 
208  mitk::DCUtilities::MatrixToDC3d(testDataNewY, testCollection, "RESULT", classMap);
209 
210  Eigen::MatrixXd Probs = forest->GetPointWiseProbabilities();
211 
212 
213  Eigen::MatrixXd prob0 = Probs.col(0);
214  Eigen::MatrixXd prob1 = Probs.col(1);
215 
216  mitk::DCUtilities::MatrixToDC3d(prob0, testCollection,"prob0", classMap);
217  mitk::DCUtilities::MatrixToDC3d(prob1, testCollection,"prob1", classMap);
218 
219 
220  std::vector<std::string> outputFilter;
222  outputFolder + "/result_collection.xml",
223  outputFilter);
224  return EXIT_SUCCESS;
225 }
DataCollection::Pointer LoadCollection(const std::string &xmlFileName)
Build up a mitk::DataCollection from a XML resource.
void SetDataItemNames(std::vector< std::string > itemNames)
#define MITK_ERROR
Definition: mitkLogMacros.h:24
void setContributor(std::string contributor)
int main(int argc, char *argv[])
ValueType * any_cast(Any *operand)
Definition: usAny.h:377
std::map< std::string, us::Any > parseArguments(const StringContainerType &arguments, bool *ok=nullptr)
void addArgument(const std::string &longarg, const std::string &shortarg, Type type, const std::string &argLabel, const std::string &argHelp=std::string(), const us::Any &defaultValue=us::Any(), bool optional=true, bool ignoreRest=false, bool deprecated=false)
static void MatrixToDC3d(const Eigen::MatrixXd &matrix, mitk::DataCollection::Pointer dc, const std::vector< std::string > &names, std::string mask)
static Eigen::MatrixXi DC3dDToMatrixXi(mitk::DataCollection::Pointer dc, std::string name, std::string mask)
void setCategory(std::string category)
static bool ExportCollectionToFolder(DataCollection *dataCollection, std::string xmlFile, std::vector< std::string > filter)
ExportCollectionToFolder.
void setArgumentPrefix(const std::string &longPrefix, const std::string &shortPrefix)
static Eigen::MatrixXd DC3dDToMatrixXd(mitk::DataCollection::Pointer dc, std::string names, std::string mask)
void AddSubColIds(std::vector< std::string > subColIds)
const char features[]
std::string helpText() const
void setTitle(std::string title)
void setDescription(std::string description)
static itkEventMacro(BoundingShapeInteractionEvent, itk::AnyEvent) class MITKBOUNDINGSHAPE_EXPORT BoundingShapeInteractor Pointer New()
Basic interaction methods for mitk::GeometryData.