Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
TumorInvasionAnalysisTool.cpp
Go to the documentation of this file.
1 /*===================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
7 All rights reserved.
8 
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 A PARTICULAR PURPOSE.
12 
13 See LICENSE.txt or http://www.mitk.org for details.
14 
15 ===================================================================*/
16 
17 // MITK - DataCollection
18 #include <mitkCollectionReader.h>
19 #include <mitkCollectionWriter.h>
20 #include <mitkDataCollection.h>
21 #include <mitkImageCast.h>
22 
26 // CTK
27 #include "mitkCommandLineParser.h"
28 // ITK
29 #include <itkImageRegionIterator.h>
30 
31 using namespace std;
32 
33 int main(int argc, char *argv[])
34 {
35  // Setup CLI Module parsable interface
36  mitkCommandLineParser parser;
37  parser.setTitle("Tumor Invasion Analysis");
38  parser.setCategory("Tumor Analysis");
39  parser.setDescription("Learns and predicts Invasion behavior");
40  parser.setContributor("MBI");
41 
42  parser.setArgumentPrefix("--", "-");
43  // Add command line argument names
44  parser.addArgument("help", "h", mitkCommandLineParser::Bool, "Show options");
45  parser.addArgument("loadFile", "l", mitkCommandLineParser::InputFile, "DataCollection File");
46  parser.addArgument(
47  "colIds", "c", mitkCommandLineParser::String, "Patient Identifiers from DataCollection used for training");
48  parser.addArgument(
49  "testId", "t", mitkCommandLineParser::String, "Patient Identifier from DataCollection used for testing");
50  parser.addArgument("features", "b", mitkCommandLineParser::String, "Features");
51  parser.addArgument("stats", "s", mitkCommandLineParser::String, "Output file for stats");
52  parser.addArgument("ratio", "q", mitkCommandLineParser::Float, "ratio of tumor to healthy");
53  parser.addArgument("treeDepth", "d", mitkCommandLineParser::Int, "limits tree depth");
54  parser.addArgument("forestSize", "f", mitkCommandLineParser::Int, "number of trees");
55  parser.addArgument("samplingMode", "m", mitkCommandLineParser::Int, "mode of sample selection");
56  parser.addArgument("configName", "n", mitkCommandLineParser::String, "human readable name for configuration");
57  parser.addArgument("output", "o", mitkCommandLineParser::OutputDirectory, "output folder for results");
58  parser.addArgument("forest", "t", mitkCommandLineParser::OutputFile, "store trained forest to file");
59 
60  map<string, us::Any> parsedArgs = parser.parseArguments(argc, argv);
61  // Show a help message
62  if (parsedArgs.size() == 0)
63  return EXIT_SUCCESS;
64  if (parsedArgs.count("help") || parsedArgs.count("h"))
65  {
66  std::cout << parser.helpText();
67  return EXIT_SUCCESS;
68  }
69 
70  // Default values
71  float ratio = 1.0;
72  bool useStatsFile = false;
73  unsigned int forestSize = 250;
74  unsigned int treeDepth = 0;
75  unsigned int samplingMode = 1;
76  std::string configName = "";
77  std::string outputFolder = "";
78  std::string forestFile = "";
79 
80  std::vector<std::string> features;
81  std::vector<std::string> trainingIds;
82  std::vector<std::string> testingIds;
83  std::vector<std::string> loadIds; // features + masks needed for training and evaluation
84  std::string outputFile;
85  std::string xmlFile;
86  std::ofstream experimentFS;
87 
88  // Parse input parameters
89  {
90  if (parsedArgs.count("colIds") || parsedArgs.count("c"))
91  {
92  std::istringstream ss(us::any_cast<string>(parsedArgs["colIds"]));
93  std::string token;
94 
95  while (std::getline(ss, token, ','))
96  trainingIds.push_back(token);
97  }
98 
99  if (parsedArgs.count("output") || parsedArgs.count("o"))
100  {
101  outputFolder = us::any_cast<string>(parsedArgs["output"]);
102  }
103 
104  if (parsedArgs.count("configName") || parsedArgs.count("n"))
105  {
106  configName = us::any_cast<string>(parsedArgs["configName"]);
107  }
108 
109  if (parsedArgs.count("features") || parsedArgs.count("b"))
110  {
111  std::istringstream ss(us::any_cast<string>(parsedArgs["features"]));
112  std::string token;
113 
114  while (std::getline(ss, token, ','))
115  features.push_back(token);
116  }
117 
118  if (parsedArgs.count("treeDepth") || parsedArgs.count("d"))
119  {
120  treeDepth = us::any_cast<int>(parsedArgs["treeDepth"]);
121  }
122 
123  if (parsedArgs.count("ratio") || parsedArgs.count("q"))
124  {
125  ratio = us::any_cast<float>(parsedArgs["ratio"]);
126  }
127 
128  if (parsedArgs.count("forestSize") || parsedArgs.count("f"))
129  {
130  forestSize = us::any_cast<int>(parsedArgs["forestSize"]);
131  }
132 
133  if (parsedArgs.count("samplingMode") || parsedArgs.count("m"))
134  {
135  samplingMode = us::any_cast<int>(parsedArgs["samplingMode"]);
136  }
137 
138  if (parsedArgs.count("stats") || parsedArgs.count("s"))
139  {
140  useStatsFile = true;
141  experimentFS.open(us::any_cast<string>(parsedArgs["stats"]).c_str(), std::ios_base::app);
142  }
143 
144  if (parsedArgs.count("forest") || parsedArgs.count("t"))
145  {
146  forestFile = us::any_cast<string>(parsedArgs["stats"]);
147  }
148 
149  if (parsedArgs.count("testId") || parsedArgs.count("t"))
150  {
151  std::istringstream ss(us::any_cast<string>(parsedArgs["testId"]));
152  std::string token;
153 
154  while (std::getline(ss, token, ','))
155  testingIds.push_back(token);
156  }
157 
158  for (unsigned int i = 0; i < features.size(); i++)
159  {
160  loadIds.push_back(features.at(i));
161  }
162  loadIds.push_back("GTV");
163  loadIds.push_back("BRAINMASK");
164  loadIds.push_back("TARGET");
165 
166  if (parsedArgs.count("stats") || parsedArgs.count("s"))
167  {
168  outputFile = us::any_cast<string>(parsedArgs["stats"]);
169  }
170 
171  if (parsedArgs.count("loadFile") || parsedArgs.count("l"))
172  {
173  xmlFile = us::any_cast<string>(parsedArgs["loadFile"]);
174  }
175  else
176  {
177  MITK_ERROR << parser.helpText();
178  return EXIT_FAILURE;
179  }
180  }
181 
182  mitk::DataCollection::Pointer trainCollection;
183  mitk::DataCollection::Pointer testCollection;
184  {
185  mitk::CollectionReader colReader;
186  // Load only relevant images
187  colReader.SetDataItemNames(loadIds);
188  colReader.AddSubColIds(testingIds);
189  testCollection = colReader.LoadCollection(xmlFile);
190  colReader.ClearDataElementIds();
191  colReader.ClearSubColIds();
192  colReader.SetDataItemNames(loadIds);
193  colReader.AddSubColIds(trainingIds);
194  trainCollection = colReader.LoadCollection(xmlFile);
195  }
196 
197  std::cout << "Setup Training" << std::endl;
199 
200  classifier.SetClassRatio(ratio);
201  classifier.SetTrainMargin(7, 1);
202  classifier.SamplesWeightingActivated(true);
203  classifier.SelectTrainingSamples(trainCollection, samplingMode);
204  // Learning stage
205  std::cout << "Start Training" << std::endl;
206  classifier.LearnProgressionFeatures(trainCollection, features, forestSize, treeDepth);
207 
208  if (forestFile != "")
209  classifier.SaveRandomForest(forestFile);
210 
211  std::cout << "Start Predict" << std::endl;
212  classifier.PredictInvasion(testCollection, features);
213 
214  if (false && outputFolder != "")
215  {
216  std::cout << "Saving files to " << outputFolder << std::endl;
217  mitk::CollectionWriter::ExportCollectionToFolder(trainCollection, "/tmp/dumple");
218  }
219  classifier.SanitizeResults(testCollection);
220 
221  {
222  mitk::DataCollectionImageIterator<unsigned char, 3> gtvIt(testCollection, "GTV");
223  mitk::DataCollectionImageIterator<unsigned char, 3> result(testCollection, "RESULTOPEN");
224 
225  while (!gtvIt.IsAtEnd())
226  {
227  if (gtvIt.GetVoxel() != 0)
228  {
229  result.SetVoxel(2);
230  }
231 
232  result++;
233  gtvIt++;
234  }
235  }
236 
238  mitk::ProgressionValueToIndexMapper progressionValueToIndexMapper;
239  mitk::BinaryValueToIndexMapper binaryValueToIndexMapper;
240 
241  stat2.SetCollection(testCollection);
242  stat2.SetClassCount(2);
243  stat2.SetGoldName("TARGET");
244  stat2.SetTestName("RESULTOPEN");
245  stat2.SetMaskName("BRAINMASK");
246  stat2.SetGroundTruthValueToIndexMapper(&binaryValueToIndexMapper);
247  stat2.SetTestValueToIndexMapper(&progressionValueToIndexMapper);
248 
249  stat2.Update();
250  stat2.ComputeRMSD();
251 
252  // FIXME: DICE value available after calling Print method
253  std::ostringstream out2;
254  stat2.Print(out2, std::cout, true);
255  std::cout << std::endl << std::endl << out2.str() << std::endl;
256 
257  // Exclude GTV from Statistics by removing it from brain mask,
258  // insert GTV as tumor region, since it is known before, in the result.
259  {
260  mitk::DataCollectionImageIterator<unsigned char, 3> gtvIt(testCollection, "GTV");
261  mitk::DataCollectionImageIterator<unsigned char, 3> brainMaskIter(testCollection, "BRAINMASK");
262  mitk::DataCollectionImageIterator<unsigned char, 3> result(testCollection, "RESULTOPEN");
263 
264  while (!gtvIt.IsAtEnd())
265  {
266  if (gtvIt.GetVoxel() != 0)
267  {
268  brainMaskIter.SetVoxel(0);
269  result.SetVoxel(2);
270  }
271  result++;
272  gtvIt++;
273  brainMaskIter++;
274  }
275  }
276 
278  stat.SetCollection(testCollection);
279  stat.SetClassCount(2);
280  stat.SetGoldName("TARGET");
281  stat.SetTestName("RESULTOPEN");
282  stat.SetMaskName("BRAINMASK");
283  stat.SetGroundTruthValueToIndexMapper(&binaryValueToIndexMapper);
284  stat.SetTestValueToIndexMapper(&progressionValueToIndexMapper);
285 
286  stat.Update();
287  stat.ComputeRMSD();
288 
289  // WARN: DICE value computed within Print method, so values are only available
290  // after
291  // calling Print()
292  std::ostringstream out;
293  stat.Print(out, std::cout, true);
294  std::cout << std::endl << std::endl << out.str() << std::endl;
295 
296  // Statistics for original GTV excluded (Dice,Sensitivity) and for Gold
297  // Standard vs prediction (RMSE)
298  mitk::StatisticData statData = stat.GetStatisticData(1).at(0);
299  mitk::StatisticData statData2 = stat2.GetStatisticData(1).at(0);
300 
301  std::cout << "Writing Stats to file" << std::endl;
302  // one line output
303  if (useStatsFile)
304  {
305  experimentFS << "Tree_Depth " << treeDepth << ',';
306  experimentFS << "Forest_Size " << forestSize << ',';
307  experimentFS << "Tumor/healthy_ratio " << ratio << ',';
308  experimentFS << "Sample_Selection " << samplingMode << ',';
309 
310  experimentFS << "Trainined_on: " << ',';
311  for (unsigned int i = 0; i < trainingIds.size(); i++)
312  {
313  experimentFS << trainingIds.at(i) << "/";
314  }
315  experimentFS << ',';
316 
317  experimentFS << "Tested_on: " << ',';
318  for (unsigned int i = 0; i < testingIds.size(); i++)
319  {
320  experimentFS << testingIds.at(i) << "/";
321  }
322  experimentFS << ',';
323 
324  experimentFS << "Features_used: " << ',';
325  if (configName == "")
326  {
327  for (unsigned int i = 0; i < features.size(); i++)
328  {
329  experimentFS << features.at(i) << "/";
330  }
331  }
332  else
333  experimentFS << configName;
334 
335  experimentFS << ',';
336  experimentFS << "---- STATS ---" << ',';
337  experimentFS << " Sensitivity " << statData.m_Sensitivity << ',';
338  experimentFS << " DICE " << statData.m_DICE << ',';
339  experimentFS << " RMSE " << statData2.m_RMSD << ',';
340  experimentFS << std::endl;
341  }
342 
343  if (outputFolder != "")
344  {
345  std::cout << "Saving files to " << outputFolder << std::endl;
346  mitk::CollectionWriter::ExportCollectionToFolder(testCollection, outputFolder);
347  }
348 
349  return EXIT_SUCCESS;
350 }
void SamplesWeightingActivated(bool isActive)
SamplesWeightingActivated If activated a weighted mask for the samples is calculated, weighting samples according to their location and ratio.
DataCollection::Pointer LoadCollection(const std::string &xmlFileName)
Build up a mitk::DataCollection from a XML resource.
void SetDataItemNames(std::vector< std::string > itemNames)
void Print(std::ostream &out, std::ostream &sout=std::cout, bool withHeader=false, std::string label="None")
void SaveRandomForest(std::string filename)
SaveRandomForest - Saves a trained random forest.
void SetTestName(std::string name)
#define MITK_ERROR
Definition: mitkLogMacros.h:24
std::vector< StatisticData > GetStatisticData(unsigned char c) const
mitk::CollectionStatistic::GetStatisticData
void setContributor(std::string contributor)
void SetTestValueToIndexMapper(const ValueToIndexMapper *mapper)
STL namespace.
ValueType * any_cast(Any *operand)
Definition: usAny.h:377
void SetGroundTruthValueToIndexMapper(const ValueToIndexMapper *mapper)
void SetTrainMargin(vcl_size_t dil2d, vcl_size_t dil3d)
std::map< std::string, us::Any > parseArguments(const StringContainerType &arguments, bool *ok=nullptr)
void SelectTrainingSamples(DataCollection *collection, unsigned int mode=0)
SelectTrainingSamples.
void SetMaskName(std::string name)
void SetGoldName(std::string name)
void addArgument(const std::string &longarg, const std::string &shortarg, Type type, const std::string &argLabel, const std::string &argHelp=std::string(), const us::Any &defaultValue=us::Any(), bool optional=true, bool ignoreRest=false, bool deprecated=false)
void PredictInvasion(DataCollection *collection, std::vector< std::string > modalitiesList)
PredictGrowth - Classify voxels into remaining healthy / turning into tumor.
void ComputeRMSD()
Computes root-mean-square distance of two binary images.
void SetCollection(DataCollection::Pointer collection)
void SetClassCount(vcl_size_t count)
void setCategory(std::string category)
static bool ExportCollectionToFolder(DataCollection *dataCollection, std::string xmlFile, std::vector< std::string > filter)
ExportCollectionToFolder.
The TumorInvasionAnalysis class - Classifies Tumor progression using RF and predicts on new cases...
void setArgumentPrefix(const std::string &longPrefix, const std::string &shortPrefix)
void LearnProgressionFeatures(DataCollection *collection, std::vector< std::string > modalitiesList, vcl_size_t forestSize=300, vcl_size_t treeDepth=10)
LearnProgressionFeatures.
void SetClassRatio(ScalarType ratio)
SetClassRatio - set ratio of tumor voxels to healthy voxels that is to be used for training...
void AddSubColIds(std::vector< std::string > subColIds)
int main(int argc, char *argv[])
const char features[]
std::string helpText() const
void setTitle(std::string title)
void SanitizeResults(DataCollection *collection, std::string resultID="RESULT")
SanitizeResults - Performs an Opening Operation on tha data to remove isolated misclassifications.
void setDescription(std::string description)