12 #ifndef mitkForest_cpp 13 #define mitkForest_cpp 24 #include <vtkSmartPointer.h> 47 int main(
int argc,
char* argv[])
49 MITK_INFO <<
"Starting MITK_Forest Mini-App";
50 double startTime = time(0);
59 for (
int i = 0; i < argc; ++i )
64 if (argv[i][0] ==
'+')
74 catch (std::exception &e)
82 std::string input = argv[i];
84 ss << input << std::endl;
94 int currentRun = allConfig.
IntValue(
"General",
"Run",0);
95 int doTraining = allConfig.
IntValue(
"General",
"Do Training",1);
96 std::string forestPath = allConfig.
Value(
"General",
"Forest Path");
97 std::string trainingCollectionPath = allConfig.
Value(
"General",
"Patient Collection");
98 std::string testCollectionPath = trainingCollectionPath;
103 std::vector<std::string> trainPatients = allConfig.
Vector(
"Training Group",currentRun);
104 std::vector<std::string> testPatients = allConfig.
Vector(
"Test Group",currentRun);
105 std::vector<std::string> modalities = allConfig.
Vector(
"Modalities",0);
106 std::string trainMask = allConfig.
Value(
"Data",
"Training Mask");
107 std::string completeTrainMask = allConfig.
Value(
"Data",
"Complete Training Mask");
108 std::string testMask = allConfig.
Value(
"Data",
"Test Mask");
109 std::string resultMask = allConfig.
Value(
"Data",
"Result Mask");
110 std::string resultProb = allConfig.
Value(
"Data",
"Result Propability");
111 std::string outputFolder = allConfig.
Value(
"General",
"Output Folder");
113 std::string writeDataFilePath = allConfig.
Value(
"Forest",
"File to write data to");
118 int minimumSplitNodeSize = allConfig.
IntValue(
"Forest",
"Minimum split node size",1);
119 int numberOfTrees = allConfig.
IntValue(
"Forest",
"Number of Trees",255);
120 double samplesPerTree = atof(allConfig.
Value(
"Forest",
"Samples per Tree").c_str());
121 if (samplesPerTree <= 0.0000001)
123 samplesPerTree = 1.0;
125 MITK_INFO <<
"Samples per Tree: " << samplesPerTree;
126 int sampleWithReplacement = allConfig.
IntValue(
"Forest",
"Sample with replacement",1);
127 double trainPrecision = atof(allConfig.
Value(
"Forest",
"Precision").c_str());
128 if (trainPrecision <= 0.0000000001)
130 trainPrecision = 0.0;
132 double weightLambda = atof(allConfig.
Value(
"Forest",
"Weight Lambda").c_str());
133 if (weightLambda <= 0.0000000001)
137 int maximumTreeDepth = allConfig.
IntValue(
"Forest",
"Maximum Tree Depth",10000);
138 int randomSplit = allConfig.
IntValue(
"Forest",
"Use RandomSplit",0);
142 std::string statisticFilePath = allConfig.
Value(
"Evaluation",
"Statistic output file");
143 std::string statisticShortFilePath = allConfig.
Value(
"Evaluation",
"Statistic short output file");
144 std::string statisticShortFileLabel = allConfig.
Value(
"Evaluation",
"Index for short file");
145 std::string statisticGoldStandard = allConfig.
Value(
"Evaluation",
"Gold Standard Name",
"GTV");
146 bool statisticWithHeader = allConfig.
IntValue(
"Evaluation",
"Write header in short file",0);
147 std::vector<std::string> labelGroupA = allConfig.
Vector(
"LabelsA",0);
148 std::vector<std::string> labelGroupB = allConfig.
Vector(
"LabelsB",0);
152 bool useWeightedPoints = allConfig.
IntValue(
"Forest",
"Use point-based weighting",0);
153 bool writePointsToFile = allConfig.
IntValue(
"Forest",
"Write points to file",0);
154 int importanceWeightAlgorithm = allConfig.
IntValue(
"Forest",
"Importance weight Algorithm",0);
155 std::string importanceWeightName = allConfig.
Value(
"Forest",
"Importance weight name",
"");
157 std::ofstream timingFile;
158 timingFile.open((statisticFilePath +
".timing").c_str(), std::ios::app);
159 timingFile << statisticShortFileLabel <<
";";
160 std::time_t lastTimePoint;
161 time(&lastTimePoint);
166 std::vector<std::string> usedModalities;
167 for (
int i = 0; i < modalities.size(); ++i)
169 usedModalities.push_back(modalities[i]);
171 usedModalities.push_back(trainMask);
172 usedModalities.push_back(completeTrainMask);
173 usedModalities.push_back(testMask);
174 usedModalities.push_back(statisticGoldStandard);
175 usedModalities.push_back(importanceWeightName);
185 trainCollection = colReader->
LoadCollection(trainingCollectionPath);
193 double seconds = std::difftime(now, lastTimePoint);
194 timingFile << seconds <<
";";
195 time(&lastTimePoint);
220 forest->SetSamplesPerTree(samplesPerTree);
221 forest->SetMinimumSplitNodeSize(minimumSplitNodeSize);
222 forest->SetTreeCount(numberOfTrees);
223 forest->UseSampleWithReplacement(sampleWithReplacement);
224 forest->SetPrecision(trainPrecision);
225 forest->SetMaximumTreeDepth(maximumTreeDepth);
226 forest->SetWeightLambda(weightLambda);
242 if (useWeightedPoints)
245 MITK_INFO <<
"Activated Point-based weighting...";
247 forest->UsePointWiseWeight(
true);
308 forest->SetPointWiseWeight(trainDataW);
309 forest->UsePointWiseWeight(
true);
311 forest->Train(trainDataX, trainDataY);
319 seconds = std::difftime(now, lastTimePoint);
320 timingFile << seconds <<
";";
321 time(&lastTimePoint);
337 auto testDataNewY = forest->Predict(testDataX);
338 auto testDataNewProb = forest->GetPointWiseProbabilities();
341 std::vector<std::string> names;
342 names.push_back(
"prob-1");
343 names.push_back(
"prob-2");
353 seconds = std::difftime(now, lastTimePoint);
354 timingFile << seconds <<
";";
355 time(&lastTimePoint);
392 std::vector<std::string> outputFilter;
397 outputFolder +
"/result_collection.xml",
404 std::ofstream statisticFile;
405 statisticFile.open(statisticFilePath.c_str(), std::ios::app);
406 std::ofstream sstatisticFile;
407 sstatisticFile.open(statisticShortFilePath.c_str(), std::ios::app);
420 stat.
Print(statisticFile,sstatisticFile,
true, statisticShortFileLabel);
421 statisticFile.close();
425 seconds = std::difftime(now, lastTimePoint);
426 timingFile << seconds << std::endl;
427 time(&lastTimePoint);
430 catch (std::string s)
void SetWeightName(std::string name)
DataCollection::Pointer LoadCollection(const std::string &xmlFileName)
Build up a mitk::DataCollection from a XML resource.
void ClearDataElementIds()
void SetDataItemNames(std::vector< std::string > itemNames)
void Print(std::ostream &out, std::ostream &sout=std::cout, bool withHeader=false, std::string label="None")
void SetTestName(std::string name)
void SetTestValueToIndexMapper(const ValueToIndexMapper *mapper)
void SetGroundTruthValueToIndexMapper(const ValueToIndexMapper *mapper)
int IntValue(const std::string §ion, const std::string &entry) const
void SetTrainMask(std::string name)
void SetMaskName(std::string name)
void SetTestMask(std::string name)
void SetGoldName(std::string name)
static void MatrixToDC3d(const Eigen::MatrixXd &matrix, mitk::DataCollection::Pointer dc, const std::vector< std::string > &names, std::string mask)
static Eigen::MatrixXi DC3dDToMatrixXi(mitk::DataCollection::Pointer dc, std::string name, std::string mask)
void SetCollection(DataCollection::Pointer collection)
void SetClassCount(vcl_size_t count)
void SetModalities(std::vector< std::string > modalities)
static bool ExportCollectionToFolder(DataCollection *dataCollection, std::string xmlFile, std::vector< std::string > filter)
ExportCollectionToFolder.
static Eigen::MatrixXd DC3dDToMatrixXd(mitk::DataCollection::Pointer dc, std::string names, std::string mask)
void ReadStream(std::istream &stream)
void SetCollection(DataCollection::Pointer data)
void AddDataElementIds(std::vector< std::string > dataElemetIds)
std::vector< std::string > Vector(std::string const §ion, unsigned int index) const
static const char * replace[]
This is a dictionary to replace long names of classes, modules, etc. to shorter versions in the conso...
int main(int argc, char *argv[])
void ReadFile(std::string const &filePath)
std::string Value(std::string const §ion, std::string const &entry) const