16 #ifndef mitkForest_cpp
17 #define mitkForest_cpp
28 #include <vtkSmartPointer.h>
51 int main(
int argc,
char* argv[])
53 MITK_INFO <<
"Starting MITK_Forest Mini-App";
54 double startTime = time(0);
63 for (
int i = 0; i < argc; ++i )
68 if (argv[i][0] ==
'+')
78 catch (std::exception &e)
86 std::string input = argv[i];
88 ss << input << std::endl;
98 int currentRun = allConfig.
IntValue(
"General",
"Run",0);
99 int doTraining = allConfig.
IntValue(
"General",
"Do Training",1);
100 std::string forestPath = allConfig.
Value(
"General",
"Forest Path");
101 std::string trainingCollectionPath = allConfig.
Value(
"General",
"Patient Collection");
102 std::string testCollectionPath = trainingCollectionPath;
103 MITK_INFO <<
"Training collection: " << trainingCollectionPath;
108 std::vector<std::string> trainPatients = allConfig.
Vector(
"Training Group",currentRun);
109 std::vector<std::string> testPatients = allConfig.
Vector(
"Test Group",currentRun);
110 std::vector<std::string> modalities = allConfig.
Vector(
"Modalities",0);
111 std::string trainMask = allConfig.
Value(
"Data",
"Training Mask");
112 std::string completeTrainMask = allConfig.
Value(
"Data",
"Complete Training Mask");
113 std::string testMask = allConfig.
Value(
"Data",
"Test Mask");
114 std::string resultMask = allConfig.
Value(
"Data",
"Result Mask");
115 std::string resultProb = allConfig.
Value(
"Data",
"Result Propability");
116 std::string outputFolder = allConfig.
Value(
"General",
"Output Folder");
118 std::string writeDataFilePath = allConfig.
Value(
"Forest",
"File to write data to");
123 int minimumSplitNodeSize = allConfig.
IntValue(
"Forest",
"Minimum split node size",1);
124 int numberOfTrees = allConfig.
IntValue(
"Forest",
"Number of Trees",255);
125 double samplesPerTree = atof(allConfig.
Value(
"Forest",
"Samples per Tree").c_str());
126 if (samplesPerTree <= 0.0000001)
128 samplesPerTree = 1.0;
130 MITK_INFO <<
"Samples per Tree: " << samplesPerTree;
131 int sampleWithReplacement = allConfig.
IntValue(
"Forest",
"Sample with replacement",1);
132 double trainPrecision = atof(allConfig.
Value(
"Forest",
"Precision").c_str());
133 if (trainPrecision <= 0.0000000001)
135 trainPrecision = 0.0;
137 double weightLambda = atof(allConfig.
Value(
"Forest",
"Weight Lambda").c_str());
138 if (weightLambda <= 0.0000000001)
142 int maximumTreeDepth = allConfig.
IntValue(
"Forest",
"Maximum Tree Depth",10000);
143 int randomSplit = allConfig.
IntValue(
"Forest",
"Use RandomSplit",0);
147 std::string statisticFilePath = allConfig.
Value(
"Evaluation",
"Statistic output file");
148 std::string statisticShortFilePath = allConfig.
Value(
"Evaluation",
"Statistic short output file");
149 std::string statisticShortFileLabel = allConfig.
Value(
"Evaluation",
"Index for short file");
150 std::string statisticGoldStandard = allConfig.
Value(
"Evaluation",
"Gold Standard Name",
"GTV");
151 bool statisticWithHeader = allConfig.
IntValue(
"Evaluation",
"Write header in short file",0);
152 std::vector<std::string> labelGroupA = allConfig.
Vector(
"LabelsA",0);
153 std::vector<std::string> labelGroupB = allConfig.
Vector(
"LabelsB",0);
157 bool useWeightedPoints = allConfig.
IntValue(
"Forest",
"Use point-based weighting",0);
158 bool writePointsToFile = allConfig.
IntValue(
"Forest",
"Write points to file",0);
159 int importanceWeightAlgorithm = allConfig.
IntValue(
"Forest",
"Importance weight Algorithm",0);
160 std::string importanceWeightName = allConfig.
Value(
"Forest",
"Importance weight name",
"");
162 std::ofstream timingFile;
163 timingFile.open((statisticFilePath +
".timing").c_str(), std::ios::app);
164 timingFile << statisticShortFileLabel <<
";";
165 std::time_t lastTimePoint;
166 time(&lastTimePoint);
171 std::vector<std::string> usedModalities;
172 for (
int i = 0; i < modalities.size(); ++i)
174 usedModalities.push_back(modalities[i]);
176 usedModalities.push_back(trainMask);
177 usedModalities.push_back(completeTrainMask);
178 usedModalities.push_back(testMask);
179 usedModalities.push_back(statisticGoldStandard);
180 usedModalities.push_back(importanceWeightName);
190 trainCollection = colReader->
LoadCollection(trainingCollectionPath);
198 double seconds = std::difftime(now, lastTimePoint);
199 timingFile << seconds <<
";";
200 time(&lastTimePoint);
225 forest->SetSamplesPerTree(samplesPerTree);
226 forest->SetMinimumSplitNodeSize(minimumSplitNodeSize);
227 forest->SetTreeCount(numberOfTrees);
228 forest->UseSampleWithReplacement(sampleWithReplacement);
229 forest->SetPrecision(trainPrecision);
230 forest->SetMaximumTreeDepth(maximumTreeDepth);
231 forest->SetWeightLambda(weightLambda);
250 MITK_INFO <<
"Activated Point-based weighting...";
252 forest->UsePointWiseWeight(
true);
313 forest->SetPointWiseWeight(trainDataW);
314 forest->UsePointWiseWeight(
true);
316 forest->Train(trainDataX, trainDataY);
324 seconds = std::difftime(now, lastTimePoint);
325 timingFile << seconds <<
";";
326 time(&lastTimePoint);
342 auto testDataNewY = forest->Predict(testDataX);
352 seconds = std::difftime(now, lastTimePoint);
353 timingFile << seconds <<
";";
354 time(&lastTimePoint);
391 std::vector<std::string> outputFilter;
396 outputFolder +
"/result_collection.xml",
403 std::ofstream statisticFile;
404 statisticFile.open(statisticFilePath.c_str(), std::ios::app);
405 std::ofstream sstatisticFile;
406 sstatisticFile.open(statisticShortFilePath.c_str(), std::ios::app);
416 stat.
Print(statisticFile,sstatisticFile,
true, statisticShortFileLabel);
417 statisticFile.close();
420 seconds = std::difftime(now, lastTimePoint);
421 timingFile << seconds << std::endl;
422 time(&lastTimePoint);
425 catch (std::string s)
std::vector< std::string > Vector(std::string const §ion, unsigned int index) const
void SetWeightName(std::string name)
std::string Value(std::string const §ion, std::string const &entry) const
static void Save(const mitk::BaseData *data, const std::string &path)
Save a mitk::BaseData instance.
DataCollection::Pointer LoadCollection(const std::string &xmlFileName)
Build up a mitk::DataCollection from a XML resource.
void ClearDataElementIds()
void SetDataItemNames(std::vector< std::string > itemNames)
void Print(std::ostream &out, std::ostream &sout=std::cout, bool withHeader=false, std::string label="None")
void SetTestName(std::string name)
void SetTrainMask(std::string name)
void SetMaskName(std::string name)
void SetTestMask(std::string name)
void SetGoldName(std::string name)
static void MatrixToDC3d(const Eigen::MatrixXd &matrix, mitk::DataCollection::Pointer dc, const std::vector< std::string > &names, std::string mask)
static Eigen::MatrixXi DC3dDToMatrixXi(mitk::DataCollection::Pointer dc, std::string name, std::string mask)
void SetCollection(DataCollection::Pointer collection)
void SetClassCount(vcl_size_t count)
void SetModalities(std::vector< std::string > modalities)
static bool ExportCollectionToFolder(DataCollection *dataCollection, std::string xmlFile, std::vector< std::string > filter)
ExportCollectionToFolder.
static Eigen::MatrixXd DC3dDToMatrixXd(mitk::DataCollection::Pointer dc, std::string names, std::string mask)
void ReadStream(std::istream &stream)
int main(int argc, char *argv[])
void SetCollection(DataCollection::Pointer data)
int IntValue(const std::string §ion, const std::string &entry) const
void AddDataElementIds(std::vector< std::string > dataElemetIds)
static const char * replace[]
This is a dictionary to replace long names of classes, modules, etc. to shorter versions in the conso...
void ReadFile(std::string const &filePath)
static itkEventMacro(BoundingShapeInteractionEvent, itk::AnyEvent) class MITKBOUNDINGSHAPE_EXPORT BoundingShapeInteractor Pointer New()
Basic interaction methods for mitk::GeometryData.