Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
mitkTrackingForestHandler.cpp
Go to the documentation of this file.
1 /*===================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
7 All rights reserved.
8 
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 A PARTICULAR PURPOSE.
12 
13 See LICENSE.txt or http://www.mitk.org for details.
14 
15 ===================================================================*/
16 
17 #ifndef _TrackingForestHandler_cpp
18 #define _TrackingForestHandler_cpp
19 
23 
24 namespace mitk
25 {
26  template< int ShOrder, int NumberOfSignalFeatures >
28  : m_WmSampleDistance(-1)
29  , m_NumTrees(30)
30  , m_MaxTreeDepth(50)
31  , m_SampleFraction(1.0)
32  , m_NumberOfSamples(0)
33  , m_GmSamplesPerVoxel(50)
34  {
35  }
36 
37  template< int ShOrder, int NumberOfSignalFeatures >
39  {
40  }
41 
42  template< int ShOrder, int NumberOfSignalFeatures >
44  {
45  // transform physical point to index coordinates
46  itk::Index<3> idx;
47  itk::ContinuousIndex< double, 3> cIdx;
48  image->TransformPhysicalPointToIndex(itkP, idx);
49  image->TransformPhysicalPointToContinuousIndex(itkP, cIdx);
50 
51  typename InterpolatedRawImageType::PixelType pix; pix.Fill(0.0);
52  if ( image->GetLargestPossibleRegion().IsInside(idx) )
53  pix = image->GetPixel(idx);
54  else
55  return pix;
56 
57  double frac_x = cIdx[0] - idx[0];
58  double frac_y = cIdx[1] - idx[1];
59  double frac_z = cIdx[2] - idx[2];
60  if (frac_x<0)
61  {
62  idx[0] -= 1;
63  frac_x += 1;
64  }
65  if (frac_y<0)
66  {
67  idx[1] -= 1;
68  frac_y += 1;
69  }
70  if (frac_z<0)
71  {
72  idx[2] -= 1;
73  frac_z += 1;
74  }
75  frac_x = 1-frac_x;
76  frac_y = 1-frac_y;
77  frac_z = 1-frac_z;
78 
79  // int coordinates inside image?
80  if (idx[0] >= 0 && idx[0] < image->GetLargestPossibleRegion().GetSize(0)-1 &&
81  idx[1] >= 0 && idx[1] < image->GetLargestPossibleRegion().GetSize(1)-1 &&
82  idx[2] >= 0 && idx[2] < image->GetLargestPossibleRegion().GetSize(2)-1)
83  {
84  // trilinear interpolation
85  vnl_vector_fixed<double, 8> interpWeights;
86  interpWeights[0] = ( frac_x)*( frac_y)*( frac_z);
87  interpWeights[1] = (1-frac_x)*( frac_y)*( frac_z);
88  interpWeights[2] = ( frac_x)*(1-frac_y)*( frac_z);
89  interpWeights[3] = ( frac_x)*( frac_y)*(1-frac_z);
90  interpWeights[4] = (1-frac_x)*(1-frac_y)*( frac_z);
91  interpWeights[5] = ( frac_x)*(1-frac_y)*(1-frac_z);
92  interpWeights[6] = (1-frac_x)*( frac_y)*(1-frac_z);
93  interpWeights[7] = (1-frac_x)*(1-frac_y)*(1-frac_z);
94 
95  pix = image->GetPixel(idx) * interpWeights[0];
96  typename InterpolatedRawImageType::IndexType tmpIdx = idx; tmpIdx[0]++;
97  pix += image->GetPixel(tmpIdx) * interpWeights[1];
98  tmpIdx = idx; tmpIdx[1]++;
99  pix += image->GetPixel(tmpIdx) * interpWeights[2];
100  tmpIdx = idx; tmpIdx[2]++;
101  pix += image->GetPixel(tmpIdx) * interpWeights[3];
102  tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++;
103  pix += image->GetPixel(tmpIdx) * interpWeights[4];
104  tmpIdx = idx; tmpIdx[1]++; tmpIdx[2]++;
105  pix += image->GetPixel(tmpIdx) * interpWeights[5];
106  tmpIdx = idx; tmpIdx[2]++; tmpIdx[0]++;
107  pix += image->GetPixel(tmpIdx) * interpWeights[6];
108  tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; tmpIdx[2]++;
109  pix += image->GetPixel(tmpIdx) * interpWeights[7];
110  }
111 
112  return pix;
113  }
114 
115  template< int ShOrder, int NumberOfSignalFeatures >
117  {
118  if (m_RawData.empty())
119  mitkThrow() << "No diffusion-weighted images set!";
120  if (!IsForestValid())
121  mitkThrow() << "No or invalid random forest detected!";
122  }
123 
124  template< int ShOrder, int NumberOfSignalFeatures >
126  {
127  InputDataValidForTracking();
128 
129  MITK_INFO << "Spherically interpolating raw data and creating feature image ...";
131 
133  filter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(m_RawData.at(0)), mitk::DiffusionPropertyHelper::GetItkVectorImage(m_RawData.at(0)) );
134  filter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(m_RawData.at(0)));
135  filter->SetLambda(0.006);
136  filter->SetNormalizationMethod(InterpolationFilterType::QBAR_RAW_SIGNAL);
137  filter->Update();
138 
139  vnl_vector_fixed<double,3> ref; ref.fill(0); ref[0]=1;
141  m_DirectionIndices.clear();
142  for (unsigned int f=0; f<NumberOfSignalFeatures*2; f++)
143  {
144  if (dot_product(ref, odf.GetDirection(f))>0) // only used directions on one hemisphere
145  m_DirectionIndices.push_back(f); // store indices for later mapping the classifier output to the actual direction
146  }
147 
148  m_FeatureImage = FeatureImageType::New();
149  m_FeatureImage->SetSpacing(filter->GetOutput()->GetSpacing());
150  m_FeatureImage->SetOrigin(filter->GetOutput()->GetOrigin());
151  m_FeatureImage->SetDirection(filter->GetOutput()->GetDirection());
152  m_FeatureImage->SetLargestPossibleRegion(filter->GetOutput()->GetLargestPossibleRegion());
153  m_FeatureImage->SetBufferedRegion(filter->GetOutput()->GetLargestPossibleRegion());
154  m_FeatureImage->SetRequestedRegion(filter->GetOutput()->GetLargestPossibleRegion());
155  m_FeatureImage->Allocate();
156 
157  // get signal values and store them in the feature image
158  itk::ImageRegionIterator< typename InterpolationFilterType::OutputImageType > it(filter->GetOutput(), filter->GetOutput()->GetLargestPossibleRegion());
159  while(!it.IsAtEnd())
160  {
161  typename FeatureImageType::PixelType pix;
162  for (unsigned int f=0; f<NumberOfSignalFeatures; f++)
163  pix[f] = it.Get()[m_DirectionIndices.at(f)];
164  m_FeatureImage->SetPixel(it.GetIndex(), pix);
165  ++it;
166  }
167 
168  //m_Forest->multithreadPrediction = false;
169  }
170 
171  template< int ShOrder, int NumberOfSignalFeatures >
173  {
174  // transform physical point to index coordinates
175  itk::Index<3> idx;
176  itk::ContinuousIndex< double, 3> cIdx;
177  m_FeatureImage->TransformPhysicalPointToIndex(itkP, idx);
178  m_FeatureImage->TransformPhysicalPointToContinuousIndex(itkP, cIdx);
179 
180  typename FeatureImageType::PixelType pix; pix.Fill(0.0);
181  if ( m_FeatureImage->GetLargestPossibleRegion().IsInside(idx) )
182  pix = m_FeatureImage->GetPixel(idx);
183  else
184  return pix;
185 
186  double frac_x = cIdx[0] - idx[0];
187  double frac_y = cIdx[1] - idx[1];
188  double frac_z = cIdx[2] - idx[2];
189  if (frac_x<0)
190  {
191  idx[0] -= 1;
192  frac_x += 1;
193  }
194  if (frac_y<0)
195  {
196  idx[1] -= 1;
197  frac_y += 1;
198  }
199  if (frac_z<0)
200  {
201  idx[2] -= 1;
202  frac_z += 1;
203  }
204  frac_x = 1-frac_x;
205  frac_y = 1-frac_y;
206  frac_z = 1-frac_z;
207 
208  // int coordinates inside image?
209  if (idx[0] >= 0 && idx[0] < m_FeatureImage->GetLargestPossibleRegion().GetSize(0)-1 &&
210  idx[1] >= 0 && idx[1] < m_FeatureImage->GetLargestPossibleRegion().GetSize(1)-1 &&
211  idx[2] >= 0 && idx[2] < m_FeatureImage->GetLargestPossibleRegion().GetSize(2)-1)
212  {
213  // trilinear interpolation
214  vnl_vector_fixed<double, 8> interpWeights;
215  interpWeights[0] = ( frac_x)*( frac_y)*( frac_z);
216  interpWeights[1] = (1-frac_x)*( frac_y)*( frac_z);
217  interpWeights[2] = ( frac_x)*(1-frac_y)*( frac_z);
218  interpWeights[3] = ( frac_x)*( frac_y)*(1-frac_z);
219  interpWeights[4] = (1-frac_x)*(1-frac_y)*( frac_z);
220  interpWeights[5] = ( frac_x)*(1-frac_y)*(1-frac_z);
221  interpWeights[6] = (1-frac_x)*( frac_y)*(1-frac_z);
222  interpWeights[7] = (1-frac_x)*(1-frac_y)*(1-frac_z);
223 
224  pix = m_FeatureImage->GetPixel(idx) * interpWeights[0];
225  typename FeatureImageType::IndexType tmpIdx = idx; tmpIdx[0]++;
226  pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[1];
227  tmpIdx = idx; tmpIdx[1]++;
228  pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[2];
229  tmpIdx = idx; tmpIdx[2]++;
230  pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[3];
231  tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++;
232  pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[4];
233  tmpIdx = idx; tmpIdx[1]++; tmpIdx[2]++;
234  pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[5];
235  tmpIdx = idx; tmpIdx[2]++; tmpIdx[0]++;
236  pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[6];
237  tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; tmpIdx[2]++;
238  pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[7];
239  }
240 
241  return pix;
242  }
243 
244  template< int ShOrder, int NumberOfSignalFeatures >
245  vnl_vector_fixed<double,3> TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::Classify(itk::Point<double, 3>& pos, int& candidates, vnl_vector_fixed<double,3>& olddir, double angularThreshold, double& w, ItkUcharImgType::Pointer mask)
246  {
247  vnl_vector_fixed<double,3> direction; direction.fill(0);
248 
249  itk::Index<3> idx;
250  m_FeatureImage->TransformPhysicalPointToIndex(pos, idx);
251  if (mask.IsNotNull() && ((mask->GetLargestPossibleRegion().IsInside(idx) && mask->GetPixel(idx)<=0) || !mask->GetLargestPossibleRegion().IsInside(idx)) )
252  return direction;
253 
254  // store feature pixel values in a vigra data type
255  vigra::MultiArray<2, double> featureData = vigra::MultiArray<2, double>( vigra::Shape2(1,NumberOfSignalFeatures+3) );
256  typename FeatureImageType::PixelType featurePixel = GetFeatureValues(pos);
257  for (unsigned int f=0; f<NumberOfSignalFeatures; f++)
258  featureData(0,f) = featurePixel[f];
259 
260  // append normalized previous direction to feature vector
261  int c = 0;
262  vnl_vector_fixed<double,3> ref; ref.fill(0); ref[0]=1;
263  for (unsigned int f=NumberOfSignalFeatures; f<NumberOfSignalFeatures+3; f++)
264  {
265  if (dot_product(ref, olddir)<0)
266  featureData(0,f) = -olddir[c];
267  else
268  featureData(0,f) = olddir[c];
269  c++;
270  }
271 
272  // perform classification
273  vigra::MultiArray<2, double> probs(vigra::Shape2(1, m_Forest->class_count()));
274  m_Forest->predictProbabilities(featureData, probs);
275 
276  double pNonFib = 0; // probability that we left the white matter
277  w = 0; // weight of the predicted direction
278  candidates = 0; // directions with probability > 0
279  for (int i=0; i<m_Forest->class_count(); i++) // for each class (number of possible directions + out-of-wm class)
280  {
281  if (probs(0,i)>0) // if probability of respective class is 0, do nothing
282  {
283  // get label of class (does not correspond to the loop variable i)
284  int classLabel = 0;
285  m_Forest->ext_param_.to_classlabel(i, classLabel);
286 
287  if (classLabel<m_DirectionIndices.size()) // does class label correspond to a direction or to the out-of-wm class?
288  {
289  candidates++; // now we have one direction more with probability > 0 (DO WE NEED THIS???)
290  vnl_vector_fixed<double,3> d = m_DirContainer.GetDirection(m_DirectionIndices.at(classLabel)); // get direction vector assiciated with the respective direction index
291 
292  if (olddir.magnitude()>0) // do we have a previous streamline direction or did we just start?
293  {
294  // TODO: check if hard curvature threshold is necessary.
295  // alternatively try square of dot pruduct as weight.
296  // TODO: check if additional weighting with dot product as directional prior is necessary. are there alternatives on the classification level?
297 
298  double dot = dot_product(d, olddir); // claculate angle between the candidate direction vector and the previous streamline direction
299  if (fabs(dot)>angularThreshold) // is angle between the directions smaller than our hard threshold?
300  {
301  if (dot<0) // make sure we don't walk backwards
302  d *= -1;
303  double w_i = probs(0,i)*fabs(dot);
304  direction += w_i*d; // weight contribution to output direction with its probability and the angular deviation from the previous direction
305  w += w_i; // increase output weight of the final direction
306  }
307  }
308  else
309  {
310  direction += probs(0,i)*d;
311  w += probs(0,i);
312  }
313  }
314  else
315  pNonFib += probs(0,i); // probability that we are not in the whte matter anymore
316  }
317  }
318 
319  // if we did not find a suitable direction, make sure that we return (0,0,0)
320  if (pNonFib>w && w>0)
321  {
322  candidates = 0;
323  w = 0;
324  direction.fill(0.0);
325  }
326 
327  return direction;
328  }
329 
330  template< int ShOrder, int NumberOfSignalFeatures >
332  {
333  m_StartTime = std::chrono::system_clock::now();
334  InputDataValidForTraining();
335  PreprocessInputDataForTraining();
336  CalculateFeaturesForTraining();
337  TrainForest();
338  m_EndTime = std::chrono::system_clock::now();
339  std::chrono::hours hh = std::chrono::duration_cast<std::chrono::hours>(m_EndTime - m_StartTime);
340  std::chrono::minutes mm = std::chrono::duration_cast<std::chrono::minutes>(m_EndTime - m_StartTime);
341  mm %= 60;
342  MITK_INFO << "Training took " << hh.count() << "h and " << mm.count() << "m";
343  }
344 
345  template< int ShOrder, int NumberOfSignalFeatures >
347  {
348  if (m_RawData.empty())
349  mitkThrow() << "No diffusion-weighted images set!";
350  if (m_Tractograms.empty())
351  mitkThrow() << "No tractograms set!";
352  if (m_RawData.size()!=m_Tractograms.size())
353  mitkThrow() << "Unequal number of diffusion-weighted images and tractograms detected!";
354  }
355 
356  template< int ShOrder, int NumberOfSignalFeatures >
358  {
359  if(m_Forest && m_Forest->tree_count()>0 && m_Forest->feature_count()==(NumberOfSignalFeatures+3))
360  return true;
361  return false;
362  }
363 
364  template< int ShOrder, int NumberOfSignalFeatures >
366  {
368 
369  MITK_INFO << "Spherical signal interpolation and sampling ...";
370  for (unsigned int i=0; i<m_RawData.size(); i++)
371  {
373  qballfilter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(m_RawData.at(i)), mitk::DiffusionPropertyHelper::GetItkVectorImage(m_RawData.at(i)) );
374  qballfilter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(m_RawData.at(i)));
375  qballfilter->SetLambda(0.006);
376  qballfilter->SetNormalizationMethod(InterpolationFilterType::QBAR_RAW_SIGNAL);
377  qballfilter->Update();
378  // FeatureImageType::Pointer itkFeatureImage = qballfilter->GetCoefficientImage();
379  m_InterpolatedRawImages.push_back(qballfilter->GetOutput());
380 
381  if (i>=m_MaskImages.size())
382  {
384  newMask->SetSpacing( m_InterpolatedRawImages.at(i)->GetSpacing() );
385  newMask->SetOrigin( m_InterpolatedRawImages.at(i)->GetOrigin() );
386  newMask->SetDirection( m_InterpolatedRawImages.at(i)->GetDirection() );
387  newMask->SetLargestPossibleRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() );
388  newMask->SetBufferedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() );
389  newMask->SetRequestedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() );
390  newMask->Allocate();
391  newMask->FillBuffer(1);
392  m_MaskImages.push_back(newMask);
393  }
394 
395  if (m_MaskImages.at(i)==nullptr)
396  {
397  m_MaskImages.at(i) = ItkUcharImgType::New();
398  m_MaskImages.at(i)->SetSpacing( m_InterpolatedRawImages.at(i)->GetSpacing() );
399  m_MaskImages.at(i)->SetOrigin( m_InterpolatedRawImages.at(i)->GetOrigin() );
400  m_MaskImages.at(i)->SetDirection( m_InterpolatedRawImages.at(i)->GetDirection() );
401  m_MaskImages.at(i)->SetLargestPossibleRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() );
402  m_MaskImages.at(i)->SetBufferedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() );
403  m_MaskImages.at(i)->SetRequestedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() );
404  m_MaskImages.at(i)->Allocate();
405  m_MaskImages.at(i)->FillBuffer(1);
406  }
407  }
408 
409  MITK_INFO << "Resampling fibers and calculating number of samples ...";
410  m_NumberOfSamples = 0;
411  for (unsigned int t=0; t<m_Tractograms.size(); t++)
412  {
413  ItkUcharImgType::Pointer mask = m_MaskImages.at(t);
415  if (t<m_WhiteMatterImages.size() && m_WhiteMatterImages.at(t)!=nullptr)
416  wmmask = m_WhiteMatterImages.at(t);
417  else
418  {
419  MITK_INFO << "No white-matter mask found. Using fiber envelope.";
421  env->SetFiberBundle(m_Tractograms.at(t));
422  env->SetInputImage(mask);
423  env->SetBinaryOutput(true);
424  env->SetUseImageGeometry(true);
425  env->Update();
426  wmmask = env->GetOutput();
427  if (t>=m_WhiteMatterImages.size())
428  m_WhiteMatterImages.push_back(wmmask);
429  else
430  m_WhiteMatterImages.at(t) = wmmask;
431  }
432 
433  // Calculate white-matter samples
434  if (m_WmSampleDistance<0)
435  {
436  typename InterpolatedRawImageType::Pointer image = m_InterpolatedRawImages.at(t);
437  float minSpacing = 1;
438  if(image->GetSpacing()[0]<image->GetSpacing()[1] && image->GetSpacing()[0]<image->GetSpacing()[2])
439  minSpacing = image->GetSpacing()[0];
440  else if (image->GetSpacing()[1] < image->GetSpacing()[2])
441  minSpacing = image->GetSpacing()[1];
442  else
443  minSpacing = image->GetSpacing()[2];
444  m_WmSampleDistance = minSpacing*0.5;
445  }
446 
447  m_Tractograms.at(t)->ResampleSpline(m_WmSampleDistance);
448  unsigned int wmSamples = m_Tractograms.at(t)->GetNumberOfPoints()-2*m_Tractograms.at(t)->GetNumFibers();
449  m_NumberOfSamples += wmSamples;
450  MITK_INFO << "Samples inside of WM: " << wmSamples;
451 
452  // calculate gray-matter samples
453  itk::ImageRegionConstIterator<ItkUcharImgType> it(wmmask, wmmask->GetLargestPossibleRegion());
454  int OUTOFWM = 0;
455  while(!it.IsAtEnd())
456  {
457  if (it.Get()==0 && mask->GetPixel(it.GetIndex())>0)
458  OUTOFWM++;
459  ++it;
460  }
461  MITK_INFO << "Non-white matter voxels: " << OUTOFWM;
462 
463  if (m_GmSamplesPerVoxel>0)
464  {
465  m_GmSamples.push_back(m_GmSamplesPerVoxel);
466  m_NumberOfSamples += m_GmSamplesPerVoxel*OUTOFWM;
467  }
468  else if (OUTOFWM>0)
469  {
470  m_GmSamples.push_back(0.5+(double)wmSamples/(double)OUTOFWM);
471  m_NumberOfSamples += m_GmSamples.back()*OUTOFWM;
472  MITK_INFO << "Non-white matter samples per voxel: " << m_GmSamples.back();
473  }
474  else
475  {
476  m_GmSamples.push_back(0);
477  }
478  MITK_INFO << "Samples outside of WM: " << m_GmSamples.back()*OUTOFWM;
479  }
480  MITK_INFO << "Number of samples: " << m_NumberOfSamples;
481  }
482 
483  template< int ShOrder, int NumberOfSignalFeatures >
485  {
486  vnl_vector_fixed<double,3> ref; ref.fill(0); ref[0]=1;
488  std::vector< int > directionIndices;
489  for (unsigned int f=0; f<2*NumberOfSignalFeatures; f++)
490  {
491  if (dot_product(ref, directions.GetDirection(f))>0)
492  directionIndices.push_back(f);
493  }
494 
495  int numDirectionFeatures = 3;
496 
497  m_FeatureData.reshape( vigra::Shape2(m_NumberOfSamples, NumberOfSignalFeatures+numDirectionFeatures) );
498  m_LabelData.reshape( vigra::Shape2(m_NumberOfSamples,1) );
499  MITK_INFO << "Number of features: " << m_FeatureData.shape(1);
500 
502  m_RandGen->SetSeed();
503  MITK_INFO << "Creating training data ...";
504  int sampleCounter = 0;
505  for (unsigned int t=0; t<m_Tractograms.size(); t++)
506  {
507  typename InterpolatedRawImageType::Pointer image = m_InterpolatedRawImages.at(t);
508  ItkUcharImgType::Pointer wmMask = m_WhiteMatterImages.at(t);
510  if (t<m_MaskImages.size())
511  mask = m_MaskImages.at(t);
512 
513  // non-white matter samples
514  itk::ImageRegionConstIterator<ItkUcharImgType> it(wmMask, wmMask->GetLargestPossibleRegion());
515  while(!it.IsAtEnd())
516  {
517  if (it.Get()==0 && (mask.IsNull() || (mask.IsNotNull() && mask->GetPixel(it.GetIndex())>0)))
518  {
519  typename InterpolatedRawImageType::PixelType pix = image->GetPixel(it.GetIndex());
520 
521  // null direction
522  for (unsigned int f=0; f<NumberOfSignalFeatures; f++)
523  {
524  m_FeatureData(sampleCounter,f) = pix[directionIndices[f]];
525  if(m_FeatureData(sampleCounter,f)!=m_FeatureData(sampleCounter,f))
526  m_FeatureData(sampleCounter,f) = 0;
527  }
528  m_FeatureData(sampleCounter,NumberOfSignalFeatures) = 0;
529  m_FeatureData(sampleCounter,NumberOfSignalFeatures+1) = 0;
530  m_FeatureData(sampleCounter,NumberOfSignalFeatures+2) = 0;
531  m_LabelData(sampleCounter,0) = directionIndices.size();
532  sampleCounter++;
533 
534  // random directions
535  for (int i=1; i<m_GmSamples.at(t); i++)
536  {
537  for (unsigned int f=0; f<NumberOfSignalFeatures; f++)
538  {
539  m_FeatureData(sampleCounter,f) = pix[directionIndices[f]];
540  if(m_FeatureData(sampleCounter,f)!=m_FeatureData(sampleCounter,f))
541  m_FeatureData(sampleCounter,f) = 0;
542  }
543  int c=0;
544  vnl_vector_fixed<double,3> probe;
545  probe[0] = m_RandGen->GetVariate()*2-1;
546  probe[1] = m_RandGen->GetVariate()*2-1;
547  probe[2] = m_RandGen->GetVariate()*2-1;
548  probe.normalize();
549  if (dot_product(ref, probe)<0)
550  probe *= -1;
551  for (unsigned int f=NumberOfSignalFeatures; f<NumberOfSignalFeatures+3; f++)
552  {
553  m_FeatureData(sampleCounter,f) = probe[c];
554  c++;
555  }
556  m_LabelData(sampleCounter,0) = directionIndices.size();
557  sampleCounter++;
558  }
559  }
560  ++it;
561  }
562 
563  // white matter samples
564  mitk::FiberBundle::Pointer fib = m_Tractograms.at(t);
565  vtkSmartPointer< vtkPolyData > polyData = fib->GetFiberPolyData();
566  for (int i=0; i<fib->GetNumFibers(); i++)
567  {
568  vtkCell* cell = polyData->GetCell(i);
569  int numPoints = cell->GetNumberOfPoints();
570  vtkPoints* points = cell->GetPoints();
571 
572  vnl_vector_fixed<double,3> dirOld; dirOld.fill(0.0);
573 
574  for (int j=0; j<numPoints-1; j++)
575  {
576  // calculate direction
577  double* p1 = points->GetPoint(j);
578  itk::Point<float, 3> itkP1;
579  itkP1[0] = p1[0]; itkP1[1] = p1[1]; itkP1[2] = p1[2];
580 
581  vnl_vector_fixed<double,3> dir; dir.fill(0.0);
582 
583  itk::Point<float, 3> itkP2;
584  double* p2 = points->GetPoint(j+1);
585  itkP2[0] = p2[0]; itkP2[1] = p2[1]; itkP2[2] = p2[2];
586  dir[0]=itkP2[0]-itkP1[0];
587  dir[1]=itkP2[1]-itkP1[1];
588  dir[2]=itkP2[2]-itkP1[2];
589 
590  if (dir.magnitude()<0.0001)
591  {
592  MITK_INFO << "streamline error!";
593  continue;
594  }
595  dir.normalize();
596  if (dir[0]!=dir[0] || dir[1]!=dir[1] || dir[2]!=dir[2])
597  {
598  MITK_INFO << "ERROR: NaN direction!";
599  continue;
600  }
601 
602  if (j==0)
603  {
604  dirOld = dir;
605  continue;
606  }
607 
608  // get voxel values
609  typename InterpolatedRawImageType::PixelType pix = GetImageValues(itkP1, image);
610  for (unsigned int f=0; f<NumberOfSignalFeatures; f++)
611  m_FeatureData(sampleCounter,f) = pix[directionIndices[f]];
612 
613  // direction training features
614  int c = 0;
615  if (dot_product(ref, dirOld)<0)
616  dirOld *= -1;
617 
618  for (unsigned int f=NumberOfSignalFeatures; f<NumberOfSignalFeatures+3; f++)
619  {
620  m_FeatureData(sampleCounter,f) = dirOld[c];
621  c++;
622  }
623 
624  // set label values
625  double angle = 0;
626  double m = dir.magnitude();
627  if (m>0.0001)
628  {
629  for (unsigned int f=0; f<NumberOfSignalFeatures; f++)
630  {
631  double a = fabs(dot_product(dir, directions.GetDirection(directionIndices[f])));
632  if (a>angle)
633  {
634  m_LabelData(sampleCounter,0) = f;
635  angle = a;
636  }
637  }
638  }
639 
640  dirOld = dir;
641  sampleCounter++;
642  }
643  }
644  }
645  }
646 
647  template< int ShOrder, int NumberOfSignalFeatures >
649  {
650  MITK_INFO << "Maximum tree depths: " << m_MaxTreeDepth;
651  MITK_INFO << "Sample fraction per tree: " << m_SampleFraction;
652  MITK_INFO << "Number of trees: " << m_NumTrees;
653 
654  std::vector< std::shared_ptr< vigra::RandomForest<int> > > trees;
655  int count = 0;
656 #pragma omp parallel for
657  for (int i = 0; i < m_NumTrees; ++i)
658  {
659  std::shared_ptr< vigra::RandomForest<int> > lrf = std::make_shared< vigra::RandomForest<int> >();
660  lrf->set_options().use_stratification(vigra::RF_NONE); // How the data should be made equal
661  lrf->set_options().sample_with_replacement(true); // if sampled with replacement or not
662  lrf->set_options().samples_per_tree(m_SampleFraction); // Fraction of samples that are used to train a tree
663  lrf->set_options().tree_count(1); // Number of trees that are calculated;
664  lrf->set_options().min_split_node_size(5); // Minimum number of datapoints that must be in a node
665  lrf->ext_param_.max_tree_depth = m_MaxTreeDepth;
666 
667  lrf->learn(m_FeatureData, m_LabelData);
668 #pragma omp critical
669  {
670  count++;
671  MITK_INFO << "Tree " << count << " finished training.";
672  trees.push_back(lrf);
673  }
674  }
675 
676  for (int i = 1; i < m_NumTrees; ++i)
677  trees.at(0)->trees_.push_back(trees.at(i)->trees_[0]);
678 
679  m_Forest = trees.at(0);
680  m_Forest->options_.tree_count_ = m_NumTrees;
681  MITK_INFO << "Training finsihed";
682  }
683 
684  template< int ShOrder, int NumberOfSignalFeatures >
686  {
687  MITK_INFO << "Saving forest to " << forestFile;
688  if (IsForestValid())
689  vigra::rf_export_HDF5( *m_Forest, forestFile, "" );
690  else
691  MITK_INFO << "Forest invalid! Could not be saved.";
692  MITK_INFO << "Forest saved successfully.";
693  }
694 
695  template< int ShOrder, int NumberOfSignalFeatures >
697  {
698  MITK_INFO << "Loading forest from " << forestFile;
699  m_Forest = std::make_shared< vigra::RandomForest<int> >();
700  vigra::rf_import_HDF5( *m_Forest, forestFile);
701  }
702 }
703 
704 #endif
itk::SmartPointer< Self > Pointer
#define MITK_INFO
Definition: mitkLogMacros.h:22
void PreprocessInputDataForTraining()
Generate masks if necessary, resample fibers, spherically interpolate raw DWIs.
DataCollection - Class to facilitate loading/accessing structured data.
void LoadForest(std::string forestFile)
Manages random forests for fiber tractography. The preparation of the features from the inputa data a...
static ImageType::Pointer GetItkVectorImage(Image *image)
void InputDataValidForTracking()
check if raw data is set and tracking forest is valid
This class takes as input one or more reference image (acquired in the absence of diffusion sensitizi...
static vnl_vector_fixed< double, 3 > GetDirection(int i)
vnl_vector_fixed< double, 3 > Classify(itk::Point< double, 3 > &pos, int &candidates, vnl_vector_fixed< double, 3 > &olddir, double angularThreshold, double &w, ItkUcharImgType::Pointer mask=nullptr)
predicts next progression direction at the given position
#define mitkThrow()
void CalculateFeaturesForTraining()
Calculate GM and WM features using the interpolated raw data, the WM masks and the fibers...
void SaveForest(std::string forestFile)
InterpolatedRawImageType::PixelType GetImageValues(itk::Point< float, 3 > itkP, typename InterpolatedRawImageType::Pointer image)
get trilinearly interpolated raw image values at given world position
void InputDataValidForTraining()
Check if everything is tehere for training (raw datasets, fiber tracts)
void InitForTracking()
calls InputDataValidForTracking() and creates feature images from the war input DWI ...
GradientDirectionsContainerType::Pointer GetGradientContainer() const
unsigned short PixelType
FeatureImageType::PixelType GetFeatureValues(itk::Point< float, 3 > itkP)
get trilinearly interpolated feature values at given world position
void TrainForest()
start training process
bool IsForestValid()
true is forest is not null, has more than 0 trees and the correct number of features (NumberOfSignalF...
static itkEventMacro(BoundingShapeInteractionEvent, itk::AnyEvent) class MITKBOUNDINGSHAPE_EXPORT BoundingShapeInteractor Pointer New()
Basic interaction methods for mitk::GeometryData.