Medical Imaging Interaction Toolkit  2018.4.99-b20efe7f
Medical Imaging Interaction Toolkit
CLPurfVoxelClassification.cpp
Go to the documentation of this file.
1 /*============================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center (DKFZ)
6 All rights reserved.
7 
8 Use of this source code is governed by a 3-clause BSD license that can be
9 found in the LICENSE file.
10 
11 ============================================================================*/
12 #ifndef mitkForest_cpp
13 #define mitkForest_cpp
14 
15 #include "time.h"
16 #include <sstream>
17 
18 #include <mitkConfigFileReader.h>
19 #include <mitkDataCollection.h>
20 #include <mitkCollectionReader.h>
21 #include <mitkCollectionWriter.h>
23 #include <mitkCostingStatistic.h>
24 #include <vtkSmartPointer.h>
25 #include <mitkIOUtil.h>
26 
28 #include <mitkRandomForestIO.h>
29 
30 // ----------------------- Forest Handling ----------------------
31 //#include <mitkDecisionForest.h>
33 //#include <mitkThresholdSplit.h>
34 //#include <mitkImpurityLoss.h>
35 //#include <mitkLinearSplitting.h>
36 //#include <mitkVigraConverter.h>
37 // ----------------------- Point weighting ----------------------
38 //#include <mitkForestWeighting.h>
39 //#include <mitkKliepDensityEstimation.h>
40 //#include <mitkExternalWeighting.h>
42 //#include <mitkKNNDensityEstimation.h>
43 //#include <mitkZadroznyWeighting.h>
44 //#include <mitkSpectralDensityEstimation.h>
45 //#include <mitkULSIFDensityEstimation.h>
46 
47 int main(int argc, char* argv[])
48 {
49  MITK_INFO << "Starting MITK_Forest Mini-App";
50  double startTime = time(0);
51 
53  // Read Console Input Parameter
55  ConfigFileReader allConfig(argv[1]);
56 
57  bool readFile = true;
58  std::stringstream ss;
59  for (int i = 0; i < argc; ++i )
60  {
61  MITK_INFO << "-----"<< argv[i]<<"------";
62  if (readFile)
63  {
64  if (argv[i][0] == '+')
65  {
66  readFile = false;
67  continue;
68  } else
69  {
70  try
71  {
72  allConfig.ReadFile(argv[i]);
73  }
74  catch (std::exception &e)
75  {
76  MITK_INFO << e.what();
77  }
78  }
79  }
80  else
81  {
82  std::string input = argv[i];
83  std::replace(input.begin(), input.end(),'_',' ');
84  ss << input << std::endl;
85  }
86  }
87  allConfig.ReadStream(ss);
88 
89  try
90  {
92  // General
94  int currentRun = allConfig.IntValue("General","Run",0);
95  int doTraining = allConfig.IntValue("General","Do Training",1);
96  std::string forestPath = allConfig.Value("General","Forest Path");
97  std::string trainingCollectionPath = allConfig.Value("General","Patient Collection");
98  std::string testCollectionPath = trainingCollectionPath;
99 
101  // Read Default Classification
103  std::vector<std::string> trainPatients = allConfig.Vector("Training Group",currentRun);
104  std::vector<std::string> testPatients = allConfig.Vector("Test Group",currentRun);
105  std::vector<std::string> modalities = allConfig.Vector("Modalities",0);
106  std::string trainMask = allConfig.Value("Data","Training Mask");
107  std::string completeTrainMask = allConfig.Value("Data","Complete Training Mask");
108  std::string testMask = allConfig.Value("Data","Test Mask");
109  std::string resultMask = allConfig.Value("Data", "Result Mask");
110  std::string resultProb = allConfig.Value("Data", "Result Propability");
111  std::string outputFolder = allConfig.Value("General","Output Folder");
112 
113  std::string writeDataFilePath = allConfig.Value("Forest","File to write data to");
114 
116  // Read Forest Parameter
118  int minimumSplitNodeSize = allConfig.IntValue("Forest", "Minimum split node size",1);
119  int numberOfTrees = allConfig.IntValue("Forest", "Number of Trees",255);
120  double samplesPerTree = atof(allConfig.Value("Forest", "Samples per Tree").c_str());
121  if (samplesPerTree <= 0.0000001)
122  {
123  samplesPerTree = 1.0;
124  }
125  MITK_INFO << "Samples per Tree: " << samplesPerTree;
126  int sampleWithReplacement = allConfig.IntValue("Forest", "Sample with replacement",1);
127  double trainPrecision = atof(allConfig.Value("Forest", "Precision").c_str());
128  if (trainPrecision <= 0.0000000001)
129  {
130  trainPrecision = 0.0;
131  }
132  double weightLambda = atof(allConfig.Value("Forest", "Weight Lambda").c_str());
133  if (weightLambda <= 0.0000000001)
134  {
135  weightLambda = 0.0;
136  }
137  int maximumTreeDepth = allConfig.IntValue("Forest", "Maximum Tree Depth",10000);
138  int randomSplit = allConfig.IntValue("Forest","Use RandomSplit",0);
140  // Read Statistic Parameter
142  std::string statisticFilePath = allConfig.Value("Evaluation", "Statistic output file");
143  std::string statisticShortFilePath = allConfig.Value("Evaluation", "Statistic short output file");
144  std::string statisticShortFileLabel = allConfig.Value("Evaluation", "Index for short file");
145  std::string statisticGoldStandard = allConfig.Value("Evaluation", "Gold Standard Name","GTV");
146  bool statisticWithHeader = allConfig.IntValue("Evaluation", "Write header in short file",0);
147  std::vector<std::string> labelGroupA = allConfig.Vector("LabelsA",0);
148  std::vector<std::string> labelGroupB = allConfig.Vector("LabelsB",0);
150  // Read Special Parameter
152  bool useWeightedPoints = allConfig.IntValue("Forest", "Use point-based weighting",0);
153  bool writePointsToFile = allConfig.IntValue("Forest", "Write points to file",0);
154  int importanceWeightAlgorithm = allConfig.IntValue("Forest","Importance weight Algorithm",0);
155  std::string importanceWeightName = allConfig.Value("Forest","Importance weight name","");
156 
157  std::ofstream timingFile;
158  timingFile.open((statisticFilePath + ".timing").c_str(), std::ios::app);
159  timingFile << statisticShortFileLabel << ";";
160  std::time_t lastTimePoint;
161  time(&lastTimePoint);
162 
164  // Read Images
166  std::vector<std::string> usedModalities;
167  for (int i = 0; i < modalities.size(); ++i)
168  {
169  usedModalities.push_back(modalities[i]);
170  }
171  usedModalities.push_back(trainMask);
172  usedModalities.push_back(completeTrainMask);
173  usedModalities.push_back(testMask);
174  usedModalities.push_back(statisticGoldStandard);
175  usedModalities.push_back(importanceWeightName);
176 
177  // vtkSmartPointer<mitk::CollectionReader> colReader = vtkSmartPointer<mitk::CollectionReader>::New();
179  colReader->AddDataElementIds(trainPatients);
180  colReader->SetDataItemNames(usedModalities);
181  //colReader->SetNames(usedModalities);
182  mitk::DataCollection::Pointer trainCollection;
183  if (doTraining)
184  {
185  trainCollection = colReader->LoadCollection(trainingCollectionPath);
186  }
187  colReader->ClearDataElementIds();
188  colReader->AddDataElementIds(testPatients);
189  mitk::DataCollection::Pointer testCollection = colReader->LoadCollection(testCollectionPath);
190 
191  std::time_t now;
192  time(&now);
193  double seconds = std::difftime(now, lastTimePoint);
194  timingFile << seconds << ";";
195  time(&lastTimePoint);
196 
197  /*
198  if (writePointsToFile)
199  {
200  MITK_INFO << "Use external weights...";
201  mitk::ExternalWeighting weightReader;
202  weightReader.SetModalities(modalities);
203  weightReader.SetTestCollection(testCollection);
204  weightReader.SetTrainCollection(trainCollection);
205  weightReader.SetTestMask(testMask);
206  weightReader.SetTrainMask(trainMask);
207  weightReader.SetWeightsName("weights");
208  weightReader.SetCorrectionFactor(1.0);
209  weightReader.SetWeightFileName(writeDataFilePath);
210  weightReader.WriteData();
211  return 0;
212  }*/
213 
215  // If required do Training....
217  //mitk::DecisionForest forest;
218 
220  forest->SetSamplesPerTree(samplesPerTree);
221  forest->SetMinimumSplitNodeSize(minimumSplitNodeSize);
222  forest->SetTreeCount(numberOfTrees);
223  forest->UseSampleWithReplacement(sampleWithReplacement);
224  forest->SetPrecision(trainPrecision);
225  forest->SetMaximumTreeDepth(maximumTreeDepth);
226  forest->SetWeightLambda(weightLambda);
227 
228  // TOOD forest.UseRandomSplit(randomSplit);
229 
230  if (doTraining)
231  {
232  // 0 = LR-Estimation
233  // 1 = KNN-Estimation
234  // 2 = Kliep
235  // 3 = Extern Image
236  // 4 = Zadrozny
237  // 5 = Spectral
238  // 6 = uLSIF
239  auto trainDataX = mitk::DCUtilities::DC3dDToMatrixXd(trainCollection, modalities, trainMask);
240  auto trainDataY = mitk::DCUtilities::DC3dDToMatrixXi(trainCollection, trainMask, trainMask);
241 
242  if (useWeightedPoints)
243  //if (false)
244  {
245  MITK_INFO << "Activated Point-based weighting...";
246  //forest.UseWeightedPoints(true);
247  forest->UsePointWiseWeight(true);
248  //forest.SetWeightName("calculated_weight");
249  /*if (importanceWeightAlgorithm == 1)
250  {
251  mitk::KNNDensityEstimation est;
252  est.SetCollection(trainCollection);
253  est.SetTrainMask(trainMask);
254  est.SetTestMask(testMask);
255  est.SetModalities(modalities);
256  est.SetWeightName("calculated_weight");
257  est.Update();
258  } else if (importanceWeightAlgorithm == 2)
259  {
260  mitk::KliepDensityEstimation est;
261  est.SetCollection(trainCollection);
262  est.SetTrainMask(trainMask);
263  est.SetTestMask(testMask);
264  est.SetModalities(modalities);
265  est.SetWeightName("calculated_weight");
266  est.Update();
267  } else if (importanceWeightAlgorithm == 3)
268  {
269  forest.SetWeightName(importanceWeightName);
270  } else if (importanceWeightAlgorithm == 4)
271  {
272  mitk::ZadroznyWeighting est;
273  est.SetCollection(trainCollection);
274  est.SetTrainMask(trainMask);
275  est.SetTestMask(testMask);
276  est.SetModalities(modalities);
277  est.SetWeightName("calculated_weight");
278  est.Update();
279  } else if (importanceWeightAlgorithm == 5)
280  {
281  mitk::SpectralDensityEstimation est;
282  est.SetCollection(trainCollection);
283  est.SetTrainMask(trainMask);
284  est.SetTestMask(testMask);
285  est.SetModalities(modalities);
286  est.SetWeightName("calculated_weight");
287  est.Update();
288  } else if (importanceWeightAlgorithm == 6)
289  {
290  mitk::ULSIFDensityEstimation est;
291  est.SetCollection(trainCollection);
292  est.SetTrainMask(trainMask);
293  est.SetTestMask(testMask);
294  est.SetModalities(modalities);
295  est.SetWeightName("calculated_weight");
296  est.Update();
297  } else*/
298  {
300  est.SetCollection(trainCollection);
301  est.SetTrainMask(trainMask);
302  est.SetTestMask(testMask);
303  est.SetModalities(modalities);
304  est.SetWeightName("calculated_weight");
305  est.Update();
306  }
307  auto trainDataW = mitk::DCUtilities::DC3dDToMatrixXd(trainCollection, "calculated_weight", trainMask);
308  forest->SetPointWiseWeight(trainDataW);
309  forest->UsePointWiseWeight(true);
310  }
311  forest->Train(trainDataX, trainDataY);
312  // TODO forest.Save(forestPath);
313  } else
314  {
315  // TODO forest.Load(forestPath);
316  }
317 
318  time(&now);
319  seconds = std::difftime(now, lastTimePoint);
320  timingFile << seconds << ";";
321  time(&lastTimePoint);
323  // If required do Save Forest....
325 
326  //writer.// (forest);
327  /*
328  auto w = forest->GetTreeWeights();
329  w(0,0) = 10;
330  forest->SetTreeWeights(w);*/
331  //mitk::IOUtil::Save(forest,"d:/tmp/forest.forest");
332 
334  // If required do test
336  auto testDataX = mitk::DCUtilities::DC3dDToMatrixXd(testCollection,modalities, testMask);
337  auto testDataNewY = forest->Predict(testDataX);
338  auto testDataNewProb = forest->GetPointWiseProbabilities();
339  //MITK_INFO << testDataNewY;
340 
341  std::vector<std::string> names;
342  names.push_back("prob-1");
343  names.push_back("prob-2");
344 
345  mitk::DCUtilities::MatrixToDC3d(testDataNewY, testCollection, resultMask, testMask);
346  mitk::DCUtilities::MatrixToDC3d(testDataNewProb, testCollection, names, testMask);
347  //forest.SetMaskName(testMask);
348  //forest.SetCollection(testCollection);
349  //forest.Test();
350  //forest.PrintTree(0);
351 
352  time(&now);
353  seconds = std::difftime(now, lastTimePoint);
354  timingFile << seconds << ";";
355  time(&lastTimePoint);
356 
358  // Cost-based analysis
360 
361  // TODO Reactivate
362  //MITK_INFO << "Calculate Cost-based Statistic ";
363  //mitk::CostingStatistic costStat;
364  //costStat.SetCollection(testCollection);
365  //costStat.SetCombinedA("combinedHealty");
366  //costStat.SetCombinedB("combinedTumor");
367  //costStat.SetCombinedLabel("combinedLabel");
368  //costStat.SetMaskName(testMask);
376  //costStat.SetProbabilitiesA(labelGroupA);
377  //costStat.SetProbabilitiesB(labelGroupB);
378 
379  //std::ofstream costStatisticFile;
380  //costStatisticFile.open((statisticFilePath + ".cost").c_str(), std::ios::app);
381  //std::ofstream lcostStatisticFile;
382 
383  //lcostStatisticFile.open((statisticFilePath + ".longcost").c_str(), std::ios::app);
384  //costStat.WriteStatistic(lcostStatisticFile,costStatisticFile,2.5,statisticShortFileLabel);
385  //costStatisticFile.close();
386 
387  //costStat.CalculateClass(50);
388 
390  // Save results to folder
392  std::vector<std::string> outputFilter;
393  //outputFilter.push_back(resultMask);
394  //std::vector<std::string> propNames = forest.GetListOfProbabilityNames();
395  //outputFilter.insert(outputFilter.begin(), propNames.begin(), propNames.end());
397  outputFolder + "/result_collection.xml",
398  outputFilter);
399 
400  MITK_INFO << "Calculate Statistic...." ;
402  // Calculate and Print Statistic
404  std::ofstream statisticFile;
405  statisticFile.open(statisticFilePath.c_str(), std::ios::app);
406  std::ofstream sstatisticFile;
407  sstatisticFile.open(statisticShortFilePath.c_str(), std::ios::app);
408 
410  stat.SetCollection(testCollection);
411  stat.SetClassCount(2);
412  stat.SetGoldName(statisticGoldStandard);
413  stat.SetTestName(resultMask);
414  stat.SetMaskName(testMask);
417  stat.SetTestValueToIndexMapper(mapper);
418  stat.Update();
419  //stat.Print(statisticFile,sstatisticFile,statisticWithHeader, statisticShortFileLabel);
420  stat.Print(statisticFile,sstatisticFile,true, statisticShortFileLabel);
421  statisticFile.close();
422  delete mapper;
423 
424  time(&now);
425  seconds = std::difftime(now, lastTimePoint);
426  timingFile << seconds << std::endl;
427  time(&lastTimePoint);
428  timingFile.close();
429  }
430  catch (std::string s)
431  {
432  MITK_INFO << s;
433  return 0;
434  }
435  catch (char* s)
436  {
437  MITK_INFO << s;
438  }
439 
440  return 0;
441 }
442 
443 #endif
void SetWeightName(std::string name)
DataCollection::Pointer LoadCollection(const std::string &xmlFileName)
Build up a mitk::DataCollection from a XML resource.
#define MITK_INFO
Definition: mitkLogMacros.h:18
void SetDataItemNames(std::vector< std::string > itemNames)
void Print(std::ostream &out, std::ostream &sout=std::cout, bool withHeader=false, std::string label="None")
void SetTestName(std::string name)
void SetTestValueToIndexMapper(const ValueToIndexMapper *mapper)
void SetGroundTruthValueToIndexMapper(const ValueToIndexMapper *mapper)
int IntValue(const std::string &section, const std::string &entry) const
void SetTrainMask(std::string name)
void SetMaskName(std::string name)
void SetTestMask(std::string name)
void SetGoldName(std::string name)
static void MatrixToDC3d(const Eigen::MatrixXd &matrix, mitk::DataCollection::Pointer dc, const std::vector< std::string > &names, std::string mask)
static Eigen::MatrixXi DC3dDToMatrixXi(mitk::DataCollection::Pointer dc, std::string name, std::string mask)
void SetCollection(DataCollection::Pointer collection)
void SetClassCount(vcl_size_t count)
void SetModalities(std::vector< std::string > modalities)
static bool ExportCollectionToFolder(DataCollection *dataCollection, std::string xmlFile, std::vector< std::string > filter)
ExportCollectionToFolder.
static Eigen::MatrixXd DC3dDToMatrixXd(mitk::DataCollection::Pointer dc, std::string names, std::string mask)
void ReadStream(std::istream &stream)
void SetCollection(DataCollection::Pointer data)
void AddDataElementIds(std::vector< std::string > dataElemetIds)
std::vector< std::string > Vector(std::string const &section, unsigned int index) const
static const char * replace[]
This is a dictionary to replace long names of classes, modules, etc. to shorter versions in the conso...
int main(int argc, char *argv[])
void ReadFile(std::string const &filePath)
std::string Value(std::string const &section, std::string const &entry) const