Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
CLVoxelClassification.cpp
Go to the documentation of this file.
1 /*===================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
7 All rights reserved.
8 
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 A PARTICULAR PURPOSE.
12 
13 See LICENSE.txt or http://www.mitk.org for details.
14 
15 ===================================================================*/
16 #ifndef mitkForest_cpp
17 #define mitkForest_cpp
18 
19 #include "time.h"
20 #include <sstream>
21 
22 #include <mitkConfigFileReader.h>
23 #include <mitkDataCollection.h>
24 #include <mitkCollectionReader.h>
25 #include <mitkCollectionWriter.h>
27 #include <mitkCostingStatistic.h>
28 #include <vtkSmartPointer.h>
29 #include <mitkIOUtil.h>
30 
32 #include <mitkRandomForestIO.h>
33 
34 // ----------------------- Forest Handling ----------------------
35 //#include <mitkDecisionForest.h>
37 //#include <mitkThresholdSplit.h>
38 //#include <mitkImpurityLoss.h>
39 //#include <mitkLinearSplitting.h>
40 //#include <mitkVigraConverter.h>
41 // ----------------------- Point weighting ----------------------
42 //#include <mitkForestWeighting.h>
43 //#include <mitkKliepDensityEstimation.h>
44 //#include <mitkExternalWeighting.h>
46 //#include <mitkKNNDensityEstimation.h>
47 //#include <mitkZadroznyWeighting.h>
48 //#include <mitkSpectralDensityEstimation.h>
49 //#include <mitkULSIFDensityEstimation.h>
50 
51 int main(int argc, char* argv[])
52 {
53  MITK_INFO << "Starting MITK_Forest Mini-App";
54  double startTime = time(0);
55 
57  // Read Console Input Parameter
59  ConfigFileReader allConfig(argv[2]);
60 
61  bool readFile = true;
62  std::stringstream ss;
63  for (int i = 0; i < argc; ++i )
64  {
65  MITK_INFO << "-----"<< argv[i]<<"------";
66  if (readFile)
67  {
68  if (argv[i][0] == '+')
69  {
70  readFile = false;
71  continue;
72  } else
73  {
74  try
75  {
76  allConfig.ReadFile(argv[i]);
77  }
78  catch (std::exception &e)
79  {
80  MITK_INFO << e.what();
81  }
82  }
83  }
84  else
85  {
86  std::string input = argv[i];
87  std::replace(input.begin(), input.end(),'_',' ');
88  ss << input << std::endl;
89  }
90  }
91  allConfig.ReadStream(ss);
92 
93  try
94  {
96  // General
98  int currentRun = allConfig.IntValue("General","Run",0);
99  int doTraining = allConfig.IntValue("General","Do Training",1);
100  std::string forestPath = allConfig.Value("General","Forest Path");
101  std::string trainingCollectionPath = allConfig.Value("General","Patient Collection");
102  std::string testCollectionPath = trainingCollectionPath;
103  MITK_INFO << "Training collection: " << trainingCollectionPath;
104 
106  // Read Default Classification
108  std::vector<std::string> trainPatients = allConfig.Vector("Training Group",currentRun);
109  std::vector<std::string> testPatients = allConfig.Vector("Test Group",currentRun);
110  std::vector<std::string> modalities = allConfig.Vector("Modalities",0);
111  std::string trainMask = allConfig.Value("Data","Training Mask");
112  std::string completeTrainMask = allConfig.Value("Data","Complete Training Mask");
113  std::string testMask = allConfig.Value("Data","Test Mask");
114  std::string resultMask = allConfig.Value("Data", "Result Mask");
115  std::string resultProb = allConfig.Value("Data", "Result Propability");
116  std::string outputFolder = allConfig.Value("General","Output Folder");
117 
118  std::string writeDataFilePath = allConfig.Value("Forest","File to write data to");
119 
121  // Read Forest Parameter
123  int minimumSplitNodeSize = allConfig.IntValue("Forest", "Minimum split node size",1);
124  int numberOfTrees = allConfig.IntValue("Forest", "Number of Trees",255);
125  double samplesPerTree = atof(allConfig.Value("Forest", "Samples per Tree").c_str());
126  if (samplesPerTree <= 0.0000001)
127  {
128  samplesPerTree = 1.0;
129  }
130  MITK_INFO << "Samples per Tree: " << samplesPerTree;
131  int sampleWithReplacement = allConfig.IntValue("Forest", "Sample with replacement",1);
132  double trainPrecision = atof(allConfig.Value("Forest", "Precision").c_str());
133  if (trainPrecision <= 0.0000000001)
134  {
135  trainPrecision = 0.0;
136  }
137  double weightLambda = atof(allConfig.Value("Forest", "Weight Lambda").c_str());
138  if (weightLambda <= 0.0000000001)
139  {
140  weightLambda = 0.0;
141  }
142  int maximumTreeDepth = allConfig.IntValue("Forest", "Maximum Tree Depth",10000);
143  int randomSplit = allConfig.IntValue("Forest","Use RandomSplit",0);
145  // Read Statistic Parameter
147  std::string statisticFilePath = allConfig.Value("Evaluation", "Statistic output file");
148  std::string statisticShortFilePath = allConfig.Value("Evaluation", "Statistic short output file");
149  std::string statisticShortFileLabel = allConfig.Value("Evaluation", "Index for short file");
150  std::string statisticGoldStandard = allConfig.Value("Evaluation", "Gold Standard Name","GTV");
151  bool statisticWithHeader = allConfig.IntValue("Evaluation", "Write header in short file",0);
152  std::vector<std::string> labelGroupA = allConfig.Vector("LabelsA",0);
153  std::vector<std::string> labelGroupB = allConfig.Vector("LabelsB",0);
155  // Read Special Parameter
157  bool useWeightedPoints = allConfig.IntValue("Forest", "Use point-based weighting",0);
158  bool writePointsToFile = allConfig.IntValue("Forest", "Write points to file",0);
159  int importanceWeightAlgorithm = allConfig.IntValue("Forest","Importance weight Algorithm",0);
160  std::string importanceWeightName = allConfig.Value("Forest","Importance weight name","");
161 
162  std::ofstream timingFile;
163  timingFile.open((statisticFilePath + ".timing").c_str(), std::ios::app);
164  timingFile << statisticShortFileLabel << ";";
165  std::time_t lastTimePoint;
166  time(&lastTimePoint);
167 
169  // Read Images
171  std::vector<std::string> usedModalities;
172  for (int i = 0; i < modalities.size(); ++i)
173  {
174  usedModalities.push_back(modalities[i]);
175  }
176  usedModalities.push_back(trainMask);
177  usedModalities.push_back(completeTrainMask);
178  usedModalities.push_back(testMask);
179  usedModalities.push_back(statisticGoldStandard);
180  usedModalities.push_back(importanceWeightName);
181 
182  // vtkSmartPointer<mitk::CollectionReader> colReader = vtkSmartPointer<mitk::CollectionReader>::New();
184  colReader->AddDataElementIds(trainPatients);
185  colReader->SetDataItemNames(usedModalities);
186  //colReader->SetNames(usedModalities);
187  mitk::DataCollection::Pointer trainCollection;
188  if (doTraining)
189  {
190  trainCollection = colReader->LoadCollection(trainingCollectionPath);
191  }
192  colReader->ClearDataElementIds();
193  colReader->AddDataElementIds(testPatients);
194  mitk::DataCollection::Pointer testCollection = colReader->LoadCollection(testCollectionPath);
195 
196  std::time_t now;
197  time(&now);
198  double seconds = std::difftime(now, lastTimePoint);
199  timingFile << seconds << ";";
200  time(&lastTimePoint);
201 
202  /*
203  if (writePointsToFile)
204  {
205  MITK_INFO << "Use external weights...";
206  mitk::ExternalWeighting weightReader;
207  weightReader.SetModalities(modalities);
208  weightReader.SetTestCollection(testCollection);
209  weightReader.SetTrainCollection(trainCollection);
210  weightReader.SetTestMask(testMask);
211  weightReader.SetTrainMask(trainMask);
212  weightReader.SetWeightsName("weights");
213  weightReader.SetCorrectionFactor(1.0);
214  weightReader.SetWeightFileName(writeDataFilePath);
215  weightReader.WriteData();
216  return 0;
217  }*/
218 
220  // If required do Training....
222  //mitk::DecisionForest forest;
223 
225  forest->SetSamplesPerTree(samplesPerTree);
226  forest->SetMinimumSplitNodeSize(minimumSplitNodeSize);
227  forest->SetTreeCount(numberOfTrees);
228  forest->UseSampleWithReplacement(sampleWithReplacement);
229  forest->SetPrecision(trainPrecision);
230  forest->SetMaximumTreeDepth(maximumTreeDepth);
231  forest->SetWeightLambda(weightLambda);
232 
233  // TOOD forest.UseRandomSplit(randomSplit);
234 
235  if (doTraining)
236  {
237  // 0 = LR-Estimation
238  // 1 = KNN-Estimation
239  // 2 = Kliep
240  // 3 = Extern Image
241  // 4 = Zadrozny
242  // 5 = Spectral
243  // 6 = uLSIF
244  auto trainDataX = mitk::DCUtilities::DC3dDToMatrixXd(trainCollection, modalities, trainMask);
245  auto trainDataY = mitk::DCUtilities::DC3dDToMatrixXi(trainCollection, trainMask, trainMask);
246 
247  //if (useWeightedPoints)
248  if (false)
249  {
250  MITK_INFO << "Activated Point-based weighting...";
251  //forest.UseWeightedPoints(true);
252  forest->UsePointWiseWeight(true);
253  //forest.SetWeightName("calculated_weight");
254  /*if (importanceWeightAlgorithm == 1)
255  {
256  mitk::KNNDensityEstimation est;
257  est.SetCollection(trainCollection);
258  est.SetTrainMask(trainMask);
259  est.SetTestMask(testMask);
260  est.SetModalities(modalities);
261  est.SetWeightName("calculated_weight");
262  est.Update();
263  } else if (importanceWeightAlgorithm == 2)
264  {
265  mitk::KliepDensityEstimation est;
266  est.SetCollection(trainCollection);
267  est.SetTrainMask(trainMask);
268  est.SetTestMask(testMask);
269  est.SetModalities(modalities);
270  est.SetWeightName("calculated_weight");
271  est.Update();
272  } else if (importanceWeightAlgorithm == 3)
273  {
274  forest.SetWeightName(importanceWeightName);
275  } else if (importanceWeightAlgorithm == 4)
276  {
277  mitk::ZadroznyWeighting est;
278  est.SetCollection(trainCollection);
279  est.SetTrainMask(trainMask);
280  est.SetTestMask(testMask);
281  est.SetModalities(modalities);
282  est.SetWeightName("calculated_weight");
283  est.Update();
284  } else if (importanceWeightAlgorithm == 5)
285  {
286  mitk::SpectralDensityEstimation est;
287  est.SetCollection(trainCollection);
288  est.SetTrainMask(trainMask);
289  est.SetTestMask(testMask);
290  est.SetModalities(modalities);
291  est.SetWeightName("calculated_weight");
292  est.Update();
293  } else if (importanceWeightAlgorithm == 6)
294  {
295  mitk::ULSIFDensityEstimation est;
296  est.SetCollection(trainCollection);
297  est.SetTrainMask(trainMask);
298  est.SetTestMask(testMask);
299  est.SetModalities(modalities);
300  est.SetWeightName("calculated_weight");
301  est.Update();
302  } else*/
303  {
305  est.SetCollection(trainCollection);
306  est.SetTrainMask(trainMask);
307  est.SetTestMask(testMask);
308  est.SetModalities(modalities);
309  est.SetWeightName("calculated_weight");
310  est.Update();
311  }
312  auto trainDataW = mitk::DCUtilities::DC3dDToMatrixXd(trainCollection, "calculated_weight", trainMask);
313  forest->SetPointWiseWeight(trainDataW);
314  forest->UsePointWiseWeight(true);
315  }
316  forest->Train(trainDataX, trainDataY);
317  // TODO forest.Save(forestPath);
318  } else
319  {
320  // TODO forest.Load(forestPath);
321  }
322 
323  time(&now);
324  seconds = std::difftime(now, lastTimePoint);
325  timingFile << seconds << ";";
326  time(&lastTimePoint);
328  // If required do Save Forest....
330 
331  //writer.// (forest);
332  /*
333  auto w = forest->GetTreeWeights();
334  w(0,0) = 10;
335  forest->SetTreeWeights(w);*/
336  mitk::IOUtil::Save(forest,"d:/tmp/forest.forest");
337 
339  // If required do test
341  auto testDataX = mitk::DCUtilities::DC3dDToMatrixXd(testCollection,modalities, testMask);
342  auto testDataNewY = forest->Predict(testDataX);
343  //MITK_INFO << testDataNewY;
344 
345  mitk::DCUtilities::MatrixToDC3d(testDataNewY, testCollection, resultMask, testMask);
346  //forest.SetMaskName(testMask);
347  //forest.SetCollection(testCollection);
348  //forest.Test();
349  //forest.PrintTree(0);
350 
351  time(&now);
352  seconds = std::difftime(now, lastTimePoint);
353  timingFile << seconds << ";";
354  time(&lastTimePoint);
355 
357  // Cost-based analysis
359 
360  // TODO Reactivate
361  //MITK_INFO << "Calculate Cost-based Statistic ";
362  //mitk::CostingStatistic costStat;
363  //costStat.SetCollection(testCollection);
364  //costStat.SetCombinedA("combinedHealty");
365  //costStat.SetCombinedB("combinedTumor");
366  //costStat.SetCombinedLabel("combinedLabel");
367  //costStat.SetMaskName(testMask);
375  //costStat.SetProbabilitiesA(labelGroupA);
376  //costStat.SetProbabilitiesB(labelGroupB);
377 
378  //std::ofstream costStatisticFile;
379  //costStatisticFile.open((statisticFilePath + ".cost").c_str(), std::ios::app);
380  //std::ofstream lcostStatisticFile;
381 
382  //lcostStatisticFile.open((statisticFilePath + ".longcost").c_str(), std::ios::app);
383  //costStat.WriteStatistic(lcostStatisticFile,costStatisticFile,2.5,statisticShortFileLabel);
384  //costStatisticFile.close();
385 
386  //costStat.CalculateClass(50);
387 
389  // Save results to folder
391  std::vector<std::string> outputFilter;
392  //outputFilter.push_back(resultMask);
393  //std::vector<std::string> propNames = forest.GetListOfProbabilityNames();
394  //outputFilter.insert(outputFilter.begin(), propNames.begin(), propNames.end());
396  outputFolder + "/result_collection.xml",
397  outputFilter);
398 
399  MITK_INFO << "Calculate Statistic...." ;
401  // Calculate and Print Statistic
403  std::ofstream statisticFile;
404  statisticFile.open(statisticFilePath.c_str(), std::ios::app);
405  std::ofstream sstatisticFile;
406  sstatisticFile.open(statisticShortFilePath.c_str(), std::ios::app);
407 
409  stat.SetCollection(testCollection);
410  stat.SetClassCount(2);
411  stat.SetGoldName(statisticGoldStandard);
412  stat.SetTestName(resultMask);
413  stat.SetMaskName(testMask);
414  stat.Update();
415  //stat.Print(statisticFile,sstatisticFile,statisticWithHeader, statisticShortFileLabel);
416  stat.Print(statisticFile,sstatisticFile,true, statisticShortFileLabel);
417  statisticFile.close();
418 
419  time(&now);
420  seconds = std::difftime(now, lastTimePoint);
421  timingFile << seconds << std::endl;
422  time(&lastTimePoint);
423  timingFile.close();
424  }
425  catch (std::string s)
426  {
427  MITK_INFO << s;
428  return 0;
429  }
430  catch (char* s)
431  {
432  MITK_INFO << s;
433  }
434 
435  return 0;
436 }
437 
438 #endif
std::vector< std::string > Vector(std::string const &section, unsigned int index) const
void SetWeightName(std::string name)
std::string Value(std::string const &section, std::string const &entry) const
static void Save(const mitk::BaseData *data, const std::string &path)
Save a mitk::BaseData instance.
Definition: mitkIOUtil.cpp:824
DataCollection::Pointer LoadCollection(const std::string &xmlFileName)
Build up a mitk::DataCollection from a XML resource.
#define MITK_INFO
Definition: mitkLogMacros.h:22
void SetDataItemNames(std::vector< std::string > itemNames)
void Print(std::ostream &out, std::ostream &sout=std::cout, bool withHeader=false, std::string label="None")
void SetTestName(std::string name)
void SetTrainMask(std::string name)
void SetMaskName(std::string name)
void SetTestMask(std::string name)
void SetGoldName(std::string name)
static void MatrixToDC3d(const Eigen::MatrixXd &matrix, mitk::DataCollection::Pointer dc, const std::vector< std::string > &names, std::string mask)
static Eigen::MatrixXi DC3dDToMatrixXi(mitk::DataCollection::Pointer dc, std::string name, std::string mask)
void SetCollection(DataCollection::Pointer collection)
void SetClassCount(vcl_size_t count)
void SetModalities(std::vector< std::string > modalities)
static bool ExportCollectionToFolder(DataCollection *dataCollection, std::string xmlFile, std::vector< std::string > filter)
ExportCollectionToFolder.
static Eigen::MatrixXd DC3dDToMatrixXd(mitk::DataCollection::Pointer dc, std::string names, std::string mask)
void ReadStream(std::istream &stream)
int main(int argc, char *argv[])
void SetCollection(DataCollection::Pointer data)
int IntValue(const std::string &section, const std::string &entry) const
void AddDataElementIds(std::vector< std::string > dataElemetIds)
static const char * replace[]
This is a dictionary to replace long names of classes, modules, etc. to shorter versions in the conso...
void ReadFile(std::string const &filePath)
static itkEventMacro(BoundingShapeInteractionEvent, itk::AnyEvent) class MITKBOUNDINGSHAPE_EXPORT BoundingShapeInteractor Pointer New()
Basic interaction methods for mitk::GeometryData.