Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
mitkMetropolisHastingsSampler.cpp
Go to the documentation of this file.
1 /*===================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
7 All rights reserved.
8 
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 A PARTICULAR PURPOSE.
12 
13 See LICENSE.txt or http://www.mitk.org for details.
14 
15 ===================================================================*/
16 
18 
19 using namespace mitk;
20 
22  : m_ExTemp(0.01)
23  , m_BirthProb(0.25)
24  , m_DeathProb(0.05)
25  , m_ShiftProb(0.15)
26  , m_OptShiftProb(0.1)
27  , m_ConnectionProb(0.45)
28  , m_TractProb(0.5)
29  , m_DelProb(0.1)
30  , m_ChempotParticle(0.0)
31  , m_AcceptedProposals(0)
32 {
33  m_RandGen = randGen;
34  m_ParticleGrid = grid;
35  m_EnergyComputer = enComp;
36 
39  m_Sigma = m_ParticleLength/8.0;
40  m_Gamma = 1/(m_Sigma*m_Sigma*2);
41  m_Z = pow(2*M_PI*m_Sigma,3.0/2.0)*(M_PI*m_Sigma/m_ParticleLength);
42 
43  m_CurvatureThreshold = curvThres;
44  m_StopProb = exp(-1/m_TractProb);
45 }
46 
47 void MetropolisHastingsSampler::SetProbabilities(float birth, float death, float shift, float optShift, float connect)
48 {
49  m_BirthProb = birth;
50  m_DeathProb = death;
51  m_ShiftProb = shift;
52  m_OptShiftProb = optShift;
53  m_ConnectionProb = connect;
55  if (sum!=1 && sum>mitk::eps)
56  {
57  m_BirthProb /= sum;
58  m_DeathProb /= sum;
59  m_ShiftProb /= sum;
60  m_OptShiftProb /= sum;
61  m_ConnectionProb /= sum;
62  }
63  std::cout << "Update proposal probabilities" << std::endl;
64  std::cout << "Birth: " << m_BirthProb << std::endl;
65  std::cout << "Death: " << m_DeathProb << std::endl;
66  std::cout << "Shift: " << m_ShiftProb << std::endl;
67  std::cout << "Optimal shift: " << m_OptShiftProb << std::endl;
68  std::cout << "Connection: " << m_ConnectionProb << std::endl;
69 }
70 
71 // print proposal times
73 {
74  double sum = m_BirthTime.GetTotal()+m_DeathTime.GetTotal()+m_ShiftTime.GetTotal()+m_OptShiftTime.GetTotal()+m_ConnectionTime.GetTotal();
75  std::cout << "Proposal time probes (toal%/mean)" << std::endl;
76  std::cout << "Birth: " << 100*m_BirthTime.GetTotal()/sum << "/" << m_BirthTime.GetMean()*1000 << std::endl;
77  std::cout << "Death: " << 100*m_DeathTime.GetTotal()/sum << "/" << m_DeathTime.GetMean()*1000 << std::endl;
78  std::cout << "Shift: " << 100*m_ShiftTime.GetTotal()/sum << "/" << m_ShiftTime.GetMean()*1000 << std::endl;
79  std::cout << "Optimal shift: " << 100*m_OptShiftTime.GetTotal()/sum << " - " << m_OptShiftTime.GetMean()*1000 << std::endl;
80  std::cout << "Connection: " << 100*m_ConnectionTime.GetTotal()/sum << "/" << m_ConnectionTime.GetMean()*1000 << std::endl;
81 }
82 
83 // update temperature of simulated annealing process
85 {
86  m_InTemp = val;
88 }
89 
90 // add small random number drawn from gaussian to each vector element
91 void MetropolisHastingsSampler::DistortVector(float sigma, vnl_vector_fixed<float, 3>& vec)
92 {
93  vec[0] += m_RandGen->GetNormalVariate(0.0, sigma);
94  vec[1] += m_RandGen->GetNormalVariate(0.0, sigma);
95  vec[2] += m_RandGen->GetNormalVariate(0.0, sigma);
96 }
97 
98 // generate normalized random vector
100 {
101  vnl_vector_fixed<float, 3> vec;
102  vec[0] = m_RandGen->GetNormalVariate();
103  vec[1] = m_RandGen->GetNormalVariate();
104  vec[2] = m_RandGen->GetNormalVariate();
105  vec.normalize();
106  return vec;
107 }
108 
109 // generate actual proposal (birth, death, shift and connection of particle)
111 {
112  float randnum = m_RandGen->GetVariate();
113 
114  // Birth Proposal
115  if (randnum < m_BirthProb)
116  {
117  m_BirthTime.Start();
118  vnl_vector_fixed<float, 3> R;
120  vnl_vector_fixed<float, 3> N = GetRandomDirection();
121  Particle prop;
122  prop.GetPos() = R;
123  prop.GetDir() = N;
124 
126 
127  float ex_energy = m_EnergyComputer->ComputeExternalEnergy(R,N,nullptr);
128  float in_energy = m_EnergyComputer->ComputeInternalEnergy(&prop);
129  prob *= exp((in_energy/m_InTemp+ex_energy/m_ExTemp)) ;
130 
131  if (prob > 1 || m_RandGen->GetVariate() < prob)
132  {
134  if (p!=nullptr)
135  {
136  p->GetPos() = R;
137  p->GetDir() = N;
139  }
140  }
141  m_BirthTime.Stop();
142  }
143  // Death Proposal
144  else if (randnum < m_BirthProb+m_DeathProb)
145  {
146  m_DeathTime.Start();
148  {
149  int pnum = m_RandGen->GetIntegerVariate()%m_ParticleGrid->m_NumParticles;
150  Particle *dp = m_ParticleGrid->GetParticle(pnum);
151  if (dp->pID == -1 && dp->mID == -1)
152  {
153  float ex_energy = m_EnergyComputer->ComputeExternalEnergy(dp->GetPos(),dp->GetDir(),dp);
154  float in_energy = m_EnergyComputer->ComputeInternalEnergy(dp);
155 
156  float prob = m_ParticleGrid->m_NumParticles * (m_BirthProb) /(m_Density*m_DeathProb); //*SpatProb(dp->R);
157  prob *= exp(-(in_energy/m_InTemp+ex_energy/m_ExTemp)) ;
158  if (prob > 1 || m_RandGen->GetVariate() < prob)
159  {
162  }
163  }
164  }
165  m_DeathTime.Stop();
166  }
167  // Shift Proposal
168  else if (randnum < m_BirthProb+m_DeathProb+m_ShiftProb)
169  {
171  {
172  m_ShiftTime.Start();
173  int pnum = m_RandGen->GetIntegerVariate()%m_ParticleGrid->m_NumParticles;
174  Particle *p = m_ParticleGrid->GetParticle(pnum);
175  Particle prop_p = *p;
176 
177  DistortVector(m_Sigma, prop_p.GetPos());
179  prop_p.GetDir().normalize();
180 
181 
182  float ex_energy = m_EnergyComputer->ComputeExternalEnergy(prop_p.GetPos(),prop_p.GetDir(),p)
185 
186  float prob = exp(ex_energy/m_ExTemp+in_energy/m_InTemp);
187  if (m_RandGen->GetVariate() < prob)
188  {
189  vnl_vector_fixed<float, 3> Rtmp = p->GetPos();
190  vnl_vector_fixed<float, 3> Ntmp = p->GetDir();
191  p->GetPos() = prop_p.GetPos();
192  p->GetDir() = prop_p.GetDir();
193  if (!m_ParticleGrid->TryUpdateGrid(pnum))
194  {
195  p->GetPos() = Rtmp;
196  p->GetDir() = Ntmp;
197  }
199  }
200  m_ShiftTime.Stop();
201  }
202  }
203  // Optimal Shift Proposal
205  {
207  {
208  m_OptShiftTime.Start();
209  int pnum = m_RandGen->GetIntegerVariate()%m_ParticleGrid->m_NumParticles;
210  Particle *p = m_ParticleGrid->GetParticle(pnum);
211 
212  bool no_proposal = false;
213  Particle prop_p = *p;
214  if (p->pID != -1 && p->mID != -1)
215  {
216  Particle *plus = m_ParticleGrid->GetParticle(p->pID);
217  int ep_plus = (plus->pID == p->ID)? 1 : -1;
218  Particle *minus = m_ParticleGrid->GetParticle(p->mID);
219  int ep_minus = (minus->pID == p->ID)? 1 : -1;
220  prop_p.GetPos() = (plus->GetPos() + plus->GetDir() * (m_ParticleLength * ep_plus) + minus->GetPos() + minus->GetDir() * (m_ParticleLength * ep_minus));
221  prop_p.GetPos() *= 0.5;
222  prop_p.GetDir() = plus->GetPos() - minus->GetPos();
223  prop_p.GetDir().normalize();
224  }
225  else if (p->pID != -1)
226  {
227  Particle *plus = m_ParticleGrid->GetParticle(p->pID);
228  int ep_plus = (plus->pID == p->ID)? 1 : -1;
229  prop_p.GetPos() = plus->GetPos() + plus->GetDir() * (m_ParticleLength * ep_plus * 2);
230  prop_p.GetDir() = plus->GetDir();
231  }
232  else if (p->mID != -1)
233  {
234  Particle *minus = m_ParticleGrid->GetParticle(p->mID);
235  int ep_minus = (minus->pID == p->ID)? 1 : -1;
236  prop_p.GetPos() = minus->GetPos() + minus->GetDir() * (m_ParticleLength * ep_minus * 2);
237  prop_p.GetDir() = minus->GetDir();
238  }
239  else
240  no_proposal = true;
241 
242  if (!no_proposal)
243  {
244  float cos = dot_product(prop_p.GetDir(), p->GetDir());
245  float p_rev = exp(-((prop_p.GetPos()-p->GetPos()).squared_magnitude() + (1-cos*cos))*m_Gamma)/m_Z;
246 
247  float ex_energy = m_EnergyComputer->ComputeExternalEnergy(prop_p.GetPos(),prop_p.GetDir(),p)
250 
251  float prob = exp(ex_energy/m_ExTemp+in_energy/m_InTemp)*m_ShiftProb*p_rev/(m_OptShiftProb+m_ShiftProb*p_rev);
252 
253  if (m_RandGen->GetVariate() < prob)
254  {
255  vnl_vector_fixed<float, 3> Rtmp = p->GetPos();
256  vnl_vector_fixed<float, 3> Ntmp = p->GetDir();
257  p->GetPos() = prop_p.GetPos();
258  p->GetDir() = prop_p.GetDir();
259  if (!m_ParticleGrid->TryUpdateGrid(pnum))
260  {
261  p->GetPos() = Rtmp;
262  p->GetDir() = Ntmp;
263  }
265  }
266  }
267  m_OptShiftTime.Stop();
268  }
269  }
270  // Connection Proposal
271  else
272  {
274  {
275  m_ConnectionTime.Start();
276  int pnum = m_RandGen->GetIntegerVariate()%m_ParticleGrid->m_NumParticles;
277  Particle *p = m_ParticleGrid->GetParticle(pnum);
278 
279  EndPoint P;
280  P.p = p;
281  P.ep = (m_RandGen->GetVariate() > 0.5)? 1 : -1; // direction of the new tract
282 
283  RemoveAndSaveTrack(P); // remove old tract and save it for later
284  if (m_BackupTrack.m_Probability != 0)
285  {
286  MakeTrackProposal(P); // propose new tract starting from P
287 
289 
292  if (m_RandGen->GetVariate() < prob)
293  {
294  ImplementTrack(m_ProposalTrack); // accept proposed tract
296  }
297  else
298  {
299  ImplementTrack(m_BackupTrack); // reject proposed tract and restore old one
300  }
301  }
302  else
304  m_ConnectionTime.Stop();
305  }
306  }
307 }
308 
309 // establish connections between particles stored in input Track
311 {
312  for (int k = 1; k < T.m_Length;k++)
313  m_ParticleGrid->CreateConnection(T.track[k-1].p,T.track[k-1].ep,T.track[k].p,-T.track[k].ep);
314 }
315 
316 // remove pending track from random particle, save it in m_BackupTrack and calculate its probability
318 {
319  EndPoint Current = P;
320  int cnt = 0;
321  float energy = 0;
322  float AccumProb = 1.0;
323  m_BackupTrack.track[cnt] = Current;
324  EndPoint Next;
325 
326  for (;;)
327  {
328  Next.p = nullptr;
329  if (Current.ep == 1)
330  {
331  if (Current.p->pID != -1)
332  {
333  Next.p = m_ParticleGrid->GetParticle(Current.p->pID);
334  Current.p->pID = -1;
336  }
337  }
338  else if (Current.ep == -1)
339  {
340  if (Current.p->mID != -1)
341  {
342  Next.p = m_ParticleGrid->GetParticle(Current.p->mID);
343  Current.p->mID = -1;
345  }
346  }
347  else
348  { fprintf(stderr,"MetropolisHastingsSampler_randshift: Connection inconsistent 3\n"); break; }
349 
350  if (Next.p == nullptr) // no successor
351  {
352  Next.ep = 0; // mark as empty successor
353  break;
354  }
355  else
356  {
357  if (Next.p->pID == Current.p->ID)
358  {
359  Next.p->pID = -1;
360  Next.ep = 1;
361  }
362  else if (Next.p->mID == Current.p->ID)
363  {
364  Next.p->mID = -1;
365  Next.ep = -1;
366  }
367  else
368  { fprintf(stderr,"MetropolisHastingsSampler_randshift: Connection inconsistent 4\n"); break; }
369  }
370 
372  AccumProb *= (m_SimpSamp.probFor(Next));
373 
374  if (Next.p == nullptr) // no successor -> break
375  break;
376 
377  energy += m_EnergyComputer->ComputeInternalEnergyConnection(Current.p,Current.ep,Next.p,Next.ep);
378 
379  Current = Next;
380  Current.ep *= -1;
381  cnt++;
382  m_BackupTrack.track[cnt] = Current;
383 
384  if (m_RandGen->GetVariate() > m_DelProb)
385  break;
386  }
387  m_BackupTrack.m_Energy = energy;
388  m_BackupTrack.m_Probability = AccumProb;
389  m_BackupTrack.m_Length = cnt+1;
390 }
391 
392 // generate new track using kind of a local tracking starting from P in the given direction, store it in m_ProposalTrack and calculate its probability
394 {
395  EndPoint Current = P;
396  int cnt = 0;
397  float energy = 0;
398  float AccumProb = 1.0;
399  m_ProposalTrack.track[cnt++] = Current;
400  Current.p->label = 1;
401 
402  for (;;)
403  {
404  // next candidate is already connected
405  if ((Current.ep == 1 && Current.p->pID != -1) || (Current.ep == -1 && Current.p->mID != -1))
406  break;
407 
408  // track too long
409 // if (cnt > 250)
410 // break;
411 
413 
414  int k = m_SimpSamp.draw(m_RandGen->GetVariate());
415 
416  // stop tracking proposed
417  if (k==0)
418  break;
419 
420  EndPoint Next = m_SimpSamp.objs[k];
421  float probability = m_SimpSamp.probFor(k);
422 
423  // accumulate energy and proposal distribution
424  energy += m_EnergyComputer->ComputeInternalEnergyConnection(Current.p,Current.ep,Next.p,Next.ep);
425  AccumProb *= probability;
426 
427  // track to next endpoint
428  Current = Next;
429  Current.ep *= -1;
430 
431  Current.p->label = 1; // put label to avoid loops
432  m_ProposalTrack.track[cnt++] = Current;
433  }
434 
435  m_ProposalTrack.m_Energy = energy;
436  m_ProposalTrack.m_Probability = AccumProb;
438 
439  // clear labels
440  for (int j = 0; j < m_ProposalTrack.m_Length;j++)
441  m_ProposalTrack.track[j].p->label = 0;
442 }
443 
444 // get neigbouring particles of P and calculate the according connection probabilities
446 {
447  Particle *p = P.p;
448  int ep = P.ep;
449 
450  float dist,dot;
451  vnl_vector_fixed<float, 3> R = p->GetPos() + (p->GetDir() * (ep*m_ParticleLength) );
453  m_SimpSamp.clear();
454 
455  m_SimpSamp.add(m_StopProb,EndPoint(nullptr,0));
456 
457  for (;;)
458  {
460  if (p2 == nullptr) break;
461  if (p!=p2 && p2->label == 0)
462  {
463  if (p2->mID == -1)
464  {
465  dist = (p2->GetPos() - p2->GetDir() * m_ParticleLength - R).squared_magnitude();
466  if (dist < m_DistanceThreshold)
467  {
468  dot = dot_product(p2->GetDir(),p->GetDir()) * ep;
469  if (dot > m_CurvatureThreshold)
470  {
471  float en = m_EnergyComputer->ComputeInternalEnergyConnection(p,ep,p2,-1);
472  m_SimpSamp.add(exp(en/m_TractProb),EndPoint(p2,-1));
473  }
474  }
475  }
476  if (p2->pID == -1)
477  {
478  dist = (p2->GetPos() + p2->GetDir() * m_ParticleLength - R).squared_magnitude();
479  if (dist < m_DistanceThreshold)
480  {
481  dot = dot_product(p2->GetDir(),p->GetDir()) * (-ep);
482  if (dot > m_CurvatureThreshold)
483  {
484  float en = m_EnergyComputer->ComputeInternalEnergyConnection(p,ep,p2,+1);
485  m_SimpSamp.add(exp(en/m_TractProb),EndPoint(p2,+1));
486  }
487  }
488  }
489  }
490  }
491 }
492 
493 // return number of accepted proposals
495 {
496  return m_AcceptedProposals;
497 }
498 
499 
void CreateConnection(Particle *P1, int ep1, Particle *P2, int ep2)
void add(float p, EndPoint obj)
Definition: mitkSimpSamp.h:59
Track m_ProposalTrack
stores proposal track
A particle is the basic element of the Gibbs fiber tractography method.
Definition: mitkParticle.h:30
vnl_vector_fixed< float, 3 > & GetDir()
Definition: mitkParticle.h:56
void SetProbabilities(float birth, float death, float shift, float optShift, float connect)
update the probabilities of the single proposals
virtual float ComputeInternalEnergy(Particle *dp)=0
void ComputeNeighbors(vnl_vector_fixed< float, 3 > &R)
Contains and manages particles.
std::vector< EndPoint > track
float m_ConnectionProb
probability for particle connection proposal
Calculates internal and external energy of the new particle configuration proposal.
SimpSamp m_SimpSamp
neighbouring particles and their probabilities for the local tracking
float m_DeathProb
probability for particle death
float probFor(int idx)
Definition: mitkSimpSamp.h:104
DataCollection - Class to facilitate loading/accessing structured data.
void PrintProposalTimes()
print the state of the proposal time probes
float m_OptShiftProb
probability for optimal particle shift
vnl_vector_fixed< float, 3 > & GetPos()
Definition: mitkParticle.h:51
ParticleGrid * m_ParticleGrid
storest all particles
EndPoint * objs
Definition: mitkSimpSamp.h:39
float m_BirthProb
probability for particle birth
int draw(float prob)
Definition: mitkSimpSamp.h:66
Particle * GetParticle(int ID)
float m_ExTemp
simulated annealing temperature
unsigned int m_AcceptedProposals
counts accepted proposals
virtual float ComputeExternalEnergy(vnl_vector_fixed< float, 3 > &R, vnl_vector_fixed< float, 3 > &N, Particle *dp)=0
EnergyComputer * m_EnergyComputer
computes internal and external energy of particles
void DrawRandomPosition(vnl_vector_fixed< float, 3 > &R)
vnl_vector_fixed< float, 3 > GetRandomDirection()
MetropolisHastingsSampler(ParticleGrid *grid, EnergyComputer *enComp, ItkRandGenType *randGen, float curvThres)
itk::Statistics::MersenneTwisterRandomVariateGenerator ItkRandGenType
float m_InTemp
simulated annealing temperature
float m_ShiftProb
probability for particle shift
float m_CurvatureThreshold
threshold for maximum angle between connected particles
MITKCORE_EXPORT const ScalarType eps
virtual float ComputeInternalEnergyConnection(Particle *p1, int ep1)=0
ItkRandGenType * m_RandGen
random generator
Track m_BackupTrack
stores track removed for new proposal traCK
Particle * GetNextNeighbor()
void MakeProposal()
make proposal for birth/death/shift/connection of particles
Particle * p
Definition: mitkParticle.h:87
float m_DistanceThreshold
threshold for maximum distance between connected particles
unsigned char label
Definition: mitkParticle.h:49
Particle * NewParticle(vnl_vector_fixed< float, 3 > R)
void DistortVector(float sigma, vnl_vector_fixed< float, 3 > &vec)