Medical Imaging Interaction Toolkit  2018.4.99-389bf124
Medical Imaging Interaction Toolkit
mitkLibSVMClassifier.cpp
Go to the documentation of this file.
1 /*============================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center (DKFZ)
6 All rights reserved.
7 
8 Use of this source code is governed by a 3-clause BSD license that can be
9 found in the LICENSE file.
10 
11 ============================================================================*/
12 
13 #include <mitkLibSVMClassifier.h>
14 
16 #include <svm.h>
17 #include <mitkExceptionMacro.h>
18 
20  m_Model(nullptr),m_Parameter(nullptr)
21 {
22  this->m_Parameter = new svm_parameter();
23 }
24 
26 {
27  if (m_Model)
28  {
30  }
31  if( m_Parameter)
32  svm_destroy_param(m_Parameter);
33 }
34 
35 void mitk::LibSVMClassifier::Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y)
36 {
37  this->SetPointWiseWeight(Eigen::MatrixXd(Y.rows(),1));
38  this->UsePointWiseWeight(false);
39 
40  svm_node *xSpace;
41  svm_problem problem;
42 
44  ReadYValues(&problem, Y);
45  ReadXValues(&problem, &xSpace,X);
46  ReadWValues(&problem);
47 
48  const char * error_msg = nullptr;
49  error_msg = svm_check_parameter(&problem, m_Parameter);
50  if (error_msg)
51  {
52  svm_destroy_param(m_Parameter);
53  free(problem.y);
54  free(problem.x);
55  free(xSpace);
56  mitkThrow() << "Error: " << error_msg;
57  }
58 
59  m_Model = svm_train(&problem, m_Parameter);
60 
61  // free(problem.y);
62  // free(problem.x);
63  // free(xSpace);
64 }
65 
66 Eigen::MatrixXi mitk::LibSVMClassifier::Predict(const Eigen::MatrixXd &X)
67 {
68  if ( ! m_Model)
69  {
70  mitkThrow() << "No Model is trained. Train or load a model before predicting new values.";
71  }
72  auto noOfPoints = static_cast<int>(X.rows());
73  auto noOfFeatures = static_cast<int>(X.cols());
74 
75  Eigen::MatrixXi result(noOfPoints,1);
76 
77  auto * xVector = static_cast<svm_node *>(malloc(sizeof(svm_node) * (noOfFeatures+1)));
78  for (int point = 0; point < noOfPoints; ++point)
79  {
80  for (int feature = 0; feature < noOfFeatures; ++feature)
81  {
82  xVector[feature].index = feature+1;
83  xVector[feature].value = X(point, feature);
84  }
85  xVector[noOfFeatures].index = -1;
86  result(point,0) = svm_predict(m_Model,xVector);
87  }
88 
89  free(xVector);
90  return result;
91 }
92 
94 {
95  // Get the proerty // Some defaults
96  if(!this->GetPropertyList()->Get("classifier.svm.svm-type",this->m_Parameter->svm_type)) this->m_Parameter->svm_type = 0;
97  if(!this->GetPropertyList()->Get("classifier.svm.kernel-type",this->m_Parameter->kernel_type)) this->m_Parameter->kernel_type = 2;
98  if(!this->GetPropertyList()->Get("classifier.svm.degree",this->m_Parameter->degree)) this->m_Parameter->degree = 3;
99  if(!this->GetPropertyList()->Get("classifier.svm.gamma",this->m_Parameter->gamma)) this->m_Parameter->gamma = 0; // 1/n_features;
100  if(!this->GetPropertyList()->Get("classifier.svm.coef0",this->m_Parameter->coef0)) this->m_Parameter->coef0 = 0;
101  if(!this->GetPropertyList()->Get("classifier.svm.nu",this->m_Parameter->nu)) this->m_Parameter->nu = 0.5;
102  if(!this->GetPropertyList()->Get("classifier.svm.cache-size",this->m_Parameter->cache_size)) this->m_Parameter->cache_size = 100.0;
103  if(!this->GetPropertyList()->Get("classifier.svm.c",this->m_Parameter->C)) this->m_Parameter->C = 1.0;
104  if(!this->GetPropertyList()->Get("classifier.svm.eps",this->m_Parameter->eps)) this->m_Parameter->eps = 1e-3;
105  if(!this->GetPropertyList()->Get("classifier.svm.p",this->m_Parameter->p)) this->m_Parameter->p = 0.1;
106  if(!this->GetPropertyList()->Get("classifier.svm.shrinking",this->m_Parameter->shrinking)) this->m_Parameter->shrinking = 1;
107  if(!this->GetPropertyList()->Get("classifier.svm.probability",this->m_Parameter->probability)) this->m_Parameter->probability = 0;
108  if(!this->GetPropertyList()->Get("classifier.svm.nr-weight",this->m_Parameter->nr_weight)) this->m_Parameter->nr_weight = 0;
109 
110  //options:
111  //-s svm_type : set type of SVM (default 0)
112  // 0 -- C-SVC
113  // 1 -- nu-SVC
114  // 2 -- one-class SVM
115  // 3 -- epsilon-SVR
116  // 4 -- nu-SVR
117  //-t kernel_type : set type of kernel function (default 2)
118  // 0 -- linear: u'*v
119  // 1 -- polynomial: (gamma*u'*v + coef0)^degree
120  // 2 -- radial basis function: exp(-gamma*|u-v|^2)
121  // 3 -- sigmoid: tanh(gamma*u'*v + coef0)
122  //-d degree : set degree in kernel function (default 3)
123  //-g gamma : set gamma in kernel function (default 1/num_features)
124  //-r coef0 : set coef0 in kernel function (default 0)
125  //-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)
126  //-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)
127  //-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)
128  //-m cachesize : set cache memory size in MB (default 100)
129  //-e epsilon : set tolerance of termination criterion (default 0.001)
130  //-h shrinking: whether to use the shrinking heuristics, 0 or 1 (default 1)
131  //-b probability_estimates: whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)
132  //-wiweight: set the parameter C of class i to weight*C, for C-SVC (default 1)
133 
134  // this->m_Parameter->weight_label = nullptr;
135  // this->m_Parameter->weight = 1;
136 }
137 
138 /* these are for training only */
139 //int *weight_label; /* for C_SVC */
140 //double* weight; /* for C_SVC */
141 
143 {
144  this->GetPropertyList()->SetIntProperty("classifier.svm.probability",val);
145 }
146 
148 {
149  this->GetPropertyList()->SetIntProperty("classifier.svm.shrinking",val);
150 }
151 
153 {
154  this->GetPropertyList()->SetIntProperty("classifier.svm.nr_weight",val);
155 }
156 
158 {
159  this->GetPropertyList()->SetDoubleProperty("classifier.svm.nu",val);
160 }
161 
163 {
164  this->GetPropertyList()->SetDoubleProperty("classifier.svm.p",val);
165 }
166 
168 {
169  this->GetPropertyList()->SetDoubleProperty("classifier.svm.eps",val);
170 }
171 
173 {
174  this->GetPropertyList()->SetDoubleProperty("classifier.svm.c",val);
175 }
176 
178 {
179  this->GetPropertyList()->SetDoubleProperty("classifier.svm.cache-size",val);
180 }
181 
183 {
184  this->GetPropertyList()->SetIntProperty("classifier.svm.svm-type",val);
185 }
186 
188 {
189  this->GetPropertyList()->SetIntProperty("classifier.svm.kernel-type",val);
190 }
191 
193 {
194  this->GetPropertyList()->SetIntProperty("classifier.svm.degree",val);
195 }
196 
198 {
199  this->GetPropertyList()->SetDoubleProperty("classifier.svm.gamma",val);
200 }
201 
203 {
204  this->GetPropertyList()->SetDoubleProperty("classifier.svm.coef0",val);
205 }
206 
208 {
209  if(this->m_Parameter == nullptr)
210  {
211  MITK_WARN("LibSVMClassifier") << "Parameters are not initialized. Please call ConvertParameter() first!";
212  return;
213  }
214 
215  this->ConvertParameter();
216 
217  // Get the proerty // Some defaults
218  if(!this->GetPropertyList()->Get("classifier.svm.svm-type",this->m_Parameter->svm_type))
219  str << "classifier.svm.svm-type\tNOT SET (default " << this->m_Parameter->svm_type << ")" << "\n";
220  else
221  str << "classifier.svm.svm-type\t" << this->m_Parameter->svm_type << "\n";
222 
223  if(!this->GetPropertyList()->Get("classifier.svm.kernel-type",this->m_Parameter->kernel_type))
224  str << "classifier.svm.kernel-type\tNOT SET (default " << this->m_Parameter->kernel_type << ")" << "\n";
225  else
226  str << "classifier.svm.kernel-type\t" << this->m_Parameter->kernel_type << "\n";
227 
228  if(!this->GetPropertyList()->Get("classifier.svm.degree",this->m_Parameter->degree))
229  str << "classifier.svm.degree\t\tNOT SET (default " << this->m_Parameter->degree << ")" << "\n";
230  else
231  str << "classifier.svm.degree\t\t" << this->m_Parameter->degree << "\n";
232 
233  if(!this->GetPropertyList()->Get("classifier.svm.gamma",this->m_Parameter->gamma))
234  str << "classifier.svm.gamma\t\tNOT SET (default " << this->m_Parameter->gamma << ")" << "\n";
235  else
236  str << "classifier.svm.gamma\t\t" << this->m_Parameter->gamma << "\n";
237 
238  if(!this->GetPropertyList()->Get("classifier.svm.coef0",this->m_Parameter->coef0))
239  str << "classifier.svm.coef0\t\tNOT SET (default " << this->m_Parameter->coef0 << ")" << "\n";
240  else
241  str << "classifier.svm.coef0\t\t" << this->m_Parameter->coef0 << "\n";
242 
243  if(!this->GetPropertyList()->Get("classifier.svm.nu",this->m_Parameter->nu))
244  str << "classifier.svm.nu\t\tNOT SET (default " << this->m_Parameter->nu << ")" << "\n";
245  else
246  str << "classifier.svm.nu\t\t" << this->m_Parameter->nu << "\n";
247 
248  if(!this->GetPropertyList()->Get("classifier.svm.cache-size",this->m_Parameter->cache_size))
249  str << "classifier.svm.cache-size\tNOT SET (default " << this->m_Parameter->cache_size << ")" << "\n";
250  else
251  str << "classifier.svm.cache-size\t" << this->m_Parameter->cache_size << "\n";
252 
253  if(!this->GetPropertyList()->Get("classifier.svm.c",this->m_Parameter->C))
254  str << "classifier.svm.c\t\tNOT SET (default " << this->m_Parameter->C << ")" << "\n";
255  else
256  str << "classifier.svm.c\t\t" << this->m_Parameter->C << "\n";
257 
258  if(!this->GetPropertyList()->Get("classifier.svm.eps",this->m_Parameter->eps))
259  str << "classifier.svm.eps\t\tNOT SET (default " << this->m_Parameter->eps << ")" << "\n";
260  else
261  str << "classifier.svm.eps\t\t" << this->m_Parameter->eps << "\n";
262 
263  if(!this->GetPropertyList()->Get("classifier.svm.p",this->m_Parameter->p))
264  str << "classifier.svm.p\t\tNOT SET (default " << this->m_Parameter->p << ")" << "\n";
265  else
266  str << "classifier.svm.p\t\t" << this->m_Parameter->p << "\n";
267 
268  if(!this->GetPropertyList()->Get("classifier.svm.shrinking",this->m_Parameter->shrinking))
269  str << "classifier.svm.shrinking\tNOT SET (default " << this->m_Parameter->shrinking << ")" << "\n";
270  else
271  str << "classifier.svm.shrinking\t" << this->m_Parameter->shrinking << "\n";
272 
273  if(!this->GetPropertyList()->Get("classifier.svm.probability",this->m_Parameter->probability))
274  str << "classifier.svm.probability\tNOT SET (default " << this->m_Parameter->probability << ")" << "\n";
275  else
276  str << "classifier.svm.probability\t" << this->m_Parameter->probability << "\n";
277 
278  if(!this->GetPropertyList()->Get("classifier.svm.nr-weight",this->m_Parameter->nr_weight))
279  str << "classifier.svm.nr-weight\tNOT SET (default " << this->m_Parameter->nr_weight << ")" << "\n";
280  else
281  str << "classifier.svm.nr-weight\t" << this->m_Parameter->nr_weight << "\n";
282 }
283 
284 // Trying to assign from matrix to noOfPoints
285 void mitk::LibSVMClassifier::ReadXValues(svm_problem * problem, svm_node** xSpace, const Eigen::MatrixXd &X)
286 {
287  auto noOfPoints = static_cast<int>(X.rows());
288  auto features = static_cast<int>(X.cols());
289 
290  problem->x = static_cast<svm_node **>(malloc(sizeof(svm_node *) * noOfPoints));
291  (*xSpace) = static_cast<svm_node *> (malloc(sizeof(svm_node) * noOfPoints * (features+1)));
292 
293  for (int row = 0; row < noOfPoints; ++row)
294  {
295  for (int col = 0; col < features; ++col)
296  {
297  (*xSpace)[row*features + col].index = col;
298  (*xSpace)[row*features + col].value = X(row,col);
299  }
300  (*xSpace)[row*features + features].index = -1;
301 
302  problem->x[row] = &((*xSpace)[row*features]);
303  }
304 }
305 
306 void mitk::LibSVMClassifier::ReadYValues(svm_problem * problem, const Eigen::MatrixXi &Y)
307 {
308  problem->l = static_cast<int>(Y.rows());
309  problem->y = static_cast<double *>(malloc(sizeof(double) * problem->l));
310 
311  for (int i = 0; i < problem->l; ++i)
312  {
313  problem->y[i] = Y(i,0);
314  }
315 }
316 
317 void mitk::LibSVMClassifier::ReadWValues(svm_problem * problem)
318 {
319  Eigen::MatrixXd & W = this->GetPointWiseWeight();
320  int noOfPoints = problem->l;
321  problem->W = static_cast<double *>(malloc(sizeof(double) * noOfPoints));
322 
324  {
325  for (int i = 0; i < noOfPoints; ++i)
326  {
327  problem->W[i] = W(i,0);
328  }
329  } else
330  {
331  for (int i = 0; i < noOfPoints; ++i)
332  {
333  problem->W[i] = 1;
334  }
335  }
336 }
virtual bool IsUsingPointWiseWeight()
IsUsingPointWiseWeight.
double * W
Definition: svm.h:70
void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y) override
Build a forest of trees from the training set (X, y).
void svm_destroy_param(struct svm_parameter *param)
Definition: svm.cpp:3152
int nr_weight
Definition: svm.h:88
const char * svm_check_parameter(const struct svm_problem *prob, const struct svm_parameter *param)
double p
Definition: svm.h:92
double cache_size
Definition: svm.h:85
virtual Eigen::MatrixXd & GetPointWiseWeight()
GetPointWiseWeightCopy.
#define MITK_WARN
Definition: mitkLogMacros.h:19
double eps
Definition: svm.h:86
#define mitkThrow()
void PrintParameter(std::ostream &str)
int shrinking
Definition: svm.h:93
struct svm_node ** x
Definition: svm.h:69
virtual void UsePointWiseWeight(bool value)
UsePointWiseWeight.
int index
Definition: svm.h:61
mitk::PropertyList::Pointer GetPropertyList() const
Get the data&#39;s property list.
const char features[]
Eigen::MatrixXi Predict(const Eigen::MatrixXd &X) override
Predict class for X.
int probability
Definition: svm.h:94
int degree
Definition: svm.h:80
struct svm_model * svm_train(const struct svm_problem *prob, const struct svm_parameter *param)
Definition: svm.h:59
double * y
Definition: svm.h:68
double svm_predict(const struct svm_model *model, const struct svm_node *x)
double gamma
Definition: svm.h:81
int l
Definition: svm.h:67
virtual void SetPointWiseWeight(const Eigen::MatrixXd &W)
SetPointWiseWeight.
double C
Definition: svm.h:87
int svm_type
Definition: svm.h:78
double nu
Definition: svm.h:91
double coef0
Definition: svm.h:82
int kernel_type
Definition: svm.h:79
void svm_free_and_destroy_model(struct svm_model **model_ptr_ptr)
Definition: svm.cpp:3142