Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
mitkPyramidalRegistrationMethodAccessFunctor.txx
Go to the documentation of this file.
1 /*===================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
7 All rights reserved.
8 
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 A PARTICULAR PURPOSE.
12 
13 See LICENSE.txt or http://www.mitk.org for details.
14 
15 ===================================================================*/
16 
17 #include "mitkPyramidalRegistrationMethod.h"
18 #include "mitkPyramidalRegistrationMethodAccessFunctor.h"
19 
20 #include <mitkImageCast.h>
21 
22 #include <itkHistogramMatchingImageFilter.h>
23 #include <itkLinearInterpolateImageFunction.h>
24 #include <itkMultiResolutionImageRegistrationMethod.h>
25 #include <itkRescaleIntensityImageFilter.h>
26 
27 #include "mitkMetricFactory.h"
28 #include "mitkOptimizerFactory.h"
29 #include "mitkRegistrationInterfaceCommand.h"
30 #include "mitkTransformFactory.h"
31 
32 namespace mitk
33 {
34  template <typename TPixel, unsigned int VImageDimension>
35  void PyramidalRegistrationMethodAccessFunctor::AccessItkImage(const itk::Image<TPixel, VImageDimension> *itkImage1,
36  PyramidalRegistrationMethod *method)
37  {
38  // TYPEDEFS
39  typedef itk::Image<TPixel, VImageDimension> ImageType;
40 
41  typedef float InternalPixelType;
42  typedef itk::Image<InternalPixelType, VImageDimension> InternalImageType;
43 
44  typedef typename itk::Transform<double, VImageDimension, VImageDimension> TransformType;
45  typedef mitk::TransformFactory<InternalPixelType, VImageDimension> TransformFactoryType;
46  typedef itk::LinearInterpolateImageFunction<InternalImageType, double> InterpolatorType;
47  typedef mitk::MetricFactory<InternalPixelType, VImageDimension> MetricFactoryType;
48  typedef itk::RecursiveMultiResolutionPyramidImageFilter<InternalImageType, InternalImageType> ImagePyramidType;
49  typedef itk::DiscreteGaussianImageFilter<ImageType, InternalImageType> GaussianFilterType;
50 
51  typedef itk::MultiResolutionImageRegistrationMethod<InternalImageType, InternalImageType> RegistrationType;
52  typedef RegistrationInterfaceCommand<RegistrationType, TPixel> CommandType;
53 
54  typedef itk::CastImageFilter<ImageType, InternalImageType> CastImageFilterType;
55 
56  itk::Array<double> initialParameters;
57  if (method->m_TransformParameters->GetInitialParameters().size())
58  {
59  initialParameters = method->m_TransformParameters->GetInitialParameters();
60  }
61 
62  // LOAD PARAMETERS
63  itk::Array<double> transformValues = method->m_Preset->getTransformValues(method->m_Presets[0]);
64  itk::Array<double> metricValues = method->m_Preset->getMetricValues(method->m_Presets[0]);
65  itk::Array<double> optimizerValues = method->m_Preset->getOptimizerValues(method->m_Presets[0]);
66  method->m_TransformParameters = method->ParseTransformParameters(transformValues);
67  method->m_MetricParameters = method->ParseMetricParameters(metricValues);
68  method->m_OptimizerParameters = method->ParseOptimizerParameters(optimizerValues);
69 
70  // The fixed and the moving image
71  typename InternalImageType::Pointer fixedImage = InternalImageType::New();
72  typename InternalImageType::Pointer movingImage = InternalImageType::New();
73 
74  mitk::CastToItkImage(method->m_ReferenceImage, fixedImage);
75 
76  // Blur the moving image
77  if (method->m_BlurMovingImage)
78  {
79  typename GaussianFilterType::Pointer gaussianFilter = GaussianFilterType::New();
80  gaussianFilter->SetInput(itkImage1);
81  gaussianFilter->SetVariance(6.0);
82  gaussianFilter->SetMaximumError(0.1);
83  // gaussianFilter->SetMaximumKernelWidth ( 3 );
84  gaussianFilter->Update();
85  movingImage = gaussianFilter->GetOutput();
86  }
87  else
88  {
89  typename CastImageFilterType::Pointer castImageFilter = CastImageFilterType::New();
90  castImageFilter->SetInput(itkImage1);
91  castImageFilter->Update();
92  movingImage = castImageFilter->GetOutput();
93  }
94 
95  if (method->m_MatchHistograms)
96  {
97  typedef itk::RescaleIntensityImageFilter<InternalImageType, InternalImageType> FilterType;
98  typedef itk::HistogramMatchingImageFilter<InternalImageType, InternalImageType> HEFilterType;
99 
100  typename FilterType::Pointer inputRescaleFilter = FilterType::New();
101  typename FilterType::Pointer referenceRescaleFilter = FilterType::New();
102 
103  referenceRescaleFilter->SetInput(fixedImage);
104  inputRescaleFilter->SetInput(movingImage);
105 
106  const float desiredMinimum = 0.0;
107  const float desiredMaximum = 255.0;
108 
109  referenceRescaleFilter->SetOutputMinimum(desiredMinimum);
110  referenceRescaleFilter->SetOutputMaximum(desiredMaximum);
111  referenceRescaleFilter->UpdateLargestPossibleRegion();
112  inputRescaleFilter->SetOutputMinimum(desiredMinimum);
113  inputRescaleFilter->SetOutputMaximum(desiredMaximum);
114  inputRescaleFilter->UpdateLargestPossibleRegion();
115 
116  // Histogram match the images
117  typename HEFilterType::Pointer intensityEqualizeFilter = HEFilterType::New();
118 
119  intensityEqualizeFilter->SetReferenceImage(inputRescaleFilter->GetOutput());
120  intensityEqualizeFilter->SetInput(referenceRescaleFilter->GetOutput());
121  intensityEqualizeFilter->SetNumberOfHistogramLevels(64);
122  intensityEqualizeFilter->SetNumberOfMatchPoints(12);
123  intensityEqualizeFilter->ThresholdAtMeanIntensityOn();
124  intensityEqualizeFilter->Update();
125 
126  // fixedImage = referenceRescaleFilter->GetOutput();
127  // movingImage = IntensityEqualizeFilter->GetOutput();
128 
129  fixedImage = intensityEqualizeFilter->GetOutput();
130  movingImage = inputRescaleFilter->GetOutput();
131  }
132 
133  typename TransformFactoryType::Pointer transFac = TransformFactoryType::New();
134  transFac->SetTransformParameters(method->m_TransformParameters);
135  transFac->SetFixedImage(fixedImage);
136  transFac->SetMovingImage(movingImage);
137  typename TransformType::Pointer transform = transFac->GetTransform();
138 
139  typename InterpolatorType::Pointer interpolator = InterpolatorType::New();
140  typename MetricFactoryType::Pointer metFac = MetricFactoryType::New();
141  metFac->SetMetricParameters(method->m_MetricParameters);
142 
143  typename OptimizerFactory::Pointer optFac = OptimizerFactory::New();
144  optFac->SetOptimizerParameters(method->m_OptimizerParameters);
145  optFac->SetNumberOfTransformParameters(transform->GetNumberOfParameters());
146  typename PyramidalRegistrationMethod::OptimizerType::Pointer optimizer = optFac->GetOptimizer();
147 
148  // optimizer scales
149  if (method->m_TransformParameters->GetUseOptimizerScales())
150  {
151  itk::Array<double> optimizerScales = method->m_TransformParameters->GetScales();
152  typename PyramidalRegistrationMethod::OptimizerType::ScalesType scales(transform->GetNumberOfParameters());
153  for (unsigned int i = 0; i < scales.Size(); i++)
154  {
155  scales[i] = optimizerScales[i];
156  }
157  optimizer->SetScales(scales);
158  }
159 
160  typename ImagePyramidType::Pointer fixedImagePyramid = ImagePyramidType::New();
161  typename ImagePyramidType::Pointer movingImagePyramid = ImagePyramidType::New();
162 
163  if (method->m_FixedSchedule.size() > 0 && method->m_MovingSchedule.size() > 0)
164  {
165  fixedImagePyramid->SetSchedule(method->m_FixedSchedule);
166  movingImagePyramid->SetSchedule(method->m_MovingSchedule);
167  // Otherwise just use the default schedule
168  }
169 
170  typename RegistrationType::Pointer registration = RegistrationType::New();
171  registration->SetOptimizer(optimizer);
172  registration->SetTransform(transform);
173  registration->SetInterpolator(interpolator);
174  registration->SetMetric(metFac->GetMetric());
175  registration->SetFixedImagePyramid(fixedImagePyramid);
176  registration->SetMovingImagePyramid(movingImagePyramid);
177  registration->SetFixedImage(fixedImage);
178  registration->SetMovingImage(movingImage);
179  registration->SetFixedImageRegion(fixedImage->GetBufferedRegion());
180 
181  if (transFac->GetTransformParameters()->GetInitialParameters().size())
182  {
183  registration->SetInitialTransformParameters(transFac->GetTransformParameters()->GetInitialParameters());
184  }
185  else
186  {
187  itk::Array<double> zeroInitial;
188  zeroInitial.set_size(transform->GetNumberOfParameters());
189  zeroInitial.fill(0.0);
190  zeroInitial[0] = 1.0;
191  zeroInitial[4] = 1.0;
192  zeroInitial[8] = 1.0;
193  registration->SetInitialTransformParameters(zeroInitial);
194  }
195 
196  if (method->m_UseMask)
197  {
198  itk::ImageMaskSpatialObject<VImageDimension> *mask =
199  dynamic_cast<itk::ImageMaskSpatialObject<VImageDimension> *>(method->m_BrainMask.GetPointer());
200  registration->GetMetric()->SetFixedImageMask(mask);
201  }
202 
203  // registering command observer with the optimizer
204  if (method->m_Observer.IsNotNull())
205  {
206  method->m_Observer->AddStepsToDo(20);
207  optimizer->AddObserver(itk::AnyEvent(), method->m_Observer);
208  registration->AddObserver(itk::AnyEvent(), method->m_Observer);
209  transform->AddObserver(itk::AnyEvent(), method->m_Observer);
210  }
211 
212  typename CommandType::Pointer command = CommandType::New();
213  command->observer = method->m_Observer;
214  command->m_Presets = method->m_Presets;
215  command->m_UseMask = method->m_UseMask;
216  command->m_BrainMask = method->m_BrainMask;
217 
218  registration->AddObserver(itk::IterationEvent(), command);
219  registration->SetSchedules(method->m_FixedSchedule, method->m_MovingSchedule);
220 
221  // Start the registration process
222  try
223  {
224  registration->Update();
225  }
226  catch (itk::ExceptionObject &err)
227  {
228  std::cout << "ExceptionObject caught !" << std::endl;
229  std::cout << err << std::endl;
230  }
231  if (method->m_Observer.IsNotNull())
232  {
233  optimizer->RemoveAllObservers();
234  registration->RemoveAllObservers();
235  transform->RemoveAllObservers();
236  method->m_Observer->SetRemainingProgress(15);
237  }
238  if (method->m_Observer.IsNotNull())
239  {
240  method->m_Observer->SetRemainingProgress(5);
241  }
242  }
243 
244 } // end namespace