Medical Imaging Interaction Toolkit  2018.4.99-389bf124
Medical Imaging Interaction Toolkit
CLMultiForestPrediction.cpp
Go to the documentation of this file.
1 /*============================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center (DKFZ)
6 All rights reserved.
7 
8 Use of this source code is governed by a 3-clause BSD license that can be
9 found in the LICENSE file.
10 
11 ============================================================================*/
12 #ifndef mitkForest_cpp
13 #define mitkForest_cpp
14 
15 #include "time.h"
16 #include <sstream>
17 
18 #include <mitkConfigFileReader.h>
19 #include <mitkDataCollection.h>
20 #include <mitkCollectionReader.h>
21 #include <mitkCollectionWriter.h>
23 #include <mitkCostingStatistic.h>
24 #include <vtkSmartPointer.h>
25 #include <mitkIOUtil.h>
26 
28 #include <mitkRandomForestIO.h>
29 
30 // ----------------------- Forest Handling ----------------------
32 
33 
34 int main(int argc, char* argv[])
35 {
36  MITK_INFO << "Starting MITK_Forest Mini-App";
37 
39  // Read Console Input Parameter
41  ConfigFileReader allConfig(argv[1]);
42 
43  bool readFile = true;
44  std::stringstream ss;
45  for (int i = 0; i < argc; ++i )
46  {
47  MITK_INFO << "-----"<< argv[i]<<"------";
48  if (readFile)
49  {
50  if (argv[i][0] == '+')
51  {
52  readFile = false;
53  continue;
54  } else
55  {
56  try
57  {
58  allConfig.ReadFile(argv[i]);
59  }
60  catch ( const std::exception &e )
61  {
62  MITK_INFO << e.what();
63  }
64  }
65  }
66  else
67  {
68  std::string input = argv[i];
69  std::replace(input.begin(), input.end(),'_',' ');
70  ss << input << std::endl;
71  }
72  }
73  allConfig.ReadStream(ss);
74 
75  try
76  {
78  // General
80  int currentRun = allConfig.IntValue("General","Run",0);
81  //int doTraining = allConfig.IntValue("General","Do Training",1);
82  std::string forestPath = allConfig.Value("General","Forest Path");
83  std::string trainingCollectionPath = allConfig.Value("General","Patient Collection");
84  std::string testCollectionPath = allConfig.Value("General", "Patient Test Collection", trainingCollectionPath);
85 
87  // Read Default Classification
89  std::vector<std::string> trainPatients = allConfig.Vector("Training Group",currentRun);
90  std::vector<std::string> testPatients = allConfig.Vector("Test Group",currentRun);
91  std::vector<std::string> modalities = allConfig.Vector("Modalities", 0);
92  std::vector<std::string> outputFilter = allConfig.Vector("Output Filter", 0);
93  std::string trainMask = allConfig.Value("Data","Training Mask");
94  std::string completeTrainMask = allConfig.Value("Data","Complete Training Mask");
95  std::string testMask = allConfig.Value("Data","Test Mask");
96  std::string resultMask = allConfig.Value("Data", "Result Mask");
97  std::string resultProb = allConfig.Value("Data", "Result Propability");
98  std::string outputFolder = allConfig.Value("General","Output Folder");
99 
100  std::string writeDataFilePath = allConfig.Value("Forest","File to write data to");
101 
103  // Read Data Forest Parameter
105  int testSingleDataset = allConfig.IntValue("Data", "Test Single Dataset",0);
106  std::string singleDatasetName = allConfig.Value("Data", "Single Dataset Name", "none");
107  std::vector<std::string> forestVector = allConfig.Vector("Forests", 0);
108 
110  // Read Statistic Parameter
112  std::string statisticFilePath = allConfig.Value("Evaluation", "Statistic output file");
113  std::string statisticShortFilePath = allConfig.Value("Evaluation", "Statistic short output file");
114  std::string statisticShortFileLabel = allConfig.Value("Evaluation", "Index for short file");
115  std::string statisticGoldStandard = allConfig.Value("Evaluation", "Gold Standard Name","GTV");
116  //bool statisticWithHeader = allConfig.IntValue("Evaluation", "Write header in short file",0);
117  std::vector<std::string> labelGroupA = allConfig.Vector("LabelsA",0);
118  std::vector<std::string> labelGroupB = allConfig.Vector("LabelsB",0);
119 
120 
121  std::ofstream timingFile;
122  timingFile.open((statisticFilePath + ".timing").c_str(), std::ios::app);
123  timingFile << statisticShortFileLabel << ";";
124  std::time_t lastTimePoint;
125  time(&lastTimePoint);
126 
128  // Read Images
130  std::vector<std::string> usedModalities;
131  for (std::size_t i = 0; i < modalities.size(); ++i)
132  {
133  usedModalities.push_back(modalities[i]);
134  }
135  usedModalities.push_back(trainMask);
136  usedModalities.push_back(completeTrainMask);
137  usedModalities.push_back(testMask);
138  usedModalities.push_back(statisticGoldStandard);
139 
140  // vtkSmartPointer<mitk::CollectionReader> colReader = vtkSmartPointer<mitk::CollectionReader>::New();
142  colReader->AddDataElementIds(trainPatients);
143  colReader->SetDataItemNames(usedModalities);
144 
145  if (testSingleDataset > 0)
146  {
147  testPatients.clear();
148  testPatients.push_back(singleDatasetName);
149  }
150  colReader->ClearDataElementIds();
151  colReader->AddDataElementIds(testPatients);
152  mitk::DataCollection::Pointer testCollection = colReader->LoadCollection(testCollectionPath);
153 
154  std::time_t now;
155  time(&now);
156  double seconds = std::difftime(now, lastTimePoint);
157  timingFile << seconds << ";";
158  time(&lastTimePoint);
159 
160 
162  MITK_INFO << "Convert Test data";
163  auto testDataX = mitk::DCUtilities::DC3dDToMatrixXd(testCollection, modalities, testMask);
164 
165  for (std::size_t i = 0; i < forestVector.size(); ++i)
166  {
167  forest = mitk::IOUtil::Load<mitk::VigraRandomForestClassifier>(forestVector[i]);
168 
169  time(&now);
170  seconds = std::difftime(now, lastTimePoint);
171  MITK_INFO << "Duration for Training: " << seconds;
172  timingFile << seconds << ";";
173  time(&lastTimePoint);
174 
175  MITK_INFO << "Predict Test Data";
176  auto testDataNewY = forest->Predict(testDataX);
177  auto testDataNewProb = forest->GetPointWiseProbabilities();
178 
179  auto maxClassValue = testDataNewProb.cols();
180  std::vector<std::string> names;
181  for (int j = 0; j < maxClassValue; ++j)
182  {
183  std::string name = resultProb + std::to_string(j);
184  names.push_back(name);
185  }
186 
187  mitk::DCUtilities::MatrixToDC3d(testDataNewY, testCollection, resultMask, testMask);
188  mitk::DCUtilities::MatrixToDC3d(testDataNewProb, testCollection, names, testMask);
189  MITK_INFO << "Converted predicted data";
190 
191  time(&now);
192  seconds = std::difftime(now, lastTimePoint);
193  timingFile << seconds << ";";
194  time(&lastTimePoint);
195 
197  // Save results to folder
199  MITK_INFO << "Write Result to HDD";
201  outputFolder + "/result_collection.xml",
202  outputFilter);
203 
204  MITK_INFO << "Calculate Statistic....";
206  // Calculate and Print Statistic
208  std::ofstream statisticFile;
209  statisticFile.open(statisticFilePath.c_str(), std::ios::app);
210  std::ofstream sstatisticFile;
211  sstatisticFile.open(statisticShortFilePath.c_str(), std::ios::app);
212 
214  stat.SetCollection(testCollection);
215  stat.SetClassCount(5);
216  stat.SetGoldName(statisticGoldStandard);
217  stat.SetTestName(resultMask);
218  stat.SetMaskName(testMask);
220  stat.SetGroundTruthValueToIndexMapper(&mapper);
221  stat.SetTestValueToIndexMapper(&mapper);
222  stat.Update();
223  //stat.Print(statisticFile,sstatisticFile,statisticWithHeader, statisticShortFileLabel);
224  stat.Print(statisticFile, sstatisticFile, true, statisticShortFileLabel + "_"+std::to_string(i));
225  statisticFile.close();
226 
227  time(&now);
228  seconds = std::difftime(now, lastTimePoint);
229  timingFile << seconds << std::endl;
230  time(&lastTimePoint);
231  timingFile.close();
232  }
233  }
234  catch (std::string s)
235  {
236  MITK_INFO << s;
237  return 0;
238  }
239  catch (char* s)
240  {
241  MITK_INFO << s;
242  }
243 
244  return 0;
245 }
246 
247 #endif
DataCollection::Pointer LoadCollection(const std::string &xmlFileName)
Build up a mitk::DataCollection from a XML resource.
#define MITK_INFO
Definition: mitkLogMacros.h:18
void SetDataItemNames(std::vector< std::string > itemNames)
void Print(std::ostream &out, std::ostream &sout=std::cout, bool withHeader=false, std::string label="None")
void SetTestName(std::string name)
void SetTestValueToIndexMapper(const ValueToIndexMapper *mapper)
int main(int argc, char *argv[])
void SetGroundTruthValueToIndexMapper(const ValueToIndexMapper *mapper)
int IntValue(const std::string &section, const std::string &entry) const
void SetMaskName(std::string name)
void SetGoldName(std::string name)
static void MatrixToDC3d(const Eigen::MatrixXd &matrix, mitk::DataCollection::Pointer dc, const std::vector< std::string > &names, std::string mask)
void SetCollection(DataCollection::Pointer collection)
void SetClassCount(vcl_size_t count)
static bool ExportCollectionToFolder(DataCollection *dataCollection, std::string xmlFile, std::vector< std::string > filter)
ExportCollectionToFolder.
static Eigen::MatrixXd DC3dDToMatrixXd(mitk::DataCollection::Pointer dc, std::string names, std::string mask)
void ReadStream(std::istream &stream)
void AddDataElementIds(std::vector< std::string > dataElemetIds)
std::vector< std::string > Vector(std::string const &section, unsigned int index) const
static const char * replace[]
This is a dictionary to replace long names of classes, modules, etc. to shorter versions in the conso...
void ReadFile(std::string const &filePath)
std::string Value(std::string const &section, std::string const &entry) const