Medical Imaging Interaction Toolkit  2018.4.99-389bf124
Medical Imaging Interaction Toolkit
CLVoxelClassification.cpp
Go to the documentation of this file.
1 /*============================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center (DKFZ)
6 All rights reserved.
7 
8 Use of this source code is governed by a 3-clause BSD license that can be
9 found in the LICENSE file.
10 
11 ============================================================================*/
12 #ifndef mitkForest_cpp
13 #define mitkForest_cpp
14 
15 #include "time.h"
16 #include <sstream>
17 
18 #include <mitkConfigFileReader.h>
19 #include <mitkDataCollection.h>
20 #include <mitkCollectionReader.h>
21 #include <mitkCollectionWriter.h>
23 #include <mitkCostingStatistic.h>
24 #include <vtkSmartPointer.h>
25 #include <mitkIOUtil.h>
26 
28 #include <mitkRandomForestIO.h>
29 
30 // ----------------------- Forest Handling ----------------------
31 //#include <mitkDecisionForest.h>
33 //#include <mitkThresholdSplit.h>
34 //#include <mitkImpurityLoss.h>
35 //#include <mitkLinearSplitting.h>
36 //#include <mitkVigraConverter.h>
37 // ----------------------- Point weighting ----------------------
38 //#include <mitkForestWeighting.h>
39 //#include <mitkKliepDensityEstimation.h>
40 //#include <mitkExternalWeighting.h>
42 //#include <mitkKNNDensityEstimation.h>
43 //#include <mitkZadroznyWeighting.h>
44 //#include <mitkSpectralDensityEstimation.h>
45 //#include <mitkULSIFDensityEstimation.h>
46 
47 int main(int argc, char* argv[])
48 {
49  MITK_INFO << "Starting MITK_Forest Mini-App";
50 
52  // Read Console Input Parameter
54  ConfigFileReader allConfig(argv[1]);
55 
56  bool readFile = true;
57  std::stringstream ss;
58  for (int i = 0; i < argc; ++i )
59  {
60  MITK_INFO << "-----"<< argv[i]<<"------";
61  if (readFile)
62  {
63  if (argv[i][0] == '+')
64  {
65  readFile = false;
66  continue;
67  } else
68  {
69  try
70  {
71  allConfig.ReadFile(argv[i]);
72  }
73  catch ( const std::exception &e )
74  {
75  MITK_INFO << e.what();
76  }
77  }
78  }
79  else
80  {
81  std::string input = argv[i];
82  std::replace(input.begin(), input.end(),'_',' ');
83  ss << input << std::endl;
84  }
85  }
86  allConfig.ReadStream(ss);
87 
88  try
89  {
91  // General
93  int currentRun = allConfig.IntValue("General","Run",0);
94  int doTraining = allConfig.IntValue("General","Do Training",1);
95  std::string forestPath = allConfig.Value("General","Forest Path");
96  std::string trainingCollectionPath = allConfig.Value("General","Patient Collection");
97  std::string testCollectionPath = allConfig.Value("General", "Patient Test Collection", trainingCollectionPath);
98 
100  // Read Default Classification
102  std::vector<std::string> trainPatients = allConfig.Vector("Training Group",currentRun);
103  std::vector<std::string> testPatients = allConfig.Vector("Test Group",currentRun);
104  std::vector<std::string> modalities = allConfig.Vector("Modalities", 0);
105  std::vector<std::string> outputFilter = allConfig.Vector("Output Filter", 0);
106  std::string trainMask = allConfig.Value("Data","Training Mask");
107  std::string completeTrainMask = allConfig.Value("Data","Complete Training Mask");
108  std::string testMask = allConfig.Value("Data","Test Mask");
109  std::string resultMask = allConfig.Value("Data", "Result Mask");
110  std::string resultProb = allConfig.Value("Data", "Result Propability");
111  std::string outputFolder = allConfig.Value("General","Output Folder");
112 
113  std::string writeDataFilePath = allConfig.Value("Forest","File to write data to");
114 
116  // Read Data Forest Parameter
118  int testSingleDataset = allConfig.IntValue("Data", "Test Single Dataset",0);
119  std::string singleDatasetName = allConfig.Value("Data", "Single Dataset Name", "none");
120  int trainSingleDataset = allConfig.IntValue("Data", "Train Single Dataset", 0);
121  std::string singleTrainDatasetName = allConfig.Value("Data", "Train Single Dataset Name", "none");
122 
124  // Read Forest Parameter
126  int minimumSplitNodeSize = allConfig.IntValue("Forest", "Minimum split node size",1);
127  int numberOfTrees = allConfig.IntValue("Forest", "Number of Trees",255);
128  double samplesPerTree = atof(allConfig.Value("Forest", "Samples per Tree").c_str());
129  if (samplesPerTree <= 0.0000001)
130  {
131  samplesPerTree = 1.0;
132  }
133  MITK_INFO << "Samples per Tree: " << samplesPerTree;
134  int sampleWithReplacement = allConfig.IntValue("Forest", "Sample with replacement",1);
135  double trainPrecision = atof(allConfig.Value("Forest", "Precision").c_str());
136  if (trainPrecision <= 0.0000000001)
137  {
138  trainPrecision = 0.0;
139  }
140  double weightLambda = atof(allConfig.Value("Forest", "Weight Lambda").c_str());
141  if (weightLambda <= 0.0000000001)
142  {
143  weightLambda = 0.0;
144  }
145  int maximumTreeDepth = allConfig.IntValue("Forest", "Maximum Tree Depth",10000);
146  // TODO int randomSplit = allConfig.IntValue("Forest","Use RandomSplit",0);
148  // Read Statistic Parameter
150  std::string statisticFilePath = allConfig.Value("Evaluation", "Statistic output file");
151  std::string statisticShortFilePath = allConfig.Value("Evaluation", "Statistic short output file");
152  std::string statisticShortFileLabel = allConfig.Value("Evaluation", "Index for short file");
153  std::string statisticGoldStandard = allConfig.Value("Evaluation", "Gold Standard Name","GTV");
154  // TODO bool statisticWithHeader = allConfig.IntValue("Evaluation", "Write header in short file",0);
155  std::vector<std::string> labelGroupA = allConfig.Vector("LabelsA",0);
156  std::vector<std::string> labelGroupB = allConfig.Vector("LabelsB",0);
158  // Read Special Parameter
160  bool useWeightedPoints = allConfig.IntValue("Forest", "Use point-based weighting",0);
161  // TODO bool writePointsToFile = allConfig.IntValue("Forest", "Write points to file",0);
162  // TODO int importanceWeightAlgorithm = allConfig.IntValue("Forest","Importance weight Algorithm",0);
163  std::string importanceWeightName = allConfig.Value("Forest","Importance weight name","");
164 
165  std::ofstream timingFile;
166  timingFile.open((statisticFilePath + ".timing").c_str(), std::ios::app);
167  timingFile << statisticShortFileLabel << ";";
168  std::time_t lastTimePoint;
169  time(&lastTimePoint);
170 
172  // Read Images
174  std::vector<std::string> usedModalities;
175  for (std::size_t i = 0; i < modalities.size(); ++i)
176  {
177  usedModalities.push_back(modalities[i]);
178  }
179  usedModalities.push_back(trainMask);
180  usedModalities.push_back(completeTrainMask);
181  usedModalities.push_back(testMask);
182  usedModalities.push_back(statisticGoldStandard);
183  usedModalities.push_back(importanceWeightName);
184 
185  if (trainSingleDataset > 0)
186  {
187  trainPatients.clear();
188  trainPatients.push_back(singleTrainDatasetName);
189  }
190 
192  colReader->AddDataElementIds(trainPatients);
193  colReader->SetDataItemNames(usedModalities);
194  //colReader->SetNames(usedModalities);
195  mitk::DataCollection::Pointer trainCollection;
196  if (doTraining)
197  {
198  trainCollection = colReader->LoadCollection(trainingCollectionPath);
199  }
200 
201  if (testSingleDataset > 0)
202  {
203  testPatients.clear();
204  testPatients.push_back(singleDatasetName);
205  }
206  colReader->ClearDataElementIds();
207  colReader->AddDataElementIds(testPatients);
208  mitk::DataCollection::Pointer testCollection = colReader->LoadCollection(testCollectionPath);
209 
210  std::time_t now;
211  time(&now);
212  double seconds = std::difftime(now, lastTimePoint);
213  timingFile << seconds << ";";
214  time(&lastTimePoint);
215 
216  /*
217  if (writePointsToFile)
218  {
219  MITK_INFO << "Use external weights...";
220  mitk::ExternalWeighting weightReader;
221  weightReader.SetModalities(modalities);
222  weightReader.SetTestCollection(testCollection);
223  weightReader.SetTrainCollection(trainCollection);
224  weightReader.SetTestMask(testMask);
225  weightReader.SetTrainMask(trainMask);
226  weightReader.SetWeightsName("weights");
227  weightReader.SetCorrectionFactor(1.0);
228  weightReader.SetWeightFileName(writeDataFilePath);
229  weightReader.WriteData();
230  return 0;
231  }*/
232 
234  // If required do Training....
236  //mitk::DecisionForest forest;
237 
239  forest->SetSamplesPerTree(samplesPerTree);
240  forest->SetMinimumSplitNodeSize(minimumSplitNodeSize);
241  forest->SetTreeCount(numberOfTrees);
242  forest->UseSampleWithReplacement(sampleWithReplacement);
243  forest->SetPrecision(trainPrecision);
244  forest->SetMaximumTreeDepth(maximumTreeDepth);
245  forest->SetWeightLambda(weightLambda);
246 
247  // TODO forest.UseRandomSplit(randomSplit);
248 
249  if (doTraining)
250  {
251  // 0 = LR-Estimation
252  // 1 = KNN-Estimation
253  // 2 = Kliep
254  // 3 = Extern Image
255  // 4 = Zadrozny
256  // 5 = Spectral
257  // 6 = uLSIF
258  auto trainDataX = mitk::DCUtilities::DC3dDToMatrixXd(trainCollection, modalities, trainMask);
259  auto trainDataY = mitk::DCUtilities::DC3dDToMatrixXi(trainCollection, trainMask, trainMask);
260 
261  if (useWeightedPoints)
262  //if (false)
263  {
264  MITK_INFO << "Activated Point-based weighting...";
265  //forest.UseWeightedPoints(true);
266  forest->UsePointWiseWeight(true);
267  //forest.SetWeightName("calculated_weight");
268  /*if (importanceWeightAlgorithm == 1)
269  {
270  mitk::KNNDensityEstimation est;
271  est.SetCollection(trainCollection);
272  est.SetTrainMask(trainMask);
273  est.SetTestMask(testMask);
274  est.SetModalities(modalities);
275  est.SetWeightName("calculated_weight");
276  est.Update();
277  } else if (importanceWeightAlgorithm == 2)
278  {
279  mitk::KliepDensityEstimation est;
280  est.SetCollection(trainCollection);
281  est.SetTrainMask(trainMask);
282  est.SetTestMask(testMask);
283  est.SetModalities(modalities);
284  est.SetWeightName("calculated_weight");
285  est.Update();
286  } else if (importanceWeightAlgorithm == 3)
287  {
288  forest.SetWeightName(importanceWeightName);
289  } else if (importanceWeightAlgorithm == 4)
290  {
291  mitk::ZadroznyWeighting est;
292  est.SetCollection(trainCollection);
293  est.SetTrainMask(trainMask);
294  est.SetTestMask(testMask);
295  est.SetModalities(modalities);
296  est.SetWeightName("calculated_weight");
297  est.Update();
298  } else if (importanceWeightAlgorithm == 5)
299  {
300  mitk::SpectralDensityEstimation est;
301  est.SetCollection(trainCollection);
302  est.SetTrainMask(trainMask);
303  est.SetTestMask(testMask);
304  est.SetModalities(modalities);
305  est.SetWeightName("calculated_weight");
306  est.Update();
307  } else if (importanceWeightAlgorithm == 6)
308  {
309  mitk::ULSIFDensityEstimation est;
310  est.SetCollection(trainCollection);
311  est.SetTrainMask(trainMask);
312  est.SetTestMask(testMask);
313  est.SetModalities(modalities);
314  est.SetWeightName("calculated_weight");
315  est.Update();
316  } else*/
317  {
319  est.SetCollection(trainCollection);
320  est.SetTrainMask(trainMask);
321  est.SetTestMask(testMask);
322  est.SetModalities(modalities);
323  est.SetWeightName("calculated_weight");
324  est.Update();
325  }
326  auto trainDataW = mitk::DCUtilities::DC3dDToMatrixXd(trainCollection, "calculated_weight", trainMask);
327  forest->SetPointWiseWeight(trainDataW);
328  forest->UsePointWiseWeight(true);
329  }
330  MITK_INFO << "Start training the forest";
331  forest->Train(trainDataX, trainDataY);
332 
333  MITK_INFO << "Save Forest";
334  mitk::IOUtil::Save(forest, forestPath);
335  } else
336  {
337  forest = mitk::IOUtil::Load<mitk::VigraRandomForestClassifier>(forestPath);// TODO forest.Load(forestPath);
338  }
339 
340  time(&now);
341  seconds = std::difftime(now, lastTimePoint);
342  MITK_INFO << "Duration for Training: " << seconds;
343  timingFile << seconds << ";";
344  time(&lastTimePoint);
346  // If required do Save Forest....
348 
349  //writer.// (forest);
350  /*
351  auto w = forest->GetTreeWeights();
352  w(0,0) = 10;
353  forest->SetTreeWeights(w);*/
354 
356  // If required do test
358  MITK_INFO << "Convert Test data";
359  auto testDataX = mitk::DCUtilities::DC3dDToMatrixXd(testCollection,modalities, testMask);
360 
361  MITK_INFO << "Predict Test Data";
362  auto testDataNewY = forest->Predict(testDataX);
363  auto testDataNewProb = forest->GetPointWiseProbabilities();
364  //MITK_INFO << testDataNewY;
365 
366  auto maxClassValue = testDataNewProb.cols();
367  std::vector<std::string> names;
368  for (int i = 0; i < maxClassValue; ++i)
369  {
370  std::string name = resultProb + std::to_string(i);
371  MITK_INFO << name;
372  names.push_back(name);
373  }
374  //names.push_back("prob-1");
375  //names.push_back("prob-2");
376 
377  mitk::DCUtilities::MatrixToDC3d(testDataNewY, testCollection, resultMask, testMask);
378  mitk::DCUtilities::MatrixToDC3d(testDataNewProb, testCollection, names, testMask);
379  MITK_INFO << "Converted predicted data";
380  //forest.SetMaskName(testMask);
381  //forest.SetCollection(testCollection);
382  //forest.Test();
383  //forest.PrintTree(0);
384 
385  time(&now);
386  seconds = std::difftime(now, lastTimePoint);
387  timingFile << seconds << ";";
388  time(&lastTimePoint);
389 
391  // Cost-based analysis
393 
394  // TODO Reactivate
395  //MITK_INFO << "Calculate Cost-based Statistic ";
396  //mitk::CostingStatistic costStat;
397  //costStat.SetCollection(testCollection);
398  //costStat.SetCombinedA("combinedHealty");
399  //costStat.SetCombinedB("combinedTumor");
400  //costStat.SetCombinedLabel("combinedLabel");
401  //costStat.SetMaskName(testMask);
409  //costStat.SetProbabilitiesA(labelGroupA);
410  //costStat.SetProbabilitiesB(labelGroupB);
411 
412  //std::ofstream costStatisticFile;
413  //costStatisticFile.open((statisticFilePath + ".cost").c_str(), std::ios::app);
414  //std::ofstream lcostStatisticFile;
415 
416  //lcostStatisticFile.open((statisticFilePath + ".longcost").c_str(), std::ios::app);
417  //costStat.WriteStatistic(lcostStatisticFile,costStatisticFile,2.5,statisticShortFileLabel);
418  //costStatisticFile.close();
419 
420  //costStat.CalculateClass(50);
421 
423  // Save results to folder
426  //outputFilter.push_back(resultMask);
427  //std::vector<std::string> propNames = forest.GetListOfProbabilityNames();
428  //outputFilter.insert(outputFilter.begin(), propNames.begin(), propNames.end());
429  MITK_INFO << "Write Result to HDD";
431  outputFolder + "/result_collection.xml",
432  outputFilter);
433 
434  MITK_INFO << "Calculate Statistic...." ;
436  // Calculate and Print Statistic
438  std::ofstream statisticFile;
439  statisticFile.open(statisticFilePath.c_str(), std::ios::app);
440  std::ofstream sstatisticFile;
441  sstatisticFile.open(statisticShortFilePath.c_str(), std::ios::app);
442 
444  stat.SetCollection(testCollection);
445  stat.SetClassCount(5);
446  stat.SetGoldName(statisticGoldStandard);
447  stat.SetTestName(resultMask);
448  stat.SetMaskName(testMask);
450  stat.SetGroundTruthValueToIndexMapper(&mapper);
451  stat.SetTestValueToIndexMapper(&mapper);
452  stat.Update();
453  //stat.Print(statisticFile,sstatisticFile,statisticWithHeader, statisticShortFileLabel);
454  stat.Print(statisticFile,sstatisticFile,true, statisticShortFileLabel);
455  statisticFile.close();
456 
457  time(&now);
458  seconds = std::difftime(now, lastTimePoint);
459  timingFile << seconds << std::endl;
460  time(&lastTimePoint);
461  timingFile.close();
462  }
463  catch ( const std::string s )
464  {
465  MITK_INFO << s;
466  return 0;
467  }
468  catch (char* s)
469  {
470  MITK_INFO << s;
471  }
472 
473  return 0;
474 }
475 
476 #endif
void SetWeightName(std::string name)
DataCollection::Pointer LoadCollection(const std::string &xmlFileName)
Build up a mitk::DataCollection from a XML resource.
#define MITK_INFO
Definition: mitkLogMacros.h:18
void SetDataItemNames(std::vector< std::string > itemNames)
void Print(std::ostream &out, std::ostream &sout=std::cout, bool withHeader=false, std::string label="None")
void SetTestName(std::string name)
void SetTestValueToIndexMapper(const ValueToIndexMapper *mapper)
void SetGroundTruthValueToIndexMapper(const ValueToIndexMapper *mapper)
int IntValue(const std::string &section, const std::string &entry) const
void SetTrainMask(std::string name)
void SetMaskName(std::string name)
void SetTestMask(std::string name)
void SetGoldName(std::string name)
static void MatrixToDC3d(const Eigen::MatrixXd &matrix, mitk::DataCollection::Pointer dc, const std::vector< std::string > &names, std::string mask)
static Eigen::MatrixXi DC3dDToMatrixXi(mitk::DataCollection::Pointer dc, std::string name, std::string mask)
void SetCollection(DataCollection::Pointer collection)
void SetClassCount(vcl_size_t count)
void SetModalities(std::vector< std::string > modalities)
static bool ExportCollectionToFolder(DataCollection *dataCollection, std::string xmlFile, std::vector< std::string > filter)
ExportCollectionToFolder.
static Eigen::MatrixXd DC3dDToMatrixXd(mitk::DataCollection::Pointer dc, std::string names, std::string mask)
void ReadStream(std::istream &stream)
int main(int argc, char *argv[])
void SetCollection(DataCollection::Pointer data)
void AddDataElementIds(std::vector< std::string > dataElemetIds)
static void Save(const mitk::BaseData *data, const std::string &path, bool setPathProperty=false)
Save a mitk::BaseData instance.
Definition: mitkIOUtil.cpp:774
std::vector< std::string > Vector(std::string const &section, unsigned int index) const
static const char * replace[]
This is a dictionary to replace long names of classes, modules, etc. to shorter versions in the conso...
void ReadFile(std::string const &filePath)
std::string Value(std::string const &section, std::string const &entry) const