Medical Imaging Interaction Toolkit  2018.4.99-389bf124
Medical Imaging Interaction Toolkit
mitkLevenbergMarquardtModelFitFunctor.cpp
Go to the documentation of this file.
1 /*============================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center (DKFZ)
6 All rights reserved.
7 
8 Use of this source code is governed by a 3-clause BSD license that can be
9 found in the LICENSE file.
10 
11 ============================================================================*/
12 
14 
17 #include <chrono>
18 #include <mitkExceptionMacro.h>
19 
21 LevenbergMarquardtModelFitFunctor(): m_Epsilon(1e-5), m_GradientTolerance(1e-3),
22  m_ValueTolerance(1e-5), m_Iterations(1000), m_DerivativeStepLength(1e-5),
23  m_ActivateFailureThreshold(true)
24 {};
25 
28 {};
29 
33 {
34  ParameterNamesType names;
35  names.push_back("sum_diff^2");
36  return names;
37 };
38 
41 GetCriteria(const ModelBase* model, const ParametersType& parameters,
42  const SignalType& sample) const
43 {
46  metric->SetModel(model);
47  metric->SetSample(sample);
48 
50  result[0] = metric->GetValue(parameters);
51 
52  return result;
53 };
54 
56  const SignalType& value, const ModelBase* model) const
57 {
60  metric->SetModel(model);
61  metric->SetSample(value);
62  metric->SetDerivativeStepLength(m_DerivativeStepLength);
63 
64  mitk::MVModelFitCostFunction::Pointer result = metric.GetPointer();
65 
66  if (m_ConstraintChecker.IsNotNull())
67  {
70  decorator->SetConstraintChecker(m_ConstraintChecker);
71  decorator->SetWrappedCostFunction(metric);
72  decorator->SetFailureThreshold(m_ConstraintChecker->GetFailedConstraintValue());
73 
74  decorator->SetModel(model);
75  decorator->SetSample(value);
76  decorator->SetActivateFailureThreshold(m_ActivateFailureThreshold);
77  result = decorator;
78  }
79 
80  return result;
81 };
82 
85 {
86  ParameterNamesType result;
87  result.push_back("optimization_time");
88  result.push_back("nr_of_iterations");
89  result.push_back("stop_condition");
90  if (m_ConstraintChecker.IsNotNull())
91  {
92  result.push_back("constraint_penalty_ratio");
93  result.push_back("constraint_failure_ratio");
94  result.push_back("constraint_last_failed_parameter");
95  }
96  return result;
97 };
98 
101 DoModelFit(const SignalType& value, const ModelBase* model,
102  const ModelBase::ParametersType& initialParameters,
103  DebugParameterMapType& debugParameters) const
104 {
105  std::chrono::time_point<std::chrono::system_clock> startTime;
106  startTime = std::chrono::system_clock::now();
107  ::itk::LevenbergMarquardtOptimizer::ParametersType internalInitParam = initialParameters;
108  ::itk::LevenbergMarquardtOptimizer::ScalesType scales = m_Scales;
109 
110  if (initialParameters.GetNumberOfElements() != model->GetNumberOfParameters())
111  {
112  MITK_DEBUG <<
113  "Size of initial parameters of fit functor optimizer do not match number of model parameters. Renitialize parameters with 0.0.";
114  internalInitParam.SetSize(model->GetNumberOfParameters());
115  internalInitParam.Fill(0.0);
116  }
117 
118  if (m_Scales.GetNumberOfElements() != model->GetNumberOfParameters())
119  {
120  MITK_DEBUG <<
121  "Size of scales of fit functor optimizer do not match number of model parameters. Reinitialize scales with 1.0.";
122  scales.SetSize(model->GetNumberOfParameters());
123  scales.Fill(1.0);
124  }
125 
126  mitk::MVModelFitCostFunction::Pointer metric = this->GenerateCostFunction(value, model);
127 
128  ::itk::LevenbergMarquardtOptimizer::Pointer optimizer = ::itk::LevenbergMarquardtOptimizer::New();
129 
130  optimizer->SetCostFunction(metric);
131  optimizer->SetEpsilonFunction(m_Epsilon);
132  optimizer->SetGradientTolerance(m_GradientTolerance);
133  optimizer->SetNumberOfIterations(m_Iterations);
134  optimizer->SetScales(scales);
135  optimizer->SetInitialPosition(internalInitParam);
136 
137  optimizer->StartOptimization();
138 
139  itk::Optimizer::ParametersType position = optimizer->GetCurrentPosition();
140 
141  std::chrono::time_point<std::chrono::system_clock> stopTime;
142  stopTime = std::chrono::system_clock::now();
143  debugParameters.clear();
144  if (this->GetDebugParameterMaps())
145  {
146  const auto timeDiff = std::chrono::duration_cast<std::chrono::milliseconds>(stopTime - startTime).count();
147  debugParameters.insert(std::make_pair("optimization_time", timeDiff));
148 
149  ParameterImagePixelType value = optimizer->GetOptimizer()->get_num_iterations();
150  debugParameters.insert(std::make_pair("nr_of_iterations", value));
151  value = optimizer->GetOptimizer()->get_failure_code();
152  debugParameters.insert(std::make_pair("stop_condition", value));
153 
154 
155  const ::mitk::MVConstrainedCostFunctionDecorator* decorator = dynamic_cast<const ::mitk::MVConstrainedCostFunctionDecorator*>(metric.GetPointer());
156  if (decorator)
157  {
158  value = decorator->GetPenaltyRatio();
159  debugParameters.insert(std::make_pair("constraint_penalty_ratio", value));
160  value = decorator->GetFailureRatio();
161  debugParameters.insert(std::make_pair("constraint_failure_ratio", value));
162  value = decorator->GetFailedParameter();
163  debugParameters.insert(std::make_pair("constraint_last_failed_parameter", value));
164  }
165  else
166  {
167  if (m_ConstraintChecker.IsNotNull())
168  {
169  mitkThrow() << "Fit functor has invalid state/wrong implementation. Constraint checker is set, but used metric seems to be no MVContstrainedCostFunctionDecorator.";
170  }
171  }
172  }
173 
174  return position;
175 };
virtual MVModelFitCostFunction::Pointer GenerateCostFunction(const SignalType &value, const ModelBase *model) const
Base class for (dynamic) models. A model can be used to calculate its signal given the discrete time ...
Definition: mitkModelBase.h:47
ModelTraitsInterface::ParametersType ParametersType
Definition: mitkModelBase.h:59
std::map< std::string, ParameterImagePixelType > DebugParameterMapType
#define MITK_DEBUG
Definition: mitkLogMacros.h:22
ModelBase::ParameterNamesType ParameterNamesType
virtual bool GetDebugParameterMaps() const
virtual ParametersSizeType GetNumberOfParameters() const =0
#define mitkThrow()
This class is used to add constraints to any multi valued model fit cost function.
ParametersType DoModelFit(const SignalType &value, const ModelBase *model, const ModelBase::ParametersType &initialParameters, DebugParameterMapType &debugParameters) const override
OutputPixelArrayType GetCriteria(const ModelBase *model, const ParametersType &parameters, const SignalType &sample) const override