Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
mitkMetropolisHastingsSampler.cpp
Go to the documentation of this file.
1 /*===================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
7 All rights reserved.
8 
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 A PARTICULAR PURPOSE.
12 
13 See LICENSE.txt or http://www.mitk.org for details.
14 
15 ===================================================================*/
16 
18 
19 using namespace mitk;
20 
22  : m_ExTemp(0.01)
23  , m_BirthProb(0.25)
24  , m_DeathProb(0.05)
25  , m_ShiftProb(0.15)
26  , m_OptShiftProb(0.1)
27  , m_ConnectionProb(0.45)
28  , m_TractProb(0.5)
29  , m_DelProb(0.1)
30  , m_ChempotParticle(0.0)
31  , m_AcceptedProposals(0)
32 {
33  m_RandGen = randGen;
34  m_ParticleGrid = grid;
35  m_EnergyComputer = enComp;
36 
39  m_Sigma = m_ParticleLength/8.0;
40  m_Gamma = 1/(m_Sigma*m_Sigma*2);
41  m_Z = pow(2*M_PI*m_Sigma,3.0/2.0)*(M_PI*m_Sigma/m_ParticleLength);
42 
43  m_CurvatureThreshold = curvThres;
44  m_StopProb = exp(-1/m_TractProb);
45 }
46 
47 void MetropolisHastingsSampler::SetProbabilities(float birth, float death, float shift, float optShift, float connect)
48 {
49  m_BirthProb = birth;
50  m_DeathProb = death;
51  m_ShiftProb = shift;
52  m_OptShiftProb = optShift;
53  m_ConnectionProb = connect;
55  if (sum!=1 && sum>mitk::eps)
56  {
57  m_BirthProb /= sum;
58  m_DeathProb /= sum;
59  m_ShiftProb /= sum;
60  m_OptShiftProb /= sum;
61  m_ConnectionProb /= sum;
62  }
63  std::cout << "Update proposal probabilities" << std::endl;
64  std::cout << "Birth: " << m_BirthProb << std::endl;
65  std::cout << "Death: " << m_DeathProb << std::endl;
66  std::cout << "Shift: " << m_ShiftProb << std::endl;
67  std::cout << "Optimal shift: " << m_OptShiftProb << std::endl;
68  std::cout << "Connection: " << m_ConnectionProb << std::endl;
69 }
70 
71 // print proposal times
73 {
74  double sum = m_BirthTime.GetTotal()+m_DeathTime.GetTotal()+m_ShiftTime.GetTotal()+m_OptShiftTime.GetTotal()+m_ConnectionTime.GetTotal();
75  std::cout << "Proposal time probes (toal%/mean)" << std::endl;
76  std::cout << "Birth: " << 100*m_BirthTime.GetTotal()/sum << "/" << m_BirthTime.GetMean()*1000 << std::endl;
77  std::cout << "Death: " << 100*m_DeathTime.GetTotal()/sum << "/" << m_DeathTime.GetMean()*1000 << std::endl;
78  std::cout << "Shift: " << 100*m_ShiftTime.GetTotal()/sum << "/" << m_ShiftTime.GetMean()*1000 << std::endl;
79  std::cout << "Optimal shift: " << 100*m_OptShiftTime.GetTotal()/sum << " - " << m_OptShiftTime.GetMean()*1000 << std::endl;
80  std::cout << "Connection: " << 100*m_ConnectionTime.GetTotal()/sum << "/" << m_ConnectionTime.GetMean()*1000 << std::endl;
81 }
82 
83 // update temperature of simulated annealing process
85 {
86  m_InTemp = val;
88 }
89 
90 // add small random number drawn from gaussian to each vector element
91 void MetropolisHastingsSampler::DistortVector(float sigma, vnl_vector_fixed<float, 3>& vec)
92 {
93  vec[0] += m_RandGen->GetNormalVariate(0.0, sigma);
94  vec[1] += m_RandGen->GetNormalVariate(0.0, sigma);
95  vec[2] += m_RandGen->GetNormalVariate(0.0, sigma);
96 }
97 
98 // generate normalized random vector
100 {
101  vnl_vector_fixed<float, 3> vec;
102  vec[0] = m_RandGen->GetNormalVariate();
103  vec[1] = m_RandGen->GetNormalVariate();
104  vec[2] = m_RandGen->GetNormalVariate();
105  vec.normalize();
106  return vec;
107 }
108 
109 // generate actual proposal (birth, death, shift and connection of particle)
111 {
112  float randnum = m_RandGen->GetVariate();
113 
114  // Birth Proposal
115  if (randnum < m_BirthProb)
116  {
117  m_BirthTime.Start();
118  vnl_vector_fixed<float, 3> R;
120  vnl_vector_fixed<float, 3> N = GetRandomDirection();
121  Particle prop;
122  prop.GetPos() = R;
123  prop.GetDir() = N;
124 
126 
127  float ex_energy = m_EnergyComputer->ComputeExternalEnergy(R,N,nullptr);
128  float in_energy = m_EnergyComputer->ComputeInternalEnergy(&prop);
129  prob *= exp((in_energy/m_InTemp+ex_energy/m_ExTemp)) ;
130 
131  if (prob > 1 || m_RandGen->GetVariate() < prob)
132  {
134  if (p!=nullptr)
135  {
136  p->GetPos() = R;
137  p->GetDir() = N;
139  }
140  }
141  m_BirthTime.Stop();
142  }
143  // Death Proposal
144  else if (randnum < m_BirthProb+m_DeathProb)
145  {
146  m_DeathTime.Start();
148  {
149  int pnum = m_RandGen->GetIntegerVariate()%m_ParticleGrid->m_NumParticles;
150  Particle *dp = m_ParticleGrid->GetParticle(pnum);
151  if (dp->pID == -1 && dp->mID == -1)
152  {
153  float ex_energy = m_EnergyComputer->ComputeExternalEnergy(dp->GetPos(),dp->GetDir(),dp);
154  float in_energy = m_EnergyComputer->ComputeInternalEnergy(dp);
155 
156  float prob = m_ParticleGrid->m_NumParticles * (m_BirthProb) /(m_Density*m_DeathProb); //*SpatProb(dp->R);
157  prob *= exp(-(in_energy/m_InTemp+ex_energy/m_ExTemp)) ;
158  if (prob > 1 || m_RandGen->GetVariate() < prob)
159  {
162  }
163  }
164  }
165  m_DeathTime.Stop();
166  }
167  // Shift Proposal
168  else if (randnum < m_BirthProb+m_DeathProb+m_ShiftProb)
169  {
171  {
172  m_ShiftTime.Start();
173  int pnum = m_RandGen->GetIntegerVariate()%m_ParticleGrid->m_NumParticles;
174  Particle *p = m_ParticleGrid->GetParticle(pnum);
175  Particle prop_p = *p;
176 
177  DistortVector(m_Sigma, prop_p.GetPos());
179  prop_p.GetDir().normalize();
180 
181 
182  float ex_energy = m_EnergyComputer->ComputeExternalEnergy(prop_p.GetPos(),prop_p.GetDir(),p)
185 
186  float prob = exp(ex_energy/m_ExTemp+in_energy/m_InTemp);
187  if (m_RandGen->GetVariate() < prob)
188  {
189  vnl_vector_fixed<float, 3> Rtmp = p->GetPos();
190  vnl_vector_fixed<float, 3> Ntmp = p->GetDir();
191  p->GetPos() = prop_p.GetPos();
192  p->GetDir() = prop_p.GetDir();
193  if (!m_ParticleGrid->TryUpdateGrid(pnum))
194  {
195  p->GetPos() = Rtmp;
196  p->GetDir() = Ntmp;
197  }
199  }
200  m_ShiftTime.Stop();
201  }
202  }
203  // Optimal Shift Proposal
205  {
207  {
208  m_OptShiftTime.Start();
209  int pnum = m_RandGen->GetIntegerVariate()%m_ParticleGrid->m_NumParticles;
210  Particle *p = m_ParticleGrid->GetParticle(pnum);
211 
212  bool no_proposal = false;
213  Particle prop_p = *p;
214  if (p->pID != -1 && p->mID != -1)
215  {
216  Particle *plus = m_ParticleGrid->GetParticle(p->pID);
217  int ep_plus = (plus->pID == p->ID)? 1 : -1;
218  Particle *minus = m_ParticleGrid->GetParticle(p->mID);
219  int ep_minus = (minus->pID == p->ID)? 1 : -1;
220  prop_p.GetPos() = (plus->GetPos() + plus->GetDir() * (m_ParticleLength * ep_plus) + minus->GetPos() + minus->GetDir() * (m_ParticleLength * ep_minus));
221  prop_p.GetPos() *= 0.5;
222  prop_p.GetDir() = plus->GetPos() - minus->GetPos();
223  prop_p.GetDir().normalize();
224  }
225  else if (p->pID != -1)
226  {
227  Particle *plus = m_ParticleGrid->GetParticle(p->pID);
228  int ep_plus = (plus->pID == p->ID)? 1 : -1;
229  prop_p.GetPos() = plus->GetPos() + plus->GetDir() * (m_ParticleLength * ep_plus * 2);
230  prop_p.GetDir() = plus->GetDir();
231  }
232  else if (p->mID != -1)
233  {
234  Particle *minus = m_ParticleGrid->GetParticle(p->mID);
235  int ep_minus = (minus->pID == p->ID)? 1 : -1;
236  prop_p.GetPos() = minus->GetPos() + minus->GetDir() * (m_ParticleLength * ep_minus * 2);
237  prop_p.GetDir() = minus->GetDir();
238  }
239  else
240  no_proposal = true;
241 
242  if (!no_proposal)
243  {
244  float cos = dot_product(prop_p.GetDir(), p->GetDir());
245  float p_rev = exp(-((prop_p.GetPos()-p->GetPos()).squared_magnitude() + (1-cos*cos))*m_Gamma)/m_Z;
246 
247  float ex_energy = m_EnergyComputer->ComputeExternalEnergy(prop_p.GetPos(),prop_p.GetDir(),p)
250 
251  float prob = exp(ex_energy/m_ExTemp+in_energy/m_InTemp)*m_ShiftProb*p_rev/(m_OptShiftProb+m_ShiftProb*p_rev);
252 
253  if (m_RandGen->GetVariate() < prob)
254  {
255  vnl_vector_fixed<float, 3> Rtmp = p->GetPos();
256  vnl_vector_fixed<float, 3> Ntmp = p->GetDir();
257  p->GetPos() = prop_p.GetPos();
258  p->GetDir() = prop_p.GetDir();
259  if (!m_ParticleGrid->TryUpdateGrid(pnum))
260  {
261  p->GetPos() = Rtmp;
262  p->GetDir() = Ntmp;
263  }
265  }
266  }
267  m_OptShiftTime.Stop();
268  }
269  }
270  // Connection Proposal
271  else
272  {
274  {
275  m_ConnectionTime.Start();
276  int pnum = m_RandGen->GetIntegerVariate()%m_ParticleGrid->m_NumParticles;
277  Particle *p = m_ParticleGrid->GetParticle(pnum);
278 
279  EndPoint P;
280  P.p = p;
281  P.ep = (m_RandGen->GetVariate() > 0.5)? 1 : -1; // direction of the new tract
282 
283  RemoveAndSaveTrack(P); // remove old tract and save it for later
284  if (m_BackupTrack.m_Probability != 0)
285  {
286  MakeTrackProposal(P); // propose new tract starting from P
287 
289 
292  if (m_RandGen->GetVariate() < prob)
293  {
294  ImplementTrack(m_ProposalTrack); // accept proposed tract
296  }
297  else
298  {
299  ImplementTrack(m_BackupTrack); // reject proposed tract and restore old one
300  }
301  }
302  else
304  m_ConnectionTime.Stop();
305  }
306  }
307 }
308 
309 // establish connections between particles stored in input Track
311 {
312  for (int k = 1; k < T.m_Length;k++)
313  m_ParticleGrid->CreateConnection(T.track[k-1].p,T.track[k-1].ep,T.track[k].p,-T.track[k].ep);
314 }
315 
316 // remove pending track from random particle, save it in m_BackupTrack and calculate its probability
318 {
319  EndPoint Current = P;
320  int cnt = 0;
321  float energy = 0;
322  float AccumProb = 1.0;
323  m_BackupTrack.track[cnt] = Current;
324  EndPoint Next;
325 
326  for (;;)
327  {
328  Next.p = nullptr;
329  if (Current.ep == 1)
330  {
331  if (Current.p->pID != -1)
332  {
333  Next.p = m_ParticleGrid->GetParticle(Current.p->pID);
334  Current.p->pID = -1;
336  }
337  }
338  else if (Current.ep == -1)
339  {
340  if (Current.p->mID != -1)
341  {
342  Next.p = m_ParticleGrid->GetParticle(Current.p->mID);
343  Current.p->mID = -1;
345  }
346  }
347  else
348  { fprintf(stderr,"MetropolisHastingsSampler_randshift: Connection inconsistent 3\n"); break; }
349 
350  if (Next.p == nullptr) // no successor
351  {
352  Next.ep = 0; // mark as empty successor
353  break;
354  }
355  else
356  {
357  if (Next.p->pID == Current.p->ID)
358  {
359  Next.p->pID = -1;
360  Next.ep = 1;
361  }
362  else if (Next.p->mID == Current.p->ID)
363  {
364  Next.p->mID = -1;
365  Next.ep = -1;
366  }
367  else
368  { fprintf(stderr,"MetropolisHastingsSampler_randshift: Connection inconsistent 4\n"); break; }
369  }
370 
372  AccumProb *= (m_SimpSamp.probFor(Next));
373 
374  if (Next.p == nullptr) // no successor -> break
375  break;
376 
377  energy += m_EnergyComputer->ComputeInternalEnergyConnection(Current.p,Current.ep,Next.p,Next.ep);
378 
379  Current = Next;
380  Current.ep *= -1;
381  cnt++;
382  m_BackupTrack.track[cnt] = Current;
383 
384  if (m_RandGen->GetVariate() > m_DelProb)
385  break;
386  }
387  m_BackupTrack.m_Energy = energy;
388  m_BackupTrack.m_Probability = AccumProb;
389  m_BackupTrack.m_Length = cnt+1;
390 }
391 
392 // generate new track using kind of a local tracking starting from P in the given direction, store it in m_ProposalTrack and calculate its probability
394 {
395  EndPoint Current = P;
396  int cnt = 0;
397  float energy = 0;
398  float AccumProb = 1.0;
399  m_ProposalTrack.track[cnt++] = Current;
400  Current.p->label = 1;
401 
402  for (;;)
403  {
404  // next candidate is already connected
405  if ((Current.ep == 1 && Current.p->pID != -1) || (Current.ep == -1 && Current.p->mID != -1))
406  break;
407 
408  // track too long
409 // if (cnt > 250)
410 // break;
411 
413 
414  int k = m_SimpSamp.draw(m_RandGen->GetVariate());
415 
416  // stop tracking proposed
417  if (k==0)
418  break;
419 
420  EndPoint Next = m_SimpSamp.objs[k];
421  float probability = m_SimpSamp.probFor(k);
422 
423  // accumulate energy and proposal distribution
424  energy += m_EnergyComputer->ComputeInternalEnergyConnection(Current.p,Current.ep,Next.p,Next.ep);
425  AccumProb *= probability;
426 
427  // track to next endpoint
428  Current = Next;
429  Current.ep *= -1;
430 
431  Current.p->label = 1; // put label to avoid loops
432  m_ProposalTrack.track[cnt++] = Current;
433  }
434 
435  m_ProposalTrack.m_Energy = energy;
436  m_ProposalTrack.m_Probability = AccumProb;
438 
439  // clear labels
440  for (int j = 0; j < m_ProposalTrack.m_Length;j++)
441  m_ProposalTrack.track[j].p->label = 0;
442 }
443 
444 // get neigbouring particles of P and calculate the according connection probabilities
446 {
447  Particle *p = P.p;
448  int ep = P.ep;
449 
450  float dist,dot;
451  vnl_vector_fixed<float, 3> R = p->GetPos() + (p->GetDir() * (ep*m_ParticleLength) );
453  m_SimpSamp.clear();
454 
455  m_SimpSamp.add(m_StopProb,EndPoint(nullptr,0));
456 
457  for (;;)
458  {
460  if (p2 == nullptr) break;
461  if (p!=p2 && p2->label == 0)
462  {
463  if (p2->mID == -1)
464  {
465  dist = (p2->GetPos() - p2->GetDir() * m_ParticleLength - R).squared_magnitude();
466  if (dist < m_DistanceThreshold)
467  {
468  dot = dot_product(p2->GetDir(),p->GetDir()) * ep;
469  if (dot > m_CurvatureThreshold)
470  {
471  float en = m_EnergyComputer->ComputeInternalEnergyConnection(p,ep,p2,-1);
472  m_SimpSamp.add(exp(en/m_TractProb),EndPoint(p2,-1));
473  }
474  }
475  }
476  if (p2->pID == -1)
477  {
478  dist = (p2->GetPos() + p2->GetDir() * m_ParticleLength - R).squared_magnitude();
479  if (dist < m_DistanceThreshold)
480  {
481  dot = dot_product(p2->GetDir(),p->GetDir()) * (-ep);
482  if (dot > m_CurvatureThreshold)
483  {
484  float en = m_EnergyComputer->ComputeInternalEnergyConnection(p,ep,p2,+1);
485  m_SimpSamp.add(exp(en/m_TractProb),EndPoint(p2,+1));
486  }
487  }
488  }
489  }
490  }
491 }
492 
493 // return number of accepted proposals
495 {
496  return m_AcceptedProposals;
497 }
498 
499 
void CreateConnection(Particle *P1, int ep1, Particle *P2, int ep2)
void add(float p, EndPoint obj)
Definition: mitkSimpSamp.h:59
Track m_ProposalTrack
stores proposal track
A particle is the basic element of the Gibbs fiber tractography method.
Definition: mitkParticle.h:30
vnl_vector_fixed< float, 3 > & GetDir()
Definition: mitkParticle.h:56
void SetProbabilities(float birth, float death, float shift, float optShift, float connect)
update the probabilities of the single proposals
virtual float ComputeInternalEnergy(Particle *dp)=0
void ComputeNeighbors(vnl_vector_fixed< float, 3 > &R)
Contains and manages particles.
std::vector< EndPoint > track
float m_ConnectionProb
probability for particle connection proposal
Calculates internal and external energy of the new particle configuration proposal.
SimpSamp m_SimpSamp
neighbouring particles and their probabilities for the local tracking
float m_DeathProb
probability for particle death
float probFor(int idx)
Definition: mitkSimpSamp.h:104
DataCollection - Class to facilitate loading/accessing structured data.
void PrintProposalTimes()
print the state of the proposal time probes
float m_OptShiftProb
probability for optimal particle shift
vnl_vector_fixed< float, 3 > & GetPos()
Definition: mitkParticle.h:51
ParticleGrid * m_ParticleGrid
storest all particles
EndPoint * objs
Definition: mitkSimpSamp.h:39
float m_BirthProb
probability for particle birth
int draw(float prob)
Definition: mitkSimpSamp.h:66
Particle * GetParticle(int ID)
float m_ExTemp
simulated annealing temperature
unsigned int m_AcceptedProposals
counts accepted proposals
virtual float ComputeExternalEnergy(vnl_vector_fixed< float, 3 > &R, vnl_vector_fixed< float, 3 > &N, Particle *dp)=0
EnergyComputer * m_EnergyComputer
computes internal and external energy of particles
void DrawRandomPosition(vnl_vector_fixed< float, 3 > &R)
vnl_vector_fixed< float, 3 > GetRandomDirection()
MetropolisHastingsSampler(ParticleGrid *grid, EnergyComputer *enComp, ItkRandGenType *randGen, float curvThres)
itk::Statistics::MersenneTwisterRandomVariateGenerator ItkRandGenType
float m_InTemp
simulated annealing temperature
float m_ShiftProb
probability for particle shift
float m_CurvatureThreshold
threshold for maximum angle between connected particles
MITKCORE_EXPORT const ScalarType eps
virtual float ComputeInternalEnergyConnection(Particle *p1, int ep1)=0
ItkRandGenType * m_RandGen
random generator
Track m_BackupTrack
stores track removed for new proposal traCK
Particle * GetNextNeighbor()
void MakeProposal()
make proposal for birth/death/shift/connection of particles
Particle * p
Definition: mitkParticle.h:87
float m_DistanceThreshold
threshold for maximum distance between connected particles
unsigned char label
Definition: mitkParticle.h:49
Particle * NewParticle(vnl_vector_fixed< float, 3 > R)
void DistortVector(float sigma, vnl_vector_fixed< float, 3 > &vec)