Medical Imaging Interaction Toolkit  2018.4.99-389bf124
Medical Imaging Interaction Toolkit
CLSimpleVoxelClassification.cpp
Go to the documentation of this file.
1 /*============================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center (DKFZ)
6 All rights reserved.
7 
8 Use of this source code is governed by a 3-clause BSD license that can be
9 found in the LICENSE file.
10 
11 ============================================================================*/
12 
13 
14 #include <sstream>
15 
16 #include <mitkConfigFileReader.h>
17 #include <mitkDataCollection.h>
18 #include <mitkCollectionReader.h>
19 #include <mitkCollectionWriter.h>
21 #include <mitkCostingStatistic.h>
22 #include <vtkSmartPointer.h>
23 #include <mitkIOUtil.h>
24 
26 #include <mitkRandomForestIO.h>
28 
29 // CTK
30 #include "mitkCommandLineParser.h"
31 
32 
33 
34 
35 int main(int argc, char* argv[])
36 {
37  // Setup CLI Module parsable interface
38  mitkCommandLineParser parser;
39  parser.setTitle("Simple Random Forest Classifier");
40  parser.setCategory("Classification");
41  parser.setDescription("Learns and predicts classes");
42  parser.setContributor("German Cancer Research Center (DKFZ)");
43 
44  parser.setArgumentPrefix("--", "-");
45  // Add command line argument names
46  parser.addArgument("help", "h", mitkCommandLineParser::Bool, "Show options");
47  parser.addArgument("loadFile", "l", mitkCommandLineParser::File,
48  "DataCollection File", "", us::Any(), true, false, false, mitkCommandLineParser::Input);
49  parser.addArgument(
50  "colIds", "c", mitkCommandLineParser::String,
51  "Patient Identifiers from DataCollection used for training");
52  parser.addArgument("testId", "t", mitkCommandLineParser::String,
53  "Patient Identifier from DataCollection used for testing");
54  parser.addArgument("features", "b", mitkCommandLineParser::String,
55  "Features");
56  parser.addArgument("stats", "s", mitkCommandLineParser::String,
57  "Output file for stats");
58  parser.addArgument("treeDepth", "d", mitkCommandLineParser::Int,
59  "limits tree depth");
60  parser.addArgument("forestSize", "f", mitkCommandLineParser::Int,
61  "number of trees");
62  parser.addArgument("configName", "n", mitkCommandLineParser::String,
63  "human readable name for configuration");
64  parser.addArgument("output", "o", mitkCommandLineParser::Directory,
65  "output folder for results", "", us::Any(), true, false, false, mitkCommandLineParser::Input);
66 
67  parser.addArgument("classmap", "m", mitkCommandLineParser::String,
68  "name of class that is to be learnt");
69 
70 
71  std::map<std::string, us::Any> parsedArgs = parser.parseArguments(argc, argv);
72  // Show a help message
73  if (parsedArgs.size()==0 || parsedArgs.count("help") || parsedArgs.count("h")) {
74  std::cout << parser.helpText();
75  return EXIT_SUCCESS;
76  }
77 
78  // Default values
79  unsigned int forestSize = 8;
80  unsigned int treeDepth = 10;
81  std::string configName = "";
82  std::string outputFolder = "";
83 
84  std::vector<std::string> features;
85  std::vector<std::string> trainingIds;
86  std::vector<std::string> testingIds;
87  std::vector<std::string> loadIds; // features + masks needed for training and evaluation
88  std::string outputFile;
89  std::string classMap;
90  std::string xmlFile;
91  std::ofstream experimentFS;
92 
93  // Parse input parameters
94  {
95  if (parsedArgs.count("colIds") || parsedArgs.count("c")) {
96  std::istringstream ss(us::any_cast<std::string>(parsedArgs["colIds"]));
97  std::string token;
98 
99  while (std::getline(ss, token, ','))
100  trainingIds.push_back(token);
101  }
102 
103  if (parsedArgs.count("output") || parsedArgs.count("o")) {
104  outputFolder = us::any_cast<std::string>(parsedArgs["output"]);
105  }
106 
107  if (parsedArgs.count("classmap") || parsedArgs.count("m")) {
108  classMap = us::any_cast<std::string>(parsedArgs["classmap"]);
109  }
110 
111  if (parsedArgs.count("configName") || parsedArgs.count("n")) {
112  configName = us::any_cast<std::string>(parsedArgs["configName"]);
113  }
114 
115  if (parsedArgs.count("features") || parsedArgs.count("b")) {
116  std::istringstream ss(us::any_cast<std::string>(parsedArgs["features"]));
117  std::string token;
118 
119  while (std::getline(ss, token, ','))
120  features.push_back(token);
121  }
122 
123  if (parsedArgs.count("treeDepth") || parsedArgs.count("d")) {
124  treeDepth = us::any_cast<int>(parsedArgs["treeDepth"]);
125  }
126 
127 
128  if (parsedArgs.count("forestSize") || parsedArgs.count("f")) {
129  forestSize = us::any_cast<int>(parsedArgs["forestSize"]);
130  }
131 
132  if (parsedArgs.count("stats") || parsedArgs.count("s")) {
133  experimentFS.open(us::any_cast<std::string>(parsedArgs["stats"]).c_str(),
134  std::ios_base::app);
135  }
136 
137 
138  if (parsedArgs.count("testId") || parsedArgs.count("t")) {
139  std::istringstream ss(us::any_cast<std::string>(parsedArgs["testId"]));
140  std::string token;
141 
142  while (std::getline(ss, token, ','))
143  testingIds.push_back(token);
144  }
145 
146  for (unsigned int i = 0; i < features.size(); i++) {
147  loadIds.push_back(features.at(i));
148  }
149  loadIds.push_back(classMap);
150 
151  if (parsedArgs.count("stats") || parsedArgs.count("s")) {
152  outputFile = us::any_cast<std::string>(parsedArgs["stats"]);
153  }
154 
155  if (parsedArgs.count("loadFile") || parsedArgs.count("l")) {
156  xmlFile = us::any_cast<std::string>(parsedArgs["loadFile"]);
157  } else {
158  MITK_ERROR << parser.helpText();
159  return EXIT_FAILURE;
160  }
161  }
162 
163  mitk::DataCollection::Pointer trainCollection;
164  mitk::DataCollection::Pointer testCollection;
165  {
166  mitk::CollectionReader colReader;
167  // Load only relevant images
168  colReader.SetDataItemNames(loadIds);
169  colReader.AddSubColIds(testingIds);
170  testCollection = colReader.LoadCollection(xmlFile);
171  colReader.ClearDataElementIds();
172  colReader.ClearSubColIds();
173  colReader.SetDataItemNames(loadIds);
174  colReader.AddSubColIds(trainingIds);
175  trainCollection = colReader.LoadCollection(xmlFile);
176  }
177 
179  // If required do Training....
181  //mitk::DecisionForest forest;
182 
184 
185  forest->SetTreeCount(forestSize);
186  forest->SetMaximumTreeDepth(treeDepth);
187 
188 
189  // create feature matrix
190  auto trainDataX = mitk::DCUtilities::DC3dDToMatrixXd(trainCollection, features, classMap);
191  // create label matrix
192  auto trainDataY = mitk::DCUtilities::DC3dDToMatrixXi(trainCollection, classMap, classMap);
193 
194  forest->Train(trainDataX, trainDataY);
195 
196 
197  // prepare feature matrix for test case
198  auto testDataX = mitk::DCUtilities::DC3dDToMatrixXd(testCollection,features, classMap);
199  auto testDataNewY = forest->Predict(testDataX);
200 
201 
202  mitk::DCUtilities::MatrixToDC3d(testDataNewY, testCollection, "RESULT", classMap);
203 
204  Eigen::MatrixXd Probs = forest->GetPointWiseProbabilities();
205 
206 
207  Eigen::MatrixXd prob0 = Probs.col(0);
208  Eigen::MatrixXd prob1 = Probs.col(1);
209 
210  mitk::DCUtilities::MatrixToDC3d(prob0, testCollection,"prob0", classMap);
211  mitk::DCUtilities::MatrixToDC3d(prob1, testCollection,"prob1", classMap);
212 
213 
214  std::vector<std::string> outputFilter;
216  outputFolder + "/result_collection.xml",
217  outputFilter);
218  return EXIT_SUCCESS;
219 }
DataCollection::Pointer LoadCollection(const std::string &xmlFileName)
Build up a mitk::DataCollection from a XML resource.
void SetDataItemNames(std::vector< std::string > itemNames)
#define MITK_ERROR
Definition: mitkLogMacros.h:20
void setContributor(std::string contributor)
int main(int argc, char *argv[])
ValueType * any_cast(Any *operand)
Definition: usAny.h:377
void addArgument(const std::string &longarg, const std::string &shortarg, Type type, const std::string &argLabel, const std::string &argHelp=std::string(), const us::Any &defaultValue=us::Any(), bool optional=true, bool ignoreRest=false, bool deprecated=false, mitkCommandLineParser::Channel channel=mitkCommandLineParser::Channel::None)
std::map< std::string, us::Any > parseArguments(const StringContainerType &arguments, bool *ok=nullptr)
static void MatrixToDC3d(const Eigen::MatrixXd &matrix, mitk::DataCollection::Pointer dc, const std::vector< std::string > &names, std::string mask)
static Eigen::MatrixXi DC3dDToMatrixXi(mitk::DataCollection::Pointer dc, std::string name, std::string mask)
Definition: usAny.h:163
std::string helpText() const
void setCategory(std::string category)
static bool ExportCollectionToFolder(DataCollection *dataCollection, std::string xmlFile, std::vector< std::string > filter)
ExportCollectionToFolder.
void setArgumentPrefix(const std::string &longPrefix, const std::string &shortPrefix)
static Eigen::MatrixXd DC3dDToMatrixXd(mitk::DataCollection::Pointer dc, std::string names, std::string mask)
void AddSubColIds(std::vector< std::string > subColIds)
const char features[]
void setTitle(std::string title)
void setDescription(std::string description)