Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
DFTraining.cpp
Go to the documentation of this file.
1 /*===================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
7 All rights reserved.
8 
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 A PARTICULAR PURPOSE.
12 
13 See LICENSE.txt or http://www.mitk.org for details.
14 
15 ===================================================================*/
16 
17 #include <mitkBaseData.h>
18 #include <mitkImageCast.h>
19 #include <mitkImageToItk.h>
20 #include <metaCommand.h>
21 #include "mitkCommandLineParser.h"
22 #include <usAny.h>
23 #include <itkImageFileWriter.h>
24 #include <mitkIOUtil.h>
25 #include <iostream>
26 #include <fstream>
27 #include <itksys/SystemTools.hxx>
28 #include <mitkCoreObjectFactory.h>
29 
30 #include <mitkFiberBundle.h>
32 
33 #define _USE_MATH_DEFINES
34 #include <math.h>
35 
36 using namespace std;
37 
41 int main(int argc, char* argv[])
42 {
43  MITK_INFO << "DFTraining";
44  mitkCommandLineParser parser;
45 
46  parser.setTitle("Training for Machine Learning Based Streamline Tractography");
47  parser.setCategory("Fiber Tracking and Processing Methods");
48  parser.setDescription("Train random forest classifier for machine learning based streamline tractography");
49  parser.setContributor("MBI");
50 
51  parser.setArgumentPrefix("--", "-");
52  parser.addArgument("images", "i", mitkCommandLineParser::StringList, "DWIs:", "input diffusion-weighted images", us::Any(), false);
53  parser.addArgument("tractograms", "t", mitkCommandLineParser::StringList, "Tractograms:", "input training tractograms (.fib, vtk ascii file format)", us::Any(), false);
54  parser.addArgument("forest", "f", mitkCommandLineParser::OutputFile, "Forest:", "output random forest (HDF5)", us::Any(), false);
55 
56  parser.addArgument("masks", "m", mitkCommandLineParser::StringList, "Masks:", "restrict trining using a binary mask image", us::Any());
57  parser.addArgument("wmmasks", "w", mitkCommandLineParser::StringList, "WM-Masks:", "if no binary white matter mask is specified, the envelope of the input tractogram is used", us::Any());
58 
59  parser.addArgument("stepsize", "s", mitkCommandLineParser::Float, "Stepsize:", "resampling parameter for the input tractogram in mm (determines number of white-matter samples)", us::Any());
60  parser.addArgument("gmsamples", "g", mitkCommandLineParser::Int, "Number of gray matter samples per voxel:", "Number of gray matter samples per voxel", us::Any());
61  parser.addArgument("numtrees", "n", mitkCommandLineParser::Int, "Number of trees:", "number of trees", us::Any());
62  parser.addArgument("max_tree_depth", "d", mitkCommandLineParser::Int, "Max. tree depth:", "maximum tree depth", us::Any());
63  parser.addArgument("sample_fraction", "sf", mitkCommandLineParser::Float, "Sample fraction:", "fraction of samples used per tree", us::Any());
64 
65  map<string, us::Any> parsedArgs = parser.parseArguments(argc, argv);
66  if (parsedArgs.size()==0)
67  return EXIT_FAILURE;
68 
71  if (parsedArgs.count("wmmasks"))
72  wmMaskFiles = us::any_cast<mitkCommandLineParser::StringContainerType>(parsedArgs["wmmasks"]);
73 
75  if (parsedArgs.count("masks"))
76  maskFiles = us::any_cast<mitkCommandLineParser::StringContainerType>(parsedArgs["masks"]);
77 
78  string forestFile = us::any_cast<string>(parsedArgs["forest"]);
79 
81  if (parsedArgs.count("tractograms"))
82  tractogramFiles = us::any_cast<mitkCommandLineParser::StringContainerType>(parsedArgs["tractograms"]);
83 
84  int numTrees = 50;
85  if (parsedArgs.count("numtrees"))
86  numTrees = us::any_cast<int>(parsedArgs["numtrees"]);
87 
88  int gmsamples = -1;
89  if (parsedArgs.count("gmsamples"))
90  gmsamples = us::any_cast<int>(parsedArgs["gmsamples"]);
91 
92  float stepsize = -1;
93  if (parsedArgs.count("stepsize"))
94  stepsize = us::any_cast<float>(parsedArgs["stepsize"]);
95 
96  int max_tree_depth = 25;
97  if (parsedArgs.count("max_tree_depth"))
98  max_tree_depth = us::any_cast<int>(parsedArgs["max_tree_depth"]);
99 
100  double sample_fraction = 0.6;
101  if (parsedArgs.count("sample_fraction"))
102  sample_fraction = us::any_cast<float>(parsedArgs["sample_fraction"]);
103 
104 
105  MITK_INFO << "loading diffusion-weighted images";
106  std::vector< mitk::Image::Pointer > rawData;
107  for (auto imgFile : imageFiles)
108  {
109  mitk::Image::Pointer dwi = dynamic_cast<mitk::Image*>(mitk::IOUtil::LoadImage(imgFile).GetPointer());
110  rawData.push_back(dwi);
111  }
112 
113  typedef itk::Image<unsigned char, 3> ItkUcharImgType;
114  MITK_INFO << "loading mask images";
115  std::vector< ItkUcharImgType::Pointer > maskImageVector;
116  for (auto maskFile : maskFiles)
117  {
118  mitk::Image::Pointer img = dynamic_cast<mitk::Image*>(mitk::IOUtil::LoadImage(maskFile).GetPointer());
120  mitk::CastToItkImage(img, mask);
121  maskImageVector.push_back(mask);
122  }
123 
124  MITK_INFO << "loading white matter mask images";
125  std::vector< ItkUcharImgType::Pointer > wmMaskImageVector;
126  for (auto wmFile : wmMaskFiles)
127  {
128  mitk::Image::Pointer img = dynamic_cast<mitk::Image*>(mitk::IOUtil::LoadImage(wmFile).GetPointer());
130  mitk::CastToItkImage(img, wmmask);
131  wmMaskImageVector.push_back(wmmask);
132  }
133 
134  MITK_INFO << "loading tractograms";
135  std::vector< mitk::FiberBundle::Pointer > tractograms;
136  for (auto tractFile : tractogramFiles)
137  {
138  mitk::FiberBundle::Pointer fib = dynamic_cast<mitk::FiberBundle*>(mitk::IOUtil::Load(tractFile).at(0).GetPointer());
139  tractograms.push_back(fib);
140  }
141 
142  mitk::TrackingForestHandler<> forestHandler;
143  forestHandler.SetRawData(rawData);
144  forestHandler.SetMaskImages(maskImageVector);
145  forestHandler.SetWhiteMatterImages(wmMaskImageVector);
146  forestHandler.SetTractograms(tractograms);
147  forestHandler.SetNumTrees(numTrees);
148  forestHandler.SetMaxTreeDepth(max_tree_depth);
149  forestHandler.SetGrayMatterSamplesPerVoxel(gmsamples);
150  forestHandler.SetSampleFraction(sample_fraction);
151  forestHandler.SetStepSize(stepsize);
152  forestHandler.StartTraining();
153  forestHandler.SaveForest(forestFile);
154 
155  return EXIT_SUCCESS;
156 }
itk::SmartPointer< Self > Pointer
int main(int argc, char *argv[])
Train random forest classifier for machine learning based streamline tractography.
Definition: DFTraining.cpp:41
#define MITK_INFO
Definition: mitkLogMacros.h:22
void SetGrayMatterSamplesPerVoxel(int samples)
void SetTractograms(std::vector< FiberBundle::Pointer > tractograms)
void setContributor(std::string contributor)
STL namespace.
ValueType * any_cast(Any *operand)
Definition: usAny.h:377
std::map< std::string, us::Any > parseArguments(const StringContainerType &arguments, bool *ok=nullptr)
void SetSampleFraction(double fraction)
void SetRawData(std::vector< Image::Pointer > images)
Manages random forests for fiber tractography. The preparation of the features from the inputa data a...
void addArgument(const std::string &longarg, const std::string &shortarg, Type type, const std::string &argLabel, const std::string &argHelp=std::string(), const us::Any &defaultValue=us::Any(), bool optional=true, bool ignoreRest=false, bool deprecated=false)
void SetWhiteMatterImages(std::vector< ItkUcharImgType::Pointer > images)
Image class for storing images.
Definition: mitkImage.h:76
Definition: usAny.h:163
Base Class for Fiber Bundles;.
void setCategory(std::string category)
void setArgumentPrefix(const std::string &longPrefix, const std::string &shortPrefix)
void SaveForest(std::string forestFile)
void MITKCORE_EXPORT CastToItkImage(const mitk::Image *mitkImage, itk::SmartPointer< ItkOutputImageType > &itkOutputImage)
Cast an mitk::Image to an itk::Image with a specific type.
std::vector< std::string > StringContainerType
static DataStorage::SetOfObjects::Pointer Load(const std::string &path, DataStorage &storage)
Load a file into the given DataStorage.
Definition: mitkIOUtil.cpp:483
void setTitle(std::string title)
void setDescription(std::string description)
void SetMaskImages(std::vector< ItkUcharImgType::Pointer > images)
static mitk::Image::Pointer LoadImage(const std::string &path)
LoadImage Convenience method to load an arbitrary mitkImage.
Definition: mitkIOUtil.cpp:597
static itkEventMacro(BoundingShapeInteractionEvent, itk::AnyEvent) class MITKBOUNDINGSHAPE_EXPORT BoundingShapeInteractor Pointer New()
Basic interaction methods for mitk::GeometryData.