17 #ifndef _TrackingForestHandler_cpp
18 #define _TrackingForestHandler_cpp
26 template<
int ShOrder,
int NumberOfSignalFeatures >
28 : m_WmSampleDistance(-1)
31 , m_SampleFraction(1.0)
32 , m_NumberOfSamples(0)
33 , m_GmSamplesPerVoxel(50)
37 template<
int ShOrder,
int NumberOfSignalFeatures >
42 template<
int ShOrder,
int NumberOfSignalFeatures >
47 itk::ContinuousIndex< double, 3> cIdx;
48 image->TransformPhysicalPointToIndex(itkP, idx);
49 image->TransformPhysicalPointToContinuousIndex(itkP, cIdx);
52 if ( image->GetLargestPossibleRegion().IsInside(idx) )
53 pix = image->GetPixel(idx);
57 double frac_x = cIdx[0] - idx[0];
58 double frac_y = cIdx[1] - idx[1];
59 double frac_z = cIdx[2] - idx[2];
80 if (idx[0] >= 0 && idx[0] < image->GetLargestPossibleRegion().GetSize(0)-1 &&
81 idx[1] >= 0 && idx[1] < image->GetLargestPossibleRegion().GetSize(1)-1 &&
82 idx[2] >= 0 && idx[2] < image->GetLargestPossibleRegion().GetSize(2)-1)
85 vnl_vector_fixed<double, 8> interpWeights;
86 interpWeights[0] = ( frac_x)*( frac_y)*( frac_z);
87 interpWeights[1] = (1-frac_x)*( frac_y)*( frac_z);
88 interpWeights[2] = ( frac_x)*(1-frac_y)*( frac_z);
89 interpWeights[3] = ( frac_x)*( frac_y)*(1-frac_z);
90 interpWeights[4] = (1-frac_x)*(1-frac_y)*( frac_z);
91 interpWeights[5] = ( frac_x)*(1-frac_y)*(1-frac_z);
92 interpWeights[6] = (1-frac_x)*( frac_y)*(1-frac_z);
93 interpWeights[7] = (1-frac_x)*(1-frac_y)*(1-frac_z);
95 pix = image->GetPixel(idx) * interpWeights[0];
96 typename InterpolatedRawImageType::IndexType tmpIdx = idx; tmpIdx[0]++;
97 pix += image->GetPixel(tmpIdx) * interpWeights[1];
98 tmpIdx = idx; tmpIdx[1]++;
99 pix += image->GetPixel(tmpIdx) * interpWeights[2];
100 tmpIdx = idx; tmpIdx[2]++;
101 pix += image->GetPixel(tmpIdx) * interpWeights[3];
102 tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++;
103 pix += image->GetPixel(tmpIdx) * interpWeights[4];
104 tmpIdx = idx; tmpIdx[1]++; tmpIdx[2]++;
105 pix += image->GetPixel(tmpIdx) * interpWeights[5];
106 tmpIdx = idx; tmpIdx[2]++; tmpIdx[0]++;
107 pix += image->GetPixel(tmpIdx) * interpWeights[6];
108 tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; tmpIdx[2]++;
109 pix += image->GetPixel(tmpIdx) * interpWeights[7];
115 template<
int ShOrder,
int NumberOfSignalFeatures >
118 if (m_RawData.empty())
119 mitkThrow() <<
"No diffusion-weighted images set!";
120 if (!IsForestValid())
121 mitkThrow() <<
"No or invalid random forest detected!";
124 template<
int ShOrder,
int NumberOfSignalFeatures >
127 InputDataValidForTracking();
129 MITK_INFO <<
"Spherically interpolating raw data and creating feature image ...";
135 filter->SetLambda(0.006);
136 filter->SetNormalizationMethod(InterpolationFilterType::QBAR_RAW_SIGNAL);
139 vnl_vector_fixed<double,3> ref; ref.fill(0); ref[0]=1;
141 m_DirectionIndices.clear();
142 for (
unsigned int f=0; f<NumberOfSignalFeatures*2; f++)
145 m_DirectionIndices.push_back(f);
149 m_FeatureImage->SetSpacing(filter->GetOutput()->GetSpacing());
150 m_FeatureImage->SetOrigin(filter->GetOutput()->GetOrigin());
151 m_FeatureImage->SetDirection(filter->GetOutput()->GetDirection());
152 m_FeatureImage->SetLargestPossibleRegion(filter->GetOutput()->GetLargestPossibleRegion());
153 m_FeatureImage->SetBufferedRegion(filter->GetOutput()->GetLargestPossibleRegion());
154 m_FeatureImage->SetRequestedRegion(filter->GetOutput()->GetLargestPossibleRegion());
155 m_FeatureImage->Allocate();
158 itk::ImageRegionIterator< typename InterpolationFilterType::OutputImageType > it(filter->GetOutput(), filter->GetOutput()->GetLargestPossibleRegion());
162 for (
unsigned int f=0; f<NumberOfSignalFeatures; f++)
163 pix[f] = it.Get()[m_DirectionIndices.at(f)];
164 m_FeatureImage->SetPixel(it.GetIndex(), pix);
171 template<
int ShOrder,
int NumberOfSignalFeatures >
176 itk::ContinuousIndex< double, 3> cIdx;
177 m_FeatureImage->TransformPhysicalPointToIndex(itkP, idx);
178 m_FeatureImage->TransformPhysicalPointToContinuousIndex(itkP, cIdx);
181 if ( m_FeatureImage->GetLargestPossibleRegion().IsInside(idx) )
182 pix = m_FeatureImage->GetPixel(idx);
186 double frac_x = cIdx[0] - idx[0];
187 double frac_y = cIdx[1] - idx[1];
188 double frac_z = cIdx[2] - idx[2];
209 if (idx[0] >= 0 && idx[0] < m_FeatureImage->GetLargestPossibleRegion().GetSize(0)-1 &&
210 idx[1] >= 0 && idx[1] < m_FeatureImage->GetLargestPossibleRegion().GetSize(1)-1 &&
211 idx[2] >= 0 && idx[2] < m_FeatureImage->GetLargestPossibleRegion().GetSize(2)-1)
214 vnl_vector_fixed<double, 8> interpWeights;
215 interpWeights[0] = ( frac_x)*( frac_y)*( frac_z);
216 interpWeights[1] = (1-frac_x)*( frac_y)*( frac_z);
217 interpWeights[2] = ( frac_x)*(1-frac_y)*( frac_z);
218 interpWeights[3] = ( frac_x)*( frac_y)*(1-frac_z);
219 interpWeights[4] = (1-frac_x)*(1-frac_y)*( frac_z);
220 interpWeights[5] = ( frac_x)*(1-frac_y)*(1-frac_z);
221 interpWeights[6] = (1-frac_x)*( frac_y)*(1-frac_z);
222 interpWeights[7] = (1-frac_x)*(1-frac_y)*(1-frac_z);
224 pix = m_FeatureImage->GetPixel(idx) * interpWeights[0];
225 typename FeatureImageType::IndexType tmpIdx = idx; tmpIdx[0]++;
226 pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[1];
227 tmpIdx = idx; tmpIdx[1]++;
228 pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[2];
229 tmpIdx = idx; tmpIdx[2]++;
230 pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[3];
231 tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++;
232 pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[4];
233 tmpIdx = idx; tmpIdx[1]++; tmpIdx[2]++;
234 pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[5];
235 tmpIdx = idx; tmpIdx[2]++; tmpIdx[0]++;
236 pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[6];
237 tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; tmpIdx[2]++;
238 pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[7];
244 template<
int ShOrder,
int NumberOfSignalFeatures >
247 vnl_vector_fixed<double,3> direction; direction.fill(0);
250 m_FeatureImage->TransformPhysicalPointToIndex(pos, idx);
251 if (mask.IsNotNull() && ((mask->GetLargestPossibleRegion().IsInside(idx) && mask->GetPixel(idx)<=0) || !mask->GetLargestPossibleRegion().IsInside(idx)) )
255 vigra::MultiArray<2, double> featureData = vigra::MultiArray<2, double>( vigra::Shape2(1,NumberOfSignalFeatures+3) );
257 for (
unsigned int f=0; f<NumberOfSignalFeatures; f++)
258 featureData(0,f) = featurePixel[f];
262 vnl_vector_fixed<double,3> ref; ref.fill(0); ref[0]=1;
263 for (
unsigned int f=NumberOfSignalFeatures; f<NumberOfSignalFeatures+3; f++)
265 if (dot_product(ref, olddir)<0)
266 featureData(0,f) = -olddir[c];
268 featureData(0,f) = olddir[c];
273 vigra::MultiArray<2, double> probs(vigra::Shape2(1, m_Forest->class_count()));
274 m_Forest->predictProbabilities(featureData, probs);
279 for (
int i=0; i<m_Forest->class_count(); i++)
285 m_Forest->ext_param_.to_classlabel(i, classLabel);
287 if (classLabel<m_DirectionIndices.size())
290 vnl_vector_fixed<double,3> d = m_DirContainer.GetDirection(m_DirectionIndices.at(classLabel));
292 if (olddir.magnitude()>0)
298 double dot = dot_product(d, olddir);
299 if (fabs(dot)>angularThreshold)
303 double w_i = probs(0,i)*fabs(dot);
310 direction += probs(0,i)*d;
315 pNonFib += probs(0,i);
320 if (pNonFib>w && w>0)
330 template<
int ShOrder,
int NumberOfSignalFeatures >
333 m_StartTime = std::chrono::system_clock::now();
334 InputDataValidForTraining();
335 PreprocessInputDataForTraining();
336 CalculateFeaturesForTraining();
338 m_EndTime = std::chrono::system_clock::now();
339 std::chrono::hours hh = std::chrono::duration_cast<std::chrono::hours>(m_EndTime - m_StartTime);
340 std::chrono::minutes mm = std::chrono::duration_cast<std::chrono::minutes>(m_EndTime - m_StartTime);
342 MITK_INFO <<
"Training took " << hh.count() <<
"h and " << mm.count() <<
"m";
345 template<
int ShOrder,
int NumberOfSignalFeatures >
348 if (m_RawData.empty())
349 mitkThrow() <<
"No diffusion-weighted images set!";
350 if (m_Tractograms.empty())
352 if (m_RawData.size()!=m_Tractograms.size())
353 mitkThrow() <<
"Unequal number of diffusion-weighted images and tractograms detected!";
356 template<
int ShOrder,
int NumberOfSignalFeatures >
359 if(m_Forest && m_Forest->tree_count()>0 && m_Forest->feature_count()==(NumberOfSignalFeatures+3))
364 template<
int ShOrder,
int NumberOfSignalFeatures >
369 MITK_INFO <<
"Spherical signal interpolation and sampling ...";
370 for (
unsigned int i=0; i<m_RawData.size(); i++)
375 qballfilter->SetLambda(0.006);
376 qballfilter->SetNormalizationMethod(InterpolationFilterType::QBAR_RAW_SIGNAL);
377 qballfilter->Update();
379 m_InterpolatedRawImages.push_back(qballfilter->GetOutput());
381 if (i>=m_MaskImages.size())
384 newMask->SetSpacing( m_InterpolatedRawImages.at(i)->GetSpacing() );
385 newMask->SetOrigin( m_InterpolatedRawImages.at(i)->GetOrigin() );
386 newMask->SetDirection( m_InterpolatedRawImages.at(i)->GetDirection() );
387 newMask->SetLargestPossibleRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() );
388 newMask->SetBufferedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() );
389 newMask->SetRequestedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() );
391 newMask->FillBuffer(1);
392 m_MaskImages.push_back(newMask);
395 if (m_MaskImages.at(i)==
nullptr)
398 m_MaskImages.at(i)->SetSpacing( m_InterpolatedRawImages.at(i)->GetSpacing() );
399 m_MaskImages.at(i)->SetOrigin( m_InterpolatedRawImages.at(i)->GetOrigin() );
400 m_MaskImages.at(i)->SetDirection( m_InterpolatedRawImages.at(i)->GetDirection() );
401 m_MaskImages.at(i)->SetLargestPossibleRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() );
402 m_MaskImages.at(i)->SetBufferedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() );
403 m_MaskImages.at(i)->SetRequestedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() );
404 m_MaskImages.at(i)->Allocate();
405 m_MaskImages.at(i)->FillBuffer(1);
409 MITK_INFO <<
"Resampling fibers and calculating number of samples ...";
410 m_NumberOfSamples = 0;
411 for (
unsigned int t=0; t<m_Tractograms.size(); t++)
415 if (t<m_WhiteMatterImages.size() && m_WhiteMatterImages.at(t)!=
nullptr)
416 wmmask = m_WhiteMatterImages.at(t);
419 MITK_INFO <<
"No white-matter mask found. Using fiber envelope.";
421 env->SetFiberBundle(m_Tractograms.at(t));
422 env->SetInputImage(mask);
423 env->SetBinaryOutput(
true);
424 env->SetUseImageGeometry(
true);
426 wmmask = env->GetOutput();
427 if (t>=m_WhiteMatterImages.size())
428 m_WhiteMatterImages.push_back(wmmask);
430 m_WhiteMatterImages.at(t) = wmmask;
434 if (m_WmSampleDistance<0)
437 float minSpacing = 1;
438 if(image->GetSpacing()[0]<image->GetSpacing()[1] && image->GetSpacing()[0]<image->GetSpacing()[2])
439 minSpacing = image->GetSpacing()[0];
440 else if (image->GetSpacing()[1] < image->GetSpacing()[2])
441 minSpacing = image->GetSpacing()[1];
443 minSpacing = image->GetSpacing()[2];
444 m_WmSampleDistance = minSpacing*0.5;
447 m_Tractograms.at(t)->ResampleSpline(m_WmSampleDistance);
448 unsigned int wmSamples = m_Tractograms.at(t)->GetNumberOfPoints()-2*m_Tractograms.at(t)->GetNumFibers();
449 m_NumberOfSamples += wmSamples;
450 MITK_INFO <<
"Samples inside of WM: " << wmSamples;
453 itk::ImageRegionConstIterator<ItkUcharImgType> it(wmmask, wmmask->GetLargestPossibleRegion());
457 if (it.Get()==0 && mask->GetPixel(it.GetIndex())>0)
461 MITK_INFO <<
"Non-white matter voxels: " << OUTOFWM;
463 if (m_GmSamplesPerVoxel>0)
465 m_GmSamples.push_back(m_GmSamplesPerVoxel);
466 m_NumberOfSamples += m_GmSamplesPerVoxel*OUTOFWM;
470 m_GmSamples.push_back(0.5+(
double)wmSamples/(
double)OUTOFWM);
471 m_NumberOfSamples += m_GmSamples.back()*OUTOFWM;
472 MITK_INFO <<
"Non-white matter samples per voxel: " << m_GmSamples.back();
476 m_GmSamples.push_back(0);
478 MITK_INFO <<
"Samples outside of WM: " << m_GmSamples.back()*OUTOFWM;
480 MITK_INFO <<
"Number of samples: " << m_NumberOfSamples;
483 template<
int ShOrder,
int NumberOfSignalFeatures >
486 vnl_vector_fixed<double,3> ref; ref.fill(0); ref[0]=1;
488 std::vector< int > directionIndices;
489 for (
unsigned int f=0; f<2*NumberOfSignalFeatures; f++)
492 directionIndices.push_back(f);
495 int numDirectionFeatures = 3;
497 m_FeatureData.reshape( vigra::Shape2(m_NumberOfSamples, NumberOfSignalFeatures+numDirectionFeatures) );
498 m_LabelData.reshape( vigra::Shape2(m_NumberOfSamples,1) );
499 MITK_INFO <<
"Number of features: " << m_FeatureData.shape(1);
502 m_RandGen->SetSeed();
503 MITK_INFO <<
"Creating training data ...";
504 int sampleCounter = 0;
505 for (
unsigned int t=0; t<m_Tractograms.size(); t++)
510 if (t<m_MaskImages.size())
511 mask = m_MaskImages.at(t);
514 itk::ImageRegionConstIterator<ItkUcharImgType> it(wmMask, wmMask->GetLargestPossibleRegion());
517 if (it.Get()==0 && (mask.IsNull() || (mask.IsNotNull() && mask->GetPixel(it.GetIndex())>0)))
522 for (
unsigned int f=0; f<NumberOfSignalFeatures; f++)
524 m_FeatureData(sampleCounter,f) = pix[directionIndices[f]];
525 if(m_FeatureData(sampleCounter,f)!=m_FeatureData(sampleCounter,f))
526 m_FeatureData(sampleCounter,f) = 0;
528 m_FeatureData(sampleCounter,NumberOfSignalFeatures) = 0;
529 m_FeatureData(sampleCounter,NumberOfSignalFeatures+1) = 0;
530 m_FeatureData(sampleCounter,NumberOfSignalFeatures+2) = 0;
531 m_LabelData(sampleCounter,0) = directionIndices.size();
535 for (
int i=1; i<m_GmSamples.at(t); i++)
537 for (
unsigned int f=0; f<NumberOfSignalFeatures; f++)
539 m_FeatureData(sampleCounter,f) = pix[directionIndices[f]];
540 if(m_FeatureData(sampleCounter,f)!=m_FeatureData(sampleCounter,f))
541 m_FeatureData(sampleCounter,f) = 0;
544 vnl_vector_fixed<double,3> probe;
545 probe[0] = m_RandGen->GetVariate()*2-1;
546 probe[1] = m_RandGen->GetVariate()*2-1;
547 probe[2] = m_RandGen->GetVariate()*2-1;
549 if (dot_product(ref, probe)<0)
551 for (
unsigned int f=NumberOfSignalFeatures; f<NumberOfSignalFeatures+3; f++)
553 m_FeatureData(sampleCounter,f) = probe[c];
556 m_LabelData(sampleCounter,0) = directionIndices.size();
565 vtkSmartPointer< vtkPolyData > polyData = fib->GetFiberPolyData();
566 for (
int i=0; i<fib->GetNumFibers(); i++)
568 vtkCell* cell = polyData->GetCell(i);
569 int numPoints = cell->GetNumberOfPoints();
570 vtkPoints* points = cell->GetPoints();
572 vnl_vector_fixed<double,3> dirOld; dirOld.fill(0.0);
574 for (
int j=0; j<numPoints-1; j++)
577 double* p1 = points->GetPoint(j);
578 itk::Point<float, 3> itkP1;
579 itkP1[0] = p1[0]; itkP1[1] = p1[1]; itkP1[2] = p1[2];
581 vnl_vector_fixed<double,3> dir; dir.fill(0.0);
583 itk::Point<float, 3> itkP2;
584 double* p2 = points->GetPoint(j+1);
585 itkP2[0] = p2[0]; itkP2[1] = p2[1]; itkP2[2] = p2[2];
586 dir[0]=itkP2[0]-itkP1[0];
587 dir[1]=itkP2[1]-itkP1[1];
588 dir[2]=itkP2[2]-itkP1[2];
590 if (dir.magnitude()<0.0001)
596 if (dir[0]!=dir[0] || dir[1]!=dir[1] || dir[2]!=dir[2])
610 for (
unsigned int f=0; f<NumberOfSignalFeatures; f++)
611 m_FeatureData(sampleCounter,f) = pix[directionIndices[f]];
615 if (dot_product(ref, dirOld)<0)
618 for (
unsigned int f=NumberOfSignalFeatures; f<NumberOfSignalFeatures+3; f++)
620 m_FeatureData(sampleCounter,f) = dirOld[c];
626 double m = dir.magnitude();
629 for (
unsigned int f=0; f<NumberOfSignalFeatures; f++)
631 double a = fabs(dot_product(dir, directions.
GetDirection(directionIndices[f])));
634 m_LabelData(sampleCounter,0) = f;
647 template<
int ShOrder,
int NumberOfSignalFeatures >
650 MITK_INFO <<
"Maximum tree depths: " << m_MaxTreeDepth;
651 MITK_INFO <<
"Sample fraction per tree: " << m_SampleFraction;
652 MITK_INFO <<
"Number of trees: " << m_NumTrees;
654 std::vector< std::shared_ptr< vigra::RandomForest<int> > > trees;
656 #pragma omp parallel for
657 for (
int i = 0; i < m_NumTrees; ++i)
659 std::shared_ptr< vigra::RandomForest<int> > lrf = std::make_shared< vigra::RandomForest<int> >();
660 lrf->set_options().use_stratification(vigra::RF_NONE);
661 lrf->set_options().sample_with_replacement(
true);
662 lrf->set_options().samples_per_tree(m_SampleFraction);
663 lrf->set_options().tree_count(1);
664 lrf->set_options().min_split_node_size(5);
665 lrf->ext_param_.max_tree_depth = m_MaxTreeDepth;
667 lrf->learn(m_FeatureData, m_LabelData);
671 MITK_INFO <<
"Tree " << count <<
" finished training.";
672 trees.push_back(lrf);
676 for (
int i = 1; i < m_NumTrees; ++i)
677 trees.at(0)->trees_.push_back(trees.at(i)->trees_[0]);
679 m_Forest = trees.at(0);
680 m_Forest->options_.tree_count_ = m_NumTrees;
684 template<
int ShOrder,
int NumberOfSignalFeatures >
687 MITK_INFO <<
"Saving forest to " << forestFile;
689 vigra::rf_export_HDF5( *m_Forest, forestFile,
"" );
691 MITK_INFO <<
"Forest invalid! Could not be saved.";
692 MITK_INFO <<
"Forest saved successfully.";
695 template<
int ShOrder,
int NumberOfSignalFeatures >
698 MITK_INFO <<
"Loading forest from " << forestFile;
699 m_Forest = std::make_shared< vigra::RandomForest<int> >();
700 vigra::rf_import_HDF5( *m_Forest, forestFile);
itk::SmartPointer< Self > Pointer
void PreprocessInputDataForTraining()
Generate masks if necessary, resample fibers, spherically interpolate raw DWIs.
DataCollection - Class to facilitate loading/accessing structured data.
void LoadForest(std::string forestFile)
Manages random forests for fiber tractography. The preparation of the features from the inputa data a...
static ImageType::Pointer GetItkVectorImage(Image *image)
void InputDataValidForTracking()
check if raw data is set and tracking forest is valid
This class takes as input one or more reference image (acquired in the absence of diffusion sensitizi...
static vnl_vector_fixed< double, 3 > GetDirection(int i)
vnl_vector_fixed< double, 3 > Classify(itk::Point< double, 3 > &pos, int &candidates, vnl_vector_fixed< double, 3 > &olddir, double angularThreshold, double &w, ItkUcharImgType::Pointer mask=nullptr)
predicts next progression direction at the given position
void CalculateFeaturesForTraining()
Calculate GM and WM features using the interpolated raw data, the WM masks and the fibers...
void SaveForest(std::string forestFile)
float GetReferenceBValue() const
InterpolatedRawImageType::PixelType GetImageValues(itk::Point< float, 3 > itkP, typename InterpolatedRawImageType::Pointer image)
get trilinearly interpolated raw image values at given world position
void InputDataValidForTraining()
Check if everything is tehere for training (raw datasets, fiber tracts)
void InitForTracking()
calls InputDataValidForTracking() and creates feature images from the war input DWI ...
GradientDirectionsContainerType::Pointer GetGradientContainer() const
FeatureImageType::PixelType GetFeatureValues(itk::Point< float, 3 > itkP)
get trilinearly interpolated feature values at given world position
void TrainForest()
start training process
bool IsForestValid()
true is forest is not null, has more than 0 trees and the correct number of features (NumberOfSignalF...
static itkEventMacro(BoundingShapeInteractionEvent, itk::AnyEvent) class MITKBOUNDINGSHAPE_EXPORT BoundingShapeInteractor Pointer New()
Basic interaction methods for mitk::GeometryData.