12 #ifndef mitkForest_cpp 13 #define mitkForest_cpp 24 #include <vtkSmartPointer.h> 47 int main(
int argc,
char* argv[])
49 MITK_INFO <<
"Starting MITK_Forest Mini-App";
58 for (
int i = 0; i < argc; ++i )
63 if (argv[i][0] ==
'+')
73 catch (
const std::exception &e )
81 std::string input = argv[i];
83 ss << input << std::endl;
93 int currentRun = allConfig.
IntValue(
"General",
"Run",0);
94 int doTraining = allConfig.
IntValue(
"General",
"Do Training",1);
95 std::string forestPath = allConfig.
Value(
"General",
"Forest Path");
96 std::string trainingCollectionPath = allConfig.
Value(
"General",
"Patient Collection");
97 std::string testCollectionPath = allConfig.
Value(
"General",
"Patient Test Collection", trainingCollectionPath);
102 std::vector<std::string> trainPatients = allConfig.
Vector(
"Training Group",currentRun);
103 std::vector<std::string> testPatients = allConfig.
Vector(
"Test Group",currentRun);
104 std::vector<std::string> modalities = allConfig.
Vector(
"Modalities", 0);
105 std::vector<std::string> outputFilter = allConfig.
Vector(
"Output Filter", 0);
106 std::string trainMask = allConfig.
Value(
"Data",
"Training Mask");
107 std::string completeTrainMask = allConfig.
Value(
"Data",
"Complete Training Mask");
108 std::string testMask = allConfig.
Value(
"Data",
"Test Mask");
109 std::string resultMask = allConfig.
Value(
"Data",
"Result Mask");
110 std::string resultProb = allConfig.
Value(
"Data",
"Result Propability");
111 std::string outputFolder = allConfig.
Value(
"General",
"Output Folder");
113 std::string writeDataFilePath = allConfig.
Value(
"Forest",
"File to write data to");
118 int testSingleDataset = allConfig.
IntValue(
"Data",
"Test Single Dataset",0);
119 std::string singleDatasetName = allConfig.
Value(
"Data",
"Single Dataset Name",
"none");
120 int trainSingleDataset = allConfig.
IntValue(
"Data",
"Train Single Dataset", 0);
121 std::string singleTrainDatasetName = allConfig.
Value(
"Data",
"Train Single Dataset Name",
"none");
126 int minimumSplitNodeSize = allConfig.
IntValue(
"Forest",
"Minimum split node size",1);
127 int numberOfTrees = allConfig.
IntValue(
"Forest",
"Number of Trees",255);
128 double samplesPerTree = atof(allConfig.
Value(
"Forest",
"Samples per Tree").c_str());
129 if (samplesPerTree <= 0.0000001)
131 samplesPerTree = 1.0;
133 MITK_INFO <<
"Samples per Tree: " << samplesPerTree;
134 int sampleWithReplacement = allConfig.
IntValue(
"Forest",
"Sample with replacement",1);
135 double trainPrecision = atof(allConfig.
Value(
"Forest",
"Precision").c_str());
136 if (trainPrecision <= 0.0000000001)
138 trainPrecision = 0.0;
140 double weightLambda = atof(allConfig.
Value(
"Forest",
"Weight Lambda").c_str());
141 if (weightLambda <= 0.0000000001)
145 int maximumTreeDepth = allConfig.
IntValue(
"Forest",
"Maximum Tree Depth",10000);
150 std::string statisticFilePath = allConfig.
Value(
"Evaluation",
"Statistic output file");
151 std::string statisticShortFilePath = allConfig.
Value(
"Evaluation",
"Statistic short output file");
152 std::string statisticShortFileLabel = allConfig.
Value(
"Evaluation",
"Index for short file");
153 std::string statisticGoldStandard = allConfig.
Value(
"Evaluation",
"Gold Standard Name",
"GTV");
155 std::vector<std::string> labelGroupA = allConfig.
Vector(
"LabelsA",0);
156 std::vector<std::string> labelGroupB = allConfig.
Vector(
"LabelsB",0);
160 bool useWeightedPoints = allConfig.
IntValue(
"Forest",
"Use point-based weighting",0);
163 std::string importanceWeightName = allConfig.
Value(
"Forest",
"Importance weight name",
"");
165 std::ofstream timingFile;
166 timingFile.open((statisticFilePath +
".timing").c_str(), std::ios::app);
167 timingFile << statisticShortFileLabel <<
";";
168 std::time_t lastTimePoint;
169 time(&lastTimePoint);
174 std::vector<std::string> usedModalities;
175 for (std::size_t i = 0; i < modalities.size(); ++i)
177 usedModalities.push_back(modalities[i]);
179 usedModalities.push_back(trainMask);
180 usedModalities.push_back(completeTrainMask);
181 usedModalities.push_back(testMask);
182 usedModalities.push_back(statisticGoldStandard);
183 usedModalities.push_back(importanceWeightName);
185 if (trainSingleDataset > 0)
187 trainPatients.clear();
188 trainPatients.push_back(singleTrainDatasetName);
198 trainCollection = colReader->
LoadCollection(trainingCollectionPath);
201 if (testSingleDataset > 0)
203 testPatients.clear();
204 testPatients.push_back(singleDatasetName);
212 double seconds = std::difftime(now, lastTimePoint);
213 timingFile << seconds <<
";";
214 time(&lastTimePoint);
239 forest->SetSamplesPerTree(samplesPerTree);
240 forest->SetMinimumSplitNodeSize(minimumSplitNodeSize);
241 forest->SetTreeCount(numberOfTrees);
242 forest->UseSampleWithReplacement(sampleWithReplacement);
243 forest->SetPrecision(trainPrecision);
244 forest->SetMaximumTreeDepth(maximumTreeDepth);
245 forest->SetWeightLambda(weightLambda);
261 if (useWeightedPoints)
264 MITK_INFO <<
"Activated Point-based weighting...";
266 forest->UsePointWiseWeight(
true);
327 forest->SetPointWiseWeight(trainDataW);
328 forest->UsePointWiseWeight(
true);
330 MITK_INFO <<
"Start training the forest";
331 forest->Train(trainDataX, trainDataY);
337 forest = mitk::IOUtil::Load<mitk::VigraRandomForestClassifier>(forestPath);
341 seconds = std::difftime(now, lastTimePoint);
342 MITK_INFO <<
"Duration for Training: " << seconds;
343 timingFile << seconds <<
";";
344 time(&lastTimePoint);
362 auto testDataNewY = forest->Predict(testDataX);
363 auto testDataNewProb = forest->GetPointWiseProbabilities();
366 auto maxClassValue = testDataNewProb.cols();
367 std::vector<std::string> names;
368 for (
int i = 0; i < maxClassValue; ++i)
370 std::string name = resultProb + std::to_string(i);
372 names.push_back(name);
386 seconds = std::difftime(now, lastTimePoint);
387 timingFile << seconds <<
";";
388 time(&lastTimePoint);
431 outputFolder +
"/result_collection.xml",
438 std::ofstream statisticFile;
439 statisticFile.open(statisticFilePath.c_str(), std::ios::app);
440 std::ofstream sstatisticFile;
441 sstatisticFile.open(statisticShortFilePath.c_str(), std::ios::app);
454 stat.
Print(statisticFile,sstatisticFile,
true, statisticShortFileLabel);
455 statisticFile.close();
458 seconds = std::difftime(now, lastTimePoint);
459 timingFile << seconds << std::endl;
460 time(&lastTimePoint);
463 catch (
const std::string s )
void SetWeightName(std::string name)
DataCollection::Pointer LoadCollection(const std::string &xmlFileName)
Build up a mitk::DataCollection from a XML resource.
void ClearDataElementIds()
void SetDataItemNames(std::vector< std::string > itemNames)
void Print(std::ostream &out, std::ostream &sout=std::cout, bool withHeader=false, std::string label="None")
void SetTestName(std::string name)
void SetTestValueToIndexMapper(const ValueToIndexMapper *mapper)
void SetGroundTruthValueToIndexMapper(const ValueToIndexMapper *mapper)
int IntValue(const std::string §ion, const std::string &entry) const
void SetTrainMask(std::string name)
void SetMaskName(std::string name)
void SetTestMask(std::string name)
void SetGoldName(std::string name)
static void MatrixToDC3d(const Eigen::MatrixXd &matrix, mitk::DataCollection::Pointer dc, const std::vector< std::string > &names, std::string mask)
static Eigen::MatrixXi DC3dDToMatrixXi(mitk::DataCollection::Pointer dc, std::string name, std::string mask)
void SetCollection(DataCollection::Pointer collection)
void SetClassCount(vcl_size_t count)
void SetModalities(std::vector< std::string > modalities)
static bool ExportCollectionToFolder(DataCollection *dataCollection, std::string xmlFile, std::vector< std::string > filter)
ExportCollectionToFolder.
static Eigen::MatrixXd DC3dDToMatrixXd(mitk::DataCollection::Pointer dc, std::string names, std::string mask)
void ReadStream(std::istream &stream)
int main(int argc, char *argv[])
void SetCollection(DataCollection::Pointer data)
void AddDataElementIds(std::vector< std::string > dataElemetIds)
static void Save(const mitk::BaseData *data, const std::string &path, bool setPathProperty=false)
Save a mitk::BaseData instance.
std::vector< std::string > Vector(std::string const §ion, unsigned int index) const
static const char * replace[]
This is a dictionary to replace long names of classes, modules, etc. to shorter versions in the conso...
void ReadFile(std::string const &filePath)
std::string Value(std::string const §ion, std::string const &entry) const