22 #include <vtkSmartPointer.h> 35 int main(
int argc,
char* argv[])
39 parser.
setTitle(
"Simple Random Forest Classifier");
51 "Patient Identifiers from DataCollection used for training");
53 "Patient Identifier from DataCollection used for testing");
57 "Output file for stats");
63 "human readable name for configuration");
68 "name of class that is to be learnt");
71 std::map<std::string, us::Any> parsedArgs = parser.
parseArguments(argc, argv);
73 if (parsedArgs.size()==0 || parsedArgs.count(
"help") || parsedArgs.count(
"h")) {
79 unsigned int forestSize = 8;
80 unsigned int treeDepth = 10;
81 std::string configName =
"";
82 std::string outputFolder =
"";
85 std::vector<std::string> trainingIds;
86 std::vector<std::string> testingIds;
87 std::vector<std::string> loadIds;
88 std::string outputFile;
91 std::ofstream experimentFS;
95 if (parsedArgs.count(
"colIds") || parsedArgs.count(
"c")) {
96 std::istringstream ss(us::any_cast<std::string>(parsedArgs[
"colIds"]));
99 while (std::getline(ss, token,
','))
100 trainingIds.push_back(token);
103 if (parsedArgs.count(
"output") || parsedArgs.count(
"o")) {
104 outputFolder =
us::any_cast<std::string>(parsedArgs[
"output"]);
107 if (parsedArgs.count(
"classmap") || parsedArgs.count(
"m")) {
108 classMap =
us::any_cast<std::string>(parsedArgs[
"classmap"]);
111 if (parsedArgs.count(
"configName") || parsedArgs.count(
"n")) {
112 configName =
us::any_cast<std::string>(parsedArgs[
"configName"]);
115 if (parsedArgs.count(
"features") || parsedArgs.count(
"b")) {
116 std::istringstream ss(us::any_cast<std::string>(parsedArgs[
"features"]));
119 while (std::getline(ss, token,
','))
120 features.push_back(token);
123 if (parsedArgs.count(
"treeDepth") || parsedArgs.count(
"d")) {
128 if (parsedArgs.count(
"forestSize") || parsedArgs.count(
"f")) {
129 forestSize =
us::any_cast<
int>(parsedArgs[
"forestSize"]);
132 if (parsedArgs.count(
"stats") || parsedArgs.count(
"s")) {
133 experimentFS.open(us::any_cast<std::string>(parsedArgs[
"stats"]).c_str(),
138 if (parsedArgs.count(
"testId") || parsedArgs.count(
"t")) {
139 std::istringstream ss(us::any_cast<std::string>(parsedArgs[
"testId"]));
142 while (std::getline(ss, token,
','))
143 testingIds.push_back(token);
146 for (
unsigned int i = 0; i < features.size(); i++) {
147 loadIds.push_back(features.at(i));
149 loadIds.push_back(classMap);
151 if (parsedArgs.count(
"stats") || parsedArgs.count(
"s")) {
152 outputFile =
us::any_cast<std::string>(parsedArgs[
"stats"]);
155 if (parsedArgs.count(
"loadFile") || parsedArgs.count(
"l")) {
156 xmlFile =
us::any_cast<std::string>(parsedArgs[
"loadFile"]);
185 forest->SetTreeCount(forestSize);
186 forest->SetMaximumTreeDepth(treeDepth);
194 forest->Train(trainDataX, trainDataY);
199 auto testDataNewY = forest->Predict(testDataX);
204 Eigen::MatrixXd Probs = forest->GetPointWiseProbabilities();
207 Eigen::MatrixXd prob0 = Probs.col(0);
208 Eigen::MatrixXd prob1 = Probs.col(1);
214 std::vector<std::string> outputFilter;
216 outputFolder +
"/result_collection.xml",
DataCollection::Pointer LoadCollection(const std::string &xmlFileName)
Build up a mitk::DataCollection from a XML resource.
void ClearDataElementIds()
void SetDataItemNames(std::vector< std::string > itemNames)
void setContributor(std::string contributor)
int main(int argc, char *argv[])
ValueType * any_cast(Any *operand)
void addArgument(const std::string &longarg, const std::string &shortarg, Type type, const std::string &argLabel, const std::string &argHelp=std::string(), const us::Any &defaultValue=us::Any(), bool optional=true, bool ignoreRest=false, bool deprecated=false, mitkCommandLineParser::Channel channel=mitkCommandLineParser::Channel::None)
std::map< std::string, us::Any > parseArguments(const StringContainerType &arguments, bool *ok=nullptr)
static void MatrixToDC3d(const Eigen::MatrixXd &matrix, mitk::DataCollection::Pointer dc, const std::vector< std::string > &names, std::string mask)
static Eigen::MatrixXi DC3dDToMatrixXi(mitk::DataCollection::Pointer dc, std::string name, std::string mask)
std::string helpText() const
void setCategory(std::string category)
static bool ExportCollectionToFolder(DataCollection *dataCollection, std::string xmlFile, std::vector< std::string > filter)
ExportCollectionToFolder.
void setArgumentPrefix(const std::string &longPrefix, const std::string &shortPrefix)
static Eigen::MatrixXd DC3dDToMatrixXd(mitk::DataCollection::Pointer dc, std::string names, std::string mask)
void AddSubColIds(std::vector< std::string > subColIds)
void setTitle(std::string title)
void setDescription(std::string description)