26 #include <vtkSmartPointer.h>
39 int main(
int argc,
char* argv[])
43 parser.
setTitle(
"Simple Random Forest Classifier");
52 "DataCollection File");
55 "Patient Identifiers from DataCollection used for training");
57 "Patient Identifier from DataCollection used for testing");
61 "Output file for stats");
67 "human readable name for configuration");
69 "output folder for results");
72 "name of class that is to be learnt");
75 std::map<std::string, us::Any> parsedArgs = parser.
parseArguments(argc, argv);
77 if (parsedArgs.size()==0 || parsedArgs.count(
"help") || parsedArgs.count(
"h")) {
83 bool useStatsFile =
false;
84 unsigned int forestSize = 8;
85 unsigned int treeDepth = 10;
86 std::string configName =
"";
87 std::string outputFolder =
"";
90 std::vector<std::string> trainingIds;
91 std::vector<std::string> testingIds;
92 std::vector<std::string> loadIds;
93 std::string outputFile;
96 std::ofstream experimentFS;
100 if (parsedArgs.count(
"colIds") || parsedArgs.count(
"c")) {
101 std::istringstream ss(us::any_cast<std::string>(parsedArgs[
"colIds"]));
104 while (std::getline(ss, token,
','))
105 trainingIds.push_back(token);
108 if (parsedArgs.count(
"output") || parsedArgs.count(
"o")) {
109 outputFolder =
us::any_cast<std::string>(parsedArgs[
"output"]);
112 if (parsedArgs.count(
"classmap") || parsedArgs.count(
"m")) {
113 classMap =
us::any_cast<std::string>(parsedArgs[
"classmap"]);
116 if (parsedArgs.count(
"configName") || parsedArgs.count(
"n")) {
117 configName =
us::any_cast<std::string>(parsedArgs[
"configName"]);
120 if (parsedArgs.count(
"features") || parsedArgs.count(
"b")) {
121 std::istringstream ss(us::any_cast<std::string>(parsedArgs[
"features"]));
124 while (std::getline(ss, token,
','))
125 features.push_back(token);
128 if (parsedArgs.count(
"treeDepth") || parsedArgs.count(
"d")) {
133 if (parsedArgs.count(
"forestSize") || parsedArgs.count(
"f")) {
134 forestSize =
us::any_cast<
int>(parsedArgs[
"forestSize"]);
137 if (parsedArgs.count(
"stats") || parsedArgs.count(
"s")) {
139 experimentFS.open(us::any_cast<std::string>(parsedArgs[
"stats"]).c_str(),
144 if (parsedArgs.count(
"testId") || parsedArgs.count(
"t")) {
145 std::istringstream ss(us::any_cast<std::string>(parsedArgs[
"testId"]));
148 while (std::getline(ss, token,
','))
149 testingIds.push_back(token);
152 for (
unsigned int i = 0; i < features.size(); i++) {
153 loadIds.push_back(features.at(i));
155 loadIds.push_back(classMap);
157 if (parsedArgs.count(
"stats") || parsedArgs.count(
"s")) {
158 outputFile =
us::any_cast<std::string>(parsedArgs[
"stats"]);
161 if (parsedArgs.count(
"loadFile") || parsedArgs.count(
"l")) {
162 xmlFile =
us::any_cast<std::string>(parsedArgs[
"loadFile"]);
191 forest->SetTreeCount(forestSize);
192 forest->SetMaximumTreeDepth(treeDepth);
200 forest->Train(trainDataX, trainDataY);
205 auto testDataNewY = forest->Predict(testDataX);
210 Eigen::MatrixXd Probs = forest->GetPointWiseProbabilities();
213 Eigen::MatrixXd prob0 = Probs.col(0);
214 Eigen::MatrixXd prob1 = Probs.col(1);
220 std::vector<std::string> outputFilter;
222 outputFolder +
"/result_collection.xml",
DataCollection::Pointer LoadCollection(const std::string &xmlFileName)
Build up a mitk::DataCollection from a XML resource.
void ClearDataElementIds()
void SetDataItemNames(std::vector< std::string > itemNames)
void setContributor(std::string contributor)
int main(int argc, char *argv[])
ValueType * any_cast(Any *operand)
std::map< std::string, us::Any > parseArguments(const StringContainerType &arguments, bool *ok=nullptr)
void addArgument(const std::string &longarg, const std::string &shortarg, Type type, const std::string &argLabel, const std::string &argHelp=std::string(), const us::Any &defaultValue=us::Any(), bool optional=true, bool ignoreRest=false, bool deprecated=false)
static void MatrixToDC3d(const Eigen::MatrixXd &matrix, mitk::DataCollection::Pointer dc, const std::vector< std::string > &names, std::string mask)
static Eigen::MatrixXi DC3dDToMatrixXi(mitk::DataCollection::Pointer dc, std::string name, std::string mask)
void setCategory(std::string category)
static bool ExportCollectionToFolder(DataCollection *dataCollection, std::string xmlFile, std::vector< std::string > filter)
ExportCollectionToFolder.
void setArgumentPrefix(const std::string &longPrefix, const std::string &shortPrefix)
static Eigen::MatrixXd DC3dDToMatrixXd(mitk::DataCollection::Pointer dc, std::string names, std::string mask)
void AddSubColIds(std::vector< std::string > subColIds)
std::string helpText() const
void setTitle(std::string title)
void setDescription(std::string description)
static itkEventMacro(BoundingShapeInteractionEvent, itk::AnyEvent) class MITKBOUNDINGSHAPE_EXPORT BoundingShapeInteractor Pointer New()
Basic interaction methods for mitk::GeometryData.