1 /*===================================================================
3 The Medical Imaging Interaction Toolkit (MITK)
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
13 See LICENSE.txt or http://www.mitk.org for details.
15 ===================================================================*/
17 #ifndef __itkNonLocalMeansDenoisingFilter_txx
18 #define __itkNonLocalMeansDenoisingFilter_txx
24 #define _USE_MATH_DEFINES
27 #include "itkImageRegionIterator.h"
28 #include "itkNeighborhoodIterator.h"
29 #include <itkImageRegionIteratorWithIndex.h>
35 template< class TPixelType >
36 NonLocalMeansDenoisingFilter< TPixelType >
37 ::NonLocalMeansDenoisingFilter()
39 m_ComparisonRadius(1),
40 m_UseJointInformation(false),
41 m_UseRicianAdaption(false),
45 this->SetNumberOfRequiredInputs( 1 );
48 template< class TPixelType >
50 NonLocalMeansDenoisingFilter< TPixelType >
51 ::BeforeThreadedGenerateData()
54 MITK_INFO << "SearchRadius: " << m_SearchRadius;
55 MITK_INFO << "ComparisonRadius: " << m_ComparisonRadius;
56 MITK_INFO << "Noisevariance: " << m_Variance;
57 MITK_INFO << "Use Rician Adaption: " << std::boolalpha << m_UseRicianAdaption;
58 MITK_INFO << "Use Joint Information: " << std::boolalpha << m_UseJointInformation;
61 typename InputImageType::Pointer inputImagePointer = static_cast< InputImageType * >( this->ProcessObject::GetInput(0) );
64 // If no mask is used generate a mask of the complete image
66 m_Mask = MaskImageType::New();
67 m_Mask->SetRegions(inputImagePointer->GetLargestPossibleRegion());
69 m_Mask->FillBuffer(1);
73 // Calculation of the smallest masked region
75 typename OutputImageType::Pointer outputImage =
76 static_cast< OutputImageType * >(this->ProcessObject::GetOutput(0));
77 ImageRegionIterator< OutputImageType > oit(outputImage, inputImagePointer->GetLargestPossibleRegion());
79 ImageRegionIterator< MaskImageType > mit(m_Mask, m_Mask->GetLargestPossibleRegion());
81 typename MaskImageType::IndexType minIndex;
82 typename MaskImageType::IndexType maxIndex;
85 typename OutputImageType::PixelType outpix;
86 outpix.SetSize(inputImagePointer->GetVectorLength());
88 while (!mit.IsAtEnd())
93 // calculation of the start & end index of the smallest masked region
94 minIndex[0] = minIndex[0] < mit.GetIndex()[0] ? minIndex[0] : mit.GetIndex()[0];
95 minIndex[1] = minIndex[1] < mit.GetIndex()[1] ? minIndex[1] : mit.GetIndex()[1];
96 minIndex[2] = minIndex[2] < mit.GetIndex()[2] ? minIndex[2] : mit.GetIndex()[2];
98 maxIndex[0] = maxIndex[0] > mit.GetIndex()[0] ? maxIndex[0] : mit.GetIndex()[0];
99 maxIndex[1] = maxIndex[1] > mit.GetIndex()[1] ? maxIndex[1] : mit.GetIndex()[1];
100 maxIndex[2] = maxIndex[2] > mit.GetIndex()[2] ? maxIndex[2] : mit.GetIndex()[2];
111 // calculation of the masked region
112 typename OutputImageType::SizeType size;
113 size[0] = maxIndex[0] - minIndex[0] + 1;
114 size[1] = maxIndex[1] - minIndex[1] + 1;
115 size[2] = maxIndex[2] - minIndex[2] + 1;
117 typename OutputImageType::RegionType region (minIndex, size);
119 outputImage->SetRequestedRegion(region);
122 m_CurrentVoxelCount = 0;
125 template< class TPixelType >
127 NonLocalMeansDenoisingFilter< TPixelType >
128 ::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, ThreadIdType )
132 // initialize iterators
133 typename OutputImageType::Pointer outputImage =
134 static_cast< OutputImageType * >(this->ProcessObject::GetOutput(0));
136 ImageRegionIterator< OutputImageType > oit(outputImage, outputRegionForThread);
139 ImageRegionIterator< MaskImageType > mit(m_Mask, outputRegionForThread);
144 typedef ImageRegionIteratorWithIndex <InputImageType> InputIteratorType;
145 typename InputImageType::Pointer inputImagePointer = NULL;
146 inputImagePointer = static_cast< InputImageType * >( this->ProcessObject::GetInput(0) );
148 InputIteratorType git(inputImagePointer, outputRegionForThread );
151 // iterate over complete image region
152 while( !git.IsAtEnd() )
154 typename OutputImageType::PixelType outpix;
155 outpix.SetSize (inputImagePointer->GetVectorLength());
157 if (mit.Get() != 0 && !this->GetAbortGenerateData())
159 if(!m_UseJointInformation)
161 for (int i = 0; i < (int)inputImagePointer->GetVectorLength(); ++i)
166 std::vector<double> wj;
167 std::vector<double> p;
168 typename InputIteratorType::IndexType index;
169 index = git.GetIndex();
171 for (int x = index.GetElement(0) - m_SearchRadius; x <= index.GetElement(0) + m_SearchRadius; ++x)
173 for (int y = index.GetElement(1) - m_SearchRadius; y <= index.GetElement(1) + m_SearchRadius; ++y)
175 for (int z = index.GetElement(2) - m_SearchRadius; z <= index.GetElement(2) + m_SearchRadius; ++z)
177 typename InputIteratorType::IndexType indexV;
178 indexV.SetElement(0, x);
179 indexV.SetElement(1, y);
180 indexV.SetElement(2, z);
181 if (inputImagePointer->GetLargestPossibleRegion().IsInside(indexV))
183 TPixelType pixelJ = inputImagePointer->GetPixel(indexV)[i];
186 for (int xi = index.GetElement(0) - m_ComparisonRadius, xj = x - m_ComparisonRadius; xi <= index.GetElement(0) + m_ComparisonRadius; ++xi, ++xj)
188 for (int yi = index.GetElement(1) - m_ComparisonRadius, yj = y - m_ComparisonRadius; yi <= index.GetElement(1) + m_ComparisonRadius; ++yi, ++yj)
190 for (int zi = index.GetElement(2) - m_ComparisonRadius, zj = z - m_ComparisonRadius; zi <= index.GetElement(2) + m_ComparisonRadius; ++zi, ++zj)
192 typename InputIteratorType::IndexType indexI, indexJ;
193 indexI.SetElement(0, xi);
194 indexI.SetElement(1, yi);
195 indexI.SetElement(2, zi);
196 indexJ.SetElement(0, xj);
197 indexJ.SetElement(1, yj);
198 indexJ.SetElement(2, zj);
201 // Compare neighborhoods ni & nj
202 if (inputImagePointer->GetLargestPossibleRegion().IsInside(indexI) && inputImagePointer->GetLargestPossibleRegion().IsInside(indexJ))
204 int diff = inputImagePointer->GetPixel(indexI)[i] - inputImagePointer->GetPixel(indexJ)[i];
205 sumk += (double)(diff*diff);
211 // weight all neighborhoods
212 w = std::exp( - sumk / size / m_Variance);
214 if (m_UseRicianAdaption)
216 p.push_back((double)(pixelJ*pixelJ));
220 p.push_back((double)(pixelJ));
227 for (unsigned int n = 0; n < wj.size(); ++n)
229 sumj += (wj[n]/summw) * p[n];
231 if (m_UseRicianAdaption)
233 sumj -=2 * m_Variance;
242 if (m_UseRicianAdaption)
244 outval = std::floor(std::sqrt(sumj) + 0.5);
248 outval = std::floor(sumj + 0.5);
250 outpix.SetElement(i, outval);
256 // same procedure for vektoranalysis
259 itk::VariableLengthVector<double> sumj;
260 sumj.SetSize(inputImagePointer->GetVectorLength());
263 std::vector<double> wj;
264 std::vector<itk::VariableLengthVector <double> > p;
265 typename InputIteratorType::IndexType index;
266 index = git.GetIndex();
268 for (int x = index.GetElement(0) - m_SearchRadius; x <= index.GetElement(0) + m_SearchRadius; ++x)
270 for (int y = index.GetElement(1) - m_SearchRadius; y <= index.GetElement(1) + m_SearchRadius; ++y)
272 for (int z = index.GetElement(2) - m_SearchRadius; z <= index.GetElement(2) + m_SearchRadius; ++z)
274 typename InputIteratorType::IndexType indexV;
275 indexV.SetElement(0, x);
276 indexV.SetElement(1, y);
277 indexV.SetElement(2, z);
278 if (inputImagePointer->GetLargestPossibleRegion().IsInside(indexV))
280 typename InputImageType::PixelType pixelJ = inputImagePointer->GetPixel(indexV);
283 for (int xi = index.GetElement(0) - m_ComparisonRadius, xj = x - m_ComparisonRadius; xi <= index.GetElement(0) + m_ComparisonRadius; ++xi, ++xj)
285 for (int yi = index.GetElement(1) - m_ComparisonRadius, yj = y - m_ComparisonRadius; yi <= index.GetElement(1) + m_ComparisonRadius; ++yi, ++yj)
287 for (int zi = index.GetElement(2) - m_ComparisonRadius, zj = z - m_ComparisonRadius; zi <= index.GetElement(2) + m_ComparisonRadius; ++zi, ++zj)
289 typename InputIteratorType::IndexType indexI, indexJ;
290 indexI.SetElement(0, xi);
291 indexI.SetElement(1, yi);
292 indexI.SetElement(2, zi);
293 indexJ.SetElement(0, xj);
294 indexJ.SetElement(1, yj);
295 indexJ.SetElement(2, zj);
296 // Compare neighborhoods ni & nj
297 if (inputImagePointer->GetLargestPossibleRegion().IsInside(indexI) && inputImagePointer->GetLargestPossibleRegion().IsInside(indexJ))
299 typename InputImageType::PixelType diff = inputImagePointer->GetPixel(indexI) - inputImagePointer->GetPixel(indexJ);
300 sumk += (double)(diff.GetSquaredNorm());
306 // weight all neighborhoods
307 size *= inputImagePointer->GetVectorLength() + 1;
308 w = std::exp( - (sumk / size) / m_Variance);
310 if (m_UseRicianAdaption)
312 itk::VariableLengthVector <double> m;
313 m.SetSize(inputImagePointer->GetVectorLength());
314 for (unsigned int i = 0; i < inputImagePointer->GetVectorLength(); ++i)
316 double sp = (double)(pixelJ.GetElement(i) * pixelJ.GetElement(i));
331 for (unsigned int n = 0; n < wj.size(); ++n)
333 sumj = sumj + ((wj[n]/Z) * p[n]);
335 if (m_UseRicianAdaption)
337 sumj = sumj - (2 * m_Variance);
340 for (unsigned int i = 0; i < inputImagePointer->GetVectorLength(); ++i)
342 double a = sumj.GetElement(i);
348 if (m_UseRicianAdaption)
350 outval = std::floor(std::sqrt(a) + 0.5);
354 outval = std::floor(a + 0.5);
356 outpix.SetElement(i, outval);
367 ++m_CurrentVoxelCount;
372 MITK_INFO << "One Thread finished calculation";
375 template< class TPixelType >
376 void NonLocalMeansDenoisingFilter< TPixelType >::SetInputImage(const InputImageType* image)
378 this->SetNthInput(0, const_cast< InputImageType* >(image));
381 template< class TPixelType >
382 void NonLocalMeansDenoisingFilter< TPixelType >::SetInputMask(MaskImageType* mask)
388 #endif // __itkNonLocalMeansDenoisingFilter_txx