Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
TumorInvasionAnalysisTool.cpp
Go to the documentation of this file.
1 /*===================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
7 All rights reserved.
8 
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 A PARTICULAR PURPOSE.
12 
13 See LICENSE.txt or http://www.mitk.org for details.
14 
15 ===================================================================*/
16 
17 // MITK - DataCollection
18 #include <mitkCollectionReader.h>
19 #include <mitkCollectionWriter.h>
20 #include <mitkDataCollection.h>
21 #include <mitkImageCast.h>
22 
26 // CTK
27 #include "mitkCommandLineParser.h"
28 // ITK
29 #include <itkImageRegionIterator.h>
30 
31 using namespace std;
32 
33 int main(int argc, char *argv[])
34 {
35  // Setup CLI Module parsable interface
36  mitkCommandLineParser parser;
37  parser.setTitle("Tumor Invasion Analysis");
38  parser.setCategory("Tumor Analysis");
39  parser.setDescription("Learns and predicts Invasion behavior");
40  parser.setContributor("MBI");
41 
42  parser.setArgumentPrefix("--", "-");
43  // Add command line argument names
44  parser.addArgument("help", "h", mitkCommandLineParser::Bool, "Show options");
45  parser.addArgument("loadFile", "l", mitkCommandLineParser::InputFile, "DataCollection File");
46  parser.addArgument(
47  "colIds", "c", mitkCommandLineParser::String, "Patient Identifiers from DataCollection used for training");
48  parser.addArgument(
49  "testId", "t", mitkCommandLineParser::String, "Patient Identifier from DataCollection used for testing");
50  parser.addArgument("features", "b", mitkCommandLineParser::String, "Features");
51  parser.addArgument("stats", "s", mitkCommandLineParser::String, "Output file for stats");
52  parser.addArgument("ratio", "q", mitkCommandLineParser::Float, "ratio of tumor to healthy");
53  parser.addArgument("treeDepth", "d", mitkCommandLineParser::Int, "limits tree depth");
54  parser.addArgument("forestSize", "f", mitkCommandLineParser::Int, "number of trees");
55  parser.addArgument("samplingMode", "m", mitkCommandLineParser::Int, "mode of sample selection");
56  parser.addArgument("configName", "n", mitkCommandLineParser::String, "human readable name for configuration");
57  parser.addArgument("output", "o", mitkCommandLineParser::OutputDirectory, "output folder for results");
58  parser.addArgument("forest", "t", mitkCommandLineParser::OutputFile, "store trained forest to file");
59 
60  map<string, us::Any> parsedArgs = parser.parseArguments(argc, argv);
61  // Show a help message
62  if (parsedArgs.size() == 0)
63  return EXIT_SUCCESS;
64  if (parsedArgs.count("help") || parsedArgs.count("h"))
65  {
66  std::cout << parser.helpText();
67  return EXIT_SUCCESS;
68  }
69 
70  // Default values
71  float ratio = 1.0;
72  bool useStatsFile = false;
73  unsigned int forestSize = 250;
74  unsigned int treeDepth = 0;
75  unsigned int samplingMode = 1;
76  std::string configName = "";
77  std::string outputFolder = "";
78  std::string forestFile = "";
79 
80  std::vector<std::string> features;
81  std::vector<std::string> trainingIds;
82  std::vector<std::string> testingIds;
83  std::vector<std::string> loadIds; // features + masks needed for training and evaluation
84  std::string outputFile;
85  std::string xmlFile;
86  std::ofstream experimentFS;
87 
88  // Parse input parameters
89  {
90  if (parsedArgs.count("colIds") || parsedArgs.count("c"))
91  {
92  std::istringstream ss(us::any_cast<string>(parsedArgs["colIds"]));
93  std::string token;
94 
95  while (std::getline(ss, token, ','))
96  trainingIds.push_back(token);
97  }
98 
99  if (parsedArgs.count("output") || parsedArgs.count("o"))
100  {
101  outputFolder = us::any_cast<string>(parsedArgs["output"]);
102  }
103 
104  if (parsedArgs.count("configName") || parsedArgs.count("n"))
105  {
106  configName = us::any_cast<string>(parsedArgs["configName"]);
107  }
108 
109  if (parsedArgs.count("features") || parsedArgs.count("b"))
110  {
111  std::istringstream ss(us::any_cast<string>(parsedArgs["features"]));
112  std::string token;
113 
114  while (std::getline(ss, token, ','))
115  features.push_back(token);
116  }
117 
118  if (parsedArgs.count("treeDepth") || parsedArgs.count("d"))
119  {
120  treeDepth = us::any_cast<int>(parsedArgs["treeDepth"]);
121  }
122 
123  if (parsedArgs.count("ratio") || parsedArgs.count("q"))
124  {
125  ratio = us::any_cast<float>(parsedArgs["ratio"]);
126  }
127 
128  if (parsedArgs.count("forestSize") || parsedArgs.count("f"))
129  {
130  forestSize = us::any_cast<int>(parsedArgs["forestSize"]);
131  }
132 
133  if (parsedArgs.count("samplingMode") || parsedArgs.count("m"))
134  {
135  samplingMode = us::any_cast<int>(parsedArgs["samplingMode"]);
136  }
137 
138  if (parsedArgs.count("stats") || parsedArgs.count("s"))
139  {
140  useStatsFile = true;
141  experimentFS.open(us::any_cast<string>(parsedArgs["stats"]).c_str(), std::ios_base::app);
142  }
143 
144  if (parsedArgs.count("forest") || parsedArgs.count("t"))
145  {
146  forestFile = us::any_cast<string>(parsedArgs["stats"]);
147  }
148 
149  if (parsedArgs.count("testId") || parsedArgs.count("t"))
150  {
151  std::istringstream ss(us::any_cast<string>(parsedArgs["testId"]));
152  std::string token;
153 
154  while (std::getline(ss, token, ','))
155  testingIds.push_back(token);
156  }
157 
158  for (unsigned int i = 0; i < features.size(); i++)
159  {
160  loadIds.push_back(features.at(i));
161  }
162  loadIds.push_back("GTV");
163  loadIds.push_back("BRAINMASK");
164  loadIds.push_back("TARGET");
165 
166  if (parsedArgs.count("stats") || parsedArgs.count("s"))
167  {
168  outputFile = us::any_cast<string>(parsedArgs["stats"]);
169  }
170 
171  if (parsedArgs.count("loadFile") || parsedArgs.count("l"))
172  {
173  xmlFile = us::any_cast<string>(parsedArgs["loadFile"]);
174  }
175  else
176  {
177  MITK_ERROR << parser.helpText();
178  return EXIT_FAILURE;
179  }
180  }
181 
182  mitk::DataCollection::Pointer trainCollection;
183  mitk::DataCollection::Pointer testCollection;
184  {
185  mitk::CollectionReader colReader;
186  // Load only relevant images
187  colReader.SetDataItemNames(loadIds);
188  colReader.AddSubColIds(testingIds);
189  testCollection = colReader.LoadCollection(xmlFile);
190  colReader.ClearDataElementIds();
191  colReader.ClearSubColIds();
192  colReader.SetDataItemNames(loadIds);
193  colReader.AddSubColIds(trainingIds);
194  trainCollection = colReader.LoadCollection(xmlFile);
195  }
196 
197  std::cout << "Setup Training" << std::endl;
199 
200  classifier.SetClassRatio(ratio);
201  classifier.SetTrainMargin(7, 1);
202  classifier.SamplesWeightingActivated(true);
203  classifier.SelectTrainingSamples(trainCollection, samplingMode);
204  // Learning stage
205  std::cout << "Start Training" << std::endl;
206  classifier.LearnProgressionFeatures(trainCollection, features, forestSize, treeDepth);
207 
208  if (forestFile != "")
209  classifier.SaveRandomForest(forestFile);
210 
211  std::cout << "Start Predict" << std::endl;
212  classifier.PredictInvasion(testCollection, features);
213 
214  if (false && outputFolder != "")
215  {
216  std::cout << "Saving files to " << outputFolder << std::endl;
217  mitk::CollectionWriter::ExportCollectionToFolder(trainCollection, "/tmp/dumple");
218  }
219  classifier.SanitizeResults(testCollection);
220 
221  {
222  mitk::DataCollectionImageIterator<unsigned char, 3> gtvIt(testCollection, "GTV");
223  mitk::DataCollectionImageIterator<unsigned char, 3> result(testCollection, "RESULTOPEN");
224 
225  while (!gtvIt.IsAtEnd())
226  {
227  if (gtvIt.GetVoxel() != 0)
228  {
229  result.SetVoxel(2);
230  }
231 
232  result++;
233  gtvIt++;
234  }
235  }
236 
238  mitk::ProgressionValueToIndexMapper progressionValueToIndexMapper;
239  mitk::BinaryValueToIndexMapper binaryValueToIndexMapper;
240 
241  stat2.SetCollection(testCollection);
242  stat2.SetClassCount(2);
243  stat2.SetGoldName("TARGET");
244  stat2.SetTestName("RESULTOPEN");
245  stat2.SetMaskName("BRAINMASK");
246  stat2.SetGroundTruthValueToIndexMapper(&binaryValueToIndexMapper);
247  stat2.SetTestValueToIndexMapper(&progressionValueToIndexMapper);
248 
249  stat2.Update();
250  stat2.ComputeRMSD();
251 
252  // FIXME: DICE value available after calling Print method
253  std::ostringstream out2;
254  stat2.Print(out2, std::cout, true);
255  std::cout << std::endl << std::endl << out2.str() << std::endl;
256 
257  // Exclude GTV from Statistics by removing it from brain mask,
258  // insert GTV as tumor region, since it is known before, in the result.
259  {
260  mitk::DataCollectionImageIterator<unsigned char, 3> gtvIt(testCollection, "GTV");
261  mitk::DataCollectionImageIterator<unsigned char, 3> brainMaskIter(testCollection, "BRAINMASK");
262  mitk::DataCollectionImageIterator<unsigned char, 3> result(testCollection, "RESULTOPEN");
263 
264  while (!gtvIt.IsAtEnd())
265  {
266  if (gtvIt.GetVoxel() != 0)
267  {
268  brainMaskIter.SetVoxel(0);
269  result.SetVoxel(2);
270  }
271  result++;
272  gtvIt++;
273  brainMaskIter++;
274  }
275  }
276 
278  stat.SetCollection(testCollection);
279  stat.SetClassCount(2);
280  stat.SetGoldName("TARGET");
281  stat.SetTestName("RESULTOPEN");
282  stat.SetMaskName("BRAINMASK");
283  stat.SetGroundTruthValueToIndexMapper(&binaryValueToIndexMapper);
284  stat.SetTestValueToIndexMapper(&progressionValueToIndexMapper);
285 
286  stat.Update();
287  stat.ComputeRMSD();
288 
289  // WARN: DICE value computed within Print method, so values are only available
290  // after
291  // calling Print()
292  std::ostringstream out;
293  stat.Print(out, std::cout, true);
294  std::cout << std::endl << std::endl << out.str() << std::endl;
295 
296  // Statistics for original GTV excluded (Dice,Sensitivity) and for Gold
297  // Standard vs prediction (RMSE)
298  mitk::StatisticData statData = stat.GetStatisticData(1).at(0);
299  mitk::StatisticData statData2 = stat2.GetStatisticData(1).at(0);
300 
301  std::cout << "Writing Stats to file" << std::endl;
302  // one line output
303  if (useStatsFile)
304  {
305  experimentFS << "Tree_Depth " << treeDepth << ',';
306  experimentFS << "Forest_Size " << forestSize << ',';
307  experimentFS << "Tumor/healthy_ratio " << ratio << ',';
308  experimentFS << "Sample_Selection " << samplingMode << ',';
309 
310  experimentFS << "Trainined_on: " << ',';
311  for (unsigned int i = 0; i < trainingIds.size(); i++)
312  {
313  experimentFS << trainingIds.at(i) << "/";
314  }
315  experimentFS << ',';
316 
317  experimentFS << "Tested_on: " << ',';
318  for (unsigned int i = 0; i < testingIds.size(); i++)
319  {
320  experimentFS << testingIds.at(i) << "/";
321  }
322  experimentFS << ',';
323 
324  experimentFS << "Features_used: " << ',';
325  if (configName == "")
326  {
327  for (unsigned int i = 0; i < features.size(); i++)
328  {
329  experimentFS << features.at(i) << "/";
330  }
331  }
332  else
333  experimentFS << configName;
334 
335  experimentFS << ',';
336  experimentFS << "---- STATS ---" << ',';
337  experimentFS << " Sensitivity " << statData.m_Sensitivity << ',';
338  experimentFS << " DICE " << statData.m_DICE << ',';
339  experimentFS << " RMSE " << statData2.m_RMSD << ',';
340  experimentFS << std::endl;
341  }
342 
343  if (outputFolder != "")
344  {
345  std::cout << "Saving files to " << outputFolder << std::endl;
346  mitk::CollectionWriter::ExportCollectionToFolder(testCollection, outputFolder);
347  }
348 
349  return EXIT_SUCCESS;
350 }
void SamplesWeightingActivated(bool isActive)
SamplesWeightingActivated If activated a weighted mask for the samples is calculated, weighting samples according to their location and ratio.
DataCollection::Pointer LoadCollection(const std::string &xmlFileName)
Build up a mitk::DataCollection from a XML resource.
void SetDataItemNames(std::vector< std::string > itemNames)
void Print(std::ostream &out, std::ostream &sout=std::cout, bool withHeader=false, std::string label="None")
void SaveRandomForest(std::string filename)
SaveRandomForest - Saves a trained random forest.
void SetTestName(std::string name)
#define MITK_ERROR
Definition: mitkLogMacros.h:24
std::vector< StatisticData > GetStatisticData(unsigned char c) const
mitk::CollectionStatistic::GetStatisticData
void setContributor(std::string contributor)
void SetTestValueToIndexMapper(const ValueToIndexMapper *mapper)
STL namespace.
ValueType * any_cast(Any *operand)
Definition: usAny.h:377
void SetGroundTruthValueToIndexMapper(const ValueToIndexMapper *mapper)
void SetTrainMargin(vcl_size_t dil2d, vcl_size_t dil3d)
std::map< std::string, us::Any > parseArguments(const StringContainerType &arguments, bool *ok=nullptr)
void SelectTrainingSamples(DataCollection *collection, unsigned int mode=0)
SelectTrainingSamples.
void SetMaskName(std::string name)
void SetGoldName(std::string name)
void addArgument(const std::string &longarg, const std::string &shortarg, Type type, const std::string &argLabel, const std::string &argHelp=std::string(), const us::Any &defaultValue=us::Any(), bool optional=true, bool ignoreRest=false, bool deprecated=false)
void PredictInvasion(DataCollection *collection, std::vector< std::string > modalitiesList)
PredictGrowth - Classify voxels into remaining healthy / turning into tumor.
void ComputeRMSD()
Computes root-mean-square distance of two binary images.
void SetCollection(DataCollection::Pointer collection)
void SetClassCount(vcl_size_t count)
void setCategory(std::string category)
static bool ExportCollectionToFolder(DataCollection *dataCollection, std::string xmlFile, std::vector< std::string > filter)
ExportCollectionToFolder.
The TumorInvasionAnalysis class - Classifies Tumor progression using RF and predicts on new cases...
void setArgumentPrefix(const std::string &longPrefix, const std::string &shortPrefix)
void LearnProgressionFeatures(DataCollection *collection, std::vector< std::string > modalitiesList, vcl_size_t forestSize=300, vcl_size_t treeDepth=10)
LearnProgressionFeatures.
void SetClassRatio(ScalarType ratio)
SetClassRatio - set ratio of tumor voxels to healthy voxels that is to be used for training...
void AddSubColIds(std::vector< std::string > subColIds)
int main(int argc, char *argv[])
const char features[]
std::string helpText() const
void setTitle(std::string title)
void SanitizeResults(DataCollection *collection, std::string resultID="RESULT")
SanitizeResults - Performs an Opening Operation on tha data to remove isolated misclassifications.
void setDescription(std::string description)