Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
CLVoxelClassification.cpp
Go to the documentation of this file.
1 /*===================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
7 All rights reserved.
8 
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 A PARTICULAR PURPOSE.
12 
13 See LICENSE.txt or http://www.mitk.org for details.
14 
15 ===================================================================*/
16 #ifndef mitkForest_cpp
17 #define mitkForest_cpp
18 
19 #include "time.h"
20 #include <sstream>
21 
22 #include <mitkConfigFileReader.h>
23 #include <mitkDataCollection.h>
24 #include <mitkCollectionReader.h>
25 #include <mitkCollectionWriter.h>
27 #include <mitkCostingStatistic.h>
28 #include <vtkSmartPointer.h>
29 #include <mitkIOUtil.h>
30 
32 #include <mitkRandomForestIO.h>
33 
34 // ----------------------- Forest Handling ----------------------
35 //#include <mitkDecisionForest.h>
37 //#include <mitkThresholdSplit.h>
38 //#include <mitkImpurityLoss.h>
39 //#include <mitkLinearSplitting.h>
40 //#include <mitkVigraConverter.h>
41 // ----------------------- Point weighting ----------------------
42 //#include <mitkForestWeighting.h>
43 //#include <mitkKliepDensityEstimation.h>
44 //#include <mitkExternalWeighting.h>
46 //#include <mitkKNNDensityEstimation.h>
47 //#include <mitkZadroznyWeighting.h>
48 //#include <mitkSpectralDensityEstimation.h>
49 //#include <mitkULSIFDensityEstimation.h>
50 
51 int main(int argc, char* argv[])
52 {
53  MITK_INFO << "Starting MITK_Forest Mini-App";
54  double startTime = time(0);
55 
57  // Read Console Input Parameter
59  ConfigFileReader allConfig(argv[2]);
60 
61  bool readFile = true;
62  std::stringstream ss;
63  for (int i = 0; i < argc; ++i )
64  {
65  MITK_INFO << "-----"<< argv[i]<<"------";
66  if (readFile)
67  {
68  if (argv[i][0] == '+')
69  {
70  readFile = false;
71  continue;
72  } else
73  {
74  try
75  {
76  allConfig.ReadFile(argv[i]);
77  }
78  catch (std::exception &e)
79  {
80  MITK_INFO << e.what();
81  }
82  }
83  }
84  else
85  {
86  std::string input = argv[i];
87  std::replace(input.begin(), input.end(),'_',' ');
88  ss << input << std::endl;
89  }
90  }
91  allConfig.ReadStream(ss);
92 
93  try
94  {
96  // General
98  int currentRun = allConfig.IntValue("General","Run",0);
99  int doTraining = allConfig.IntValue("General","Do Training",1);
100  std::string forestPath = allConfig.Value("General","Forest Path");
101  std::string trainingCollectionPath = allConfig.Value("General","Patient Collection");
102  std::string testCollectionPath = trainingCollectionPath;
103  MITK_INFO << "Training collection: " << trainingCollectionPath;
104 
106  // Read Default Classification
108  std::vector<std::string> trainPatients = allConfig.Vector("Training Group",currentRun);
109  std::vector<std::string> testPatients = allConfig.Vector("Test Group",currentRun);
110  std::vector<std::string> modalities = allConfig.Vector("Modalities",0);
111  std::string trainMask = allConfig.Value("Data","Training Mask");
112  std::string completeTrainMask = allConfig.Value("Data","Complete Training Mask");
113  std::string testMask = allConfig.Value("Data","Test Mask");
114  std::string resultMask = allConfig.Value("Data", "Result Mask");
115  std::string resultProb = allConfig.Value("Data", "Result Propability");
116  std::string outputFolder = allConfig.Value("General","Output Folder");
117 
118  std::string writeDataFilePath = allConfig.Value("Forest","File to write data to");
119 
121  // Read Forest Parameter
123  int minimumSplitNodeSize = allConfig.IntValue("Forest", "Minimum split node size",1);
124  int numberOfTrees = allConfig.IntValue("Forest", "Number of Trees",255);
125  double samplesPerTree = atof(allConfig.Value("Forest", "Samples per Tree").c_str());
126  if (samplesPerTree <= 0.0000001)
127  {
128  samplesPerTree = 1.0;
129  }
130  MITK_INFO << "Samples per Tree: " << samplesPerTree;
131  int sampleWithReplacement = allConfig.IntValue("Forest", "Sample with replacement",1);
132  double trainPrecision = atof(allConfig.Value("Forest", "Precision").c_str());
133  if (trainPrecision <= 0.0000000001)
134  {
135  trainPrecision = 0.0;
136  }
137  double weightLambda = atof(allConfig.Value("Forest", "Weight Lambda").c_str());
138  if (weightLambda <= 0.0000000001)
139  {
140  weightLambda = 0.0;
141  }
142  int maximumTreeDepth = allConfig.IntValue("Forest", "Maximum Tree Depth",10000);
143  int randomSplit = allConfig.IntValue("Forest","Use RandomSplit",0);
145  // Read Statistic Parameter
147  std::string statisticFilePath = allConfig.Value("Evaluation", "Statistic output file");
148  std::string statisticShortFilePath = allConfig.Value("Evaluation", "Statistic short output file");
149  std::string statisticShortFileLabel = allConfig.Value("Evaluation", "Index for short file");
150  std::string statisticGoldStandard = allConfig.Value("Evaluation", "Gold Standard Name","GTV");
151  bool statisticWithHeader = allConfig.IntValue("Evaluation", "Write header in short file",0);
152  std::vector<std::string> labelGroupA = allConfig.Vector("LabelsA",0);
153  std::vector<std::string> labelGroupB = allConfig.Vector("LabelsB",0);
155  // Read Special Parameter
157  bool useWeightedPoints = allConfig.IntValue("Forest", "Use point-based weighting",0);
158  bool writePointsToFile = allConfig.IntValue("Forest", "Write points to file",0);
159  int importanceWeightAlgorithm = allConfig.IntValue("Forest","Importance weight Algorithm",0);
160  std::string importanceWeightName = allConfig.Value("Forest","Importance weight name","");
161 
162  std::ofstream timingFile;
163  timingFile.open((statisticFilePath + ".timing").c_str(), std::ios::app);
164  timingFile << statisticShortFileLabel << ";";
165  std::time_t lastTimePoint;
166  time(&lastTimePoint);
167 
169  // Read Images
171  std::vector<std::string> usedModalities;
172  for (int i = 0; i < modalities.size(); ++i)
173  {
174  usedModalities.push_back(modalities[i]);
175  }
176  usedModalities.push_back(trainMask);
177  usedModalities.push_back(completeTrainMask);
178  usedModalities.push_back(testMask);
179  usedModalities.push_back(statisticGoldStandard);
180  usedModalities.push_back(importanceWeightName);
181 
182  // vtkSmartPointer<mitk::CollectionReader> colReader = vtkSmartPointer<mitk::CollectionReader>::New();
184  colReader->AddDataElementIds(trainPatients);
185  colReader->SetDataItemNames(usedModalities);
186  //colReader->SetNames(usedModalities);
187  mitk::DataCollection::Pointer trainCollection;
188  if (doTraining)
189  {
190  trainCollection = colReader->LoadCollection(trainingCollectionPath);
191  }
192  colReader->ClearDataElementIds();
193  colReader->AddDataElementIds(testPatients);
194  mitk::DataCollection::Pointer testCollection = colReader->LoadCollection(testCollectionPath);
195 
196  std::time_t now;
197  time(&now);
198  double seconds = std::difftime(now, lastTimePoint);
199  timingFile << seconds << ";";
200  time(&lastTimePoint);
201 
202  /*
203  if (writePointsToFile)
204  {
205  MITK_INFO << "Use external weights...";
206  mitk::ExternalWeighting weightReader;
207  weightReader.SetModalities(modalities);
208  weightReader.SetTestCollection(testCollection);
209  weightReader.SetTrainCollection(trainCollection);
210  weightReader.SetTestMask(testMask);
211  weightReader.SetTrainMask(trainMask);
212  weightReader.SetWeightsName("weights");
213  weightReader.SetCorrectionFactor(1.0);
214  weightReader.SetWeightFileName(writeDataFilePath);
215  weightReader.WriteData();
216  return 0;
217  }*/
218 
220  // If required do Training....
222  //mitk::DecisionForest forest;
223 
225  forest->SetSamplesPerTree(samplesPerTree);
226  forest->SetMinimumSplitNodeSize(minimumSplitNodeSize);
227  forest->SetTreeCount(numberOfTrees);
228  forest->UseSampleWithReplacement(sampleWithReplacement);
229  forest->SetPrecision(trainPrecision);
230  forest->SetMaximumTreeDepth(maximumTreeDepth);
231  forest->SetWeightLambda(weightLambda);
232 
233  // TOOD forest.UseRandomSplit(randomSplit);
234 
235  if (doTraining)
236  {
237  // 0 = LR-Estimation
238  // 1 = KNN-Estimation
239  // 2 = Kliep
240  // 3 = Extern Image
241  // 4 = Zadrozny
242  // 5 = Spectral
243  // 6 = uLSIF
244  auto trainDataX = mitk::DCUtilities::DC3dDToMatrixXd(trainCollection, modalities, trainMask);
245  auto trainDataY = mitk::DCUtilities::DC3dDToMatrixXi(trainCollection, trainMask, trainMask);
246 
247  //if (useWeightedPoints)
248  if (false)
249  {
250  MITK_INFO << "Activated Point-based weighting...";
251  //forest.UseWeightedPoints(true);
252  forest->UsePointWiseWeight(true);
253  //forest.SetWeightName("calculated_weight");
254  /*if (importanceWeightAlgorithm == 1)
255  {
256  mitk::KNNDensityEstimation est;
257  est.SetCollection(trainCollection);
258  est.SetTrainMask(trainMask);
259  est.SetTestMask(testMask);
260  est.SetModalities(modalities);
261  est.SetWeightName("calculated_weight");
262  est.Update();
263  } else if (importanceWeightAlgorithm == 2)
264  {
265  mitk::KliepDensityEstimation est;
266  est.SetCollection(trainCollection);
267  est.SetTrainMask(trainMask);
268  est.SetTestMask(testMask);
269  est.SetModalities(modalities);
270  est.SetWeightName("calculated_weight");
271  est.Update();
272  } else if (importanceWeightAlgorithm == 3)
273  {
274  forest.SetWeightName(importanceWeightName);
275  } else if (importanceWeightAlgorithm == 4)
276  {
277  mitk::ZadroznyWeighting est;
278  est.SetCollection(trainCollection);
279  est.SetTrainMask(trainMask);
280  est.SetTestMask(testMask);
281  est.SetModalities(modalities);
282  est.SetWeightName("calculated_weight");
283  est.Update();
284  } else if (importanceWeightAlgorithm == 5)
285  {
286  mitk::SpectralDensityEstimation est;
287  est.SetCollection(trainCollection);
288  est.SetTrainMask(trainMask);
289  est.SetTestMask(testMask);
290  est.SetModalities(modalities);
291  est.SetWeightName("calculated_weight");
292  est.Update();
293  } else if (importanceWeightAlgorithm == 6)
294  {
295  mitk::ULSIFDensityEstimation est;
296  est.SetCollection(trainCollection);
297  est.SetTrainMask(trainMask);
298  est.SetTestMask(testMask);
299  est.SetModalities(modalities);
300  est.SetWeightName("calculated_weight");
301  est.Update();
302  } else*/
303  {
305  est.SetCollection(trainCollection);
306  est.SetTrainMask(trainMask);
307  est.SetTestMask(testMask);
308  est.SetModalities(modalities);
309  est.SetWeightName("calculated_weight");
310  est.Update();
311  }
312  auto trainDataW = mitk::DCUtilities::DC3dDToMatrixXd(trainCollection, "calculated_weight", trainMask);
313  forest->SetPointWiseWeight(trainDataW);
314  forest->UsePointWiseWeight(true);
315  }
316  forest->Train(trainDataX, trainDataY);
317  // TODO forest.Save(forestPath);
318  } else
319  {
320  // TODO forest.Load(forestPath);
321  }
322 
323  time(&now);
324  seconds = std::difftime(now, lastTimePoint);
325  timingFile << seconds << ";";
326  time(&lastTimePoint);
328  // If required do Save Forest....
330 
331  //writer.// (forest);
332  /*
333  auto w = forest->GetTreeWeights();
334  w(0,0) = 10;
335  forest->SetTreeWeights(w);*/
336  mitk::IOUtil::Save(forest,"d:/tmp/forest.forest");
337 
339  // If required do test
341  auto testDataX = mitk::DCUtilities::DC3dDToMatrixXd(testCollection,modalities, testMask);
342  auto testDataNewY = forest->Predict(testDataX);
343  //MITK_INFO << testDataNewY;
344 
345  mitk::DCUtilities::MatrixToDC3d(testDataNewY, testCollection, resultMask, testMask);
346  //forest.SetMaskName(testMask);
347  //forest.SetCollection(testCollection);
348  //forest.Test();
349  //forest.PrintTree(0);
350 
351  time(&now);
352  seconds = std::difftime(now, lastTimePoint);
353  timingFile << seconds << ";";
354  time(&lastTimePoint);
355 
357  // Cost-based analysis
359 
360  // TODO Reactivate
361  //MITK_INFO << "Calculate Cost-based Statistic ";
362  //mitk::CostingStatistic costStat;
363  //costStat.SetCollection(testCollection);
364  //costStat.SetCombinedA("combinedHealty");
365  //costStat.SetCombinedB("combinedTumor");
366  //costStat.SetCombinedLabel("combinedLabel");
367  //costStat.SetMaskName(testMask);
375  //costStat.SetProbabilitiesA(labelGroupA);
376  //costStat.SetProbabilitiesB(labelGroupB);
377 
378  //std::ofstream costStatisticFile;
379  //costStatisticFile.open((statisticFilePath + ".cost").c_str(), std::ios::app);
380  //std::ofstream lcostStatisticFile;
381 
382  //lcostStatisticFile.open((statisticFilePath + ".longcost").c_str(), std::ios::app);
383  //costStat.WriteStatistic(lcostStatisticFile,costStatisticFile,2.5,statisticShortFileLabel);
384  //costStatisticFile.close();
385 
386  //costStat.CalculateClass(50);
387 
389  // Save results to folder
391  std::vector<std::string> outputFilter;
392  //outputFilter.push_back(resultMask);
393  //std::vector<std::string> propNames = forest.GetListOfProbabilityNames();
394  //outputFilter.insert(outputFilter.begin(), propNames.begin(), propNames.end());
396  outputFolder + "/result_collection.xml",
397  outputFilter);
398 
399  MITK_INFO << "Calculate Statistic...." ;
401  // Calculate and Print Statistic
403  std::ofstream statisticFile;
404  statisticFile.open(statisticFilePath.c_str(), std::ios::app);
405  std::ofstream sstatisticFile;
406  sstatisticFile.open(statisticShortFilePath.c_str(), std::ios::app);
407 
409  stat.SetCollection(testCollection);
410  stat.SetClassCount(2);
411  stat.SetGoldName(statisticGoldStandard);
412  stat.SetTestName(resultMask);
413  stat.SetMaskName(testMask);
414  stat.Update();
415  //stat.Print(statisticFile,sstatisticFile,statisticWithHeader, statisticShortFileLabel);
416  stat.Print(statisticFile,sstatisticFile,true, statisticShortFileLabel);
417  statisticFile.close();
418 
419  time(&now);
420  seconds = std::difftime(now, lastTimePoint);
421  timingFile << seconds << std::endl;
422  time(&lastTimePoint);
423  timingFile.close();
424  }
425  catch (std::string s)
426  {
427  MITK_INFO << s;
428  return 0;
429  }
430  catch (char* s)
431  {
432  MITK_INFO << s;
433  }
434 
435  return 0;
436 }
437 
438 #endif
std::vector< std::string > Vector(std::string const &section, unsigned int index) const
void SetWeightName(std::string name)
std::string Value(std::string const &section, std::string const &entry) const
static void Save(const mitk::BaseData *data, const std::string &path)
Save a mitk::BaseData instance.
Definition: mitkIOUtil.cpp:824
DataCollection::Pointer LoadCollection(const std::string &xmlFileName)
Build up a mitk::DataCollection from a XML resource.
#define MITK_INFO
Definition: mitkLogMacros.h:22
void SetDataItemNames(std::vector< std::string > itemNames)
void Print(std::ostream &out, std::ostream &sout=std::cout, bool withHeader=false, std::string label="None")
void SetTestName(std::string name)
void SetTrainMask(std::string name)
void SetMaskName(std::string name)
void SetTestMask(std::string name)
void SetGoldName(std::string name)
static void MatrixToDC3d(const Eigen::MatrixXd &matrix, mitk::DataCollection::Pointer dc, const std::vector< std::string > &names, std::string mask)
static Eigen::MatrixXi DC3dDToMatrixXi(mitk::DataCollection::Pointer dc, std::string name, std::string mask)
void SetCollection(DataCollection::Pointer collection)
void SetClassCount(vcl_size_t count)
void SetModalities(std::vector< std::string > modalities)
static bool ExportCollectionToFolder(DataCollection *dataCollection, std::string xmlFile, std::vector< std::string > filter)
ExportCollectionToFolder.
static Eigen::MatrixXd DC3dDToMatrixXd(mitk::DataCollection::Pointer dc, std::string names, std::string mask)
void ReadStream(std::istream &stream)
int main(int argc, char *argv[])
void SetCollection(DataCollection::Pointer data)
int IntValue(const std::string &section, const std::string &entry) const
void AddDataElementIds(std::vector< std::string > dataElemetIds)
static const char * replace[]
This is a dictionary to replace long names of classes, modules, etc. to shorter versions in the conso...
void ReadFile(std::string const &filePath)
static itkEventMacro(BoundingShapeInteractionEvent, itk::AnyEvent) class MITKBOUNDINGSHAPE_EXPORT BoundingShapeInteractor Pointer New()
Basic interaction methods for mitk::GeometryData.