Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
mitkLibSVMClassifier.cpp
Go to the documentation of this file.
1 /*===================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
7 All rights reserved.
8 
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 A PARTICULAR PURPOSE.
12 
13 See LICENSE.txt or http://www.mitk.org for details.
14 
15 ===================================================================*/
16 
17 #include <mitkLibSVMClassifier.h>
18 
20 namespace LibSVM
21 {
22 #include "svm.h"
23 }
24 #include <mitkExceptionMacro.h>
25 
27  m_Model(nullptr),m_Parameter(nullptr)
28 {
29  this->m_Parameter = new LibSVM::svm_parameter();
30 }
31 
33 {
34  if (m_Model)
35  {
37  }
38  if( m_Parameter)
39  LibSVM::svm_destroy_param(m_Parameter);
40 }
41 
42 void mitk::LibSVMClassifier::Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y)
43 {
44  this->SetPointWiseWeight(Eigen::MatrixXd(Y.rows(),1));
45  this->UsePointWiseWeight(false);
46 
47  LibSVM::svm_node *xSpace;
48  LibSVM::svm_problem problem;
49 
50  ConvertParameter();
51  ReadYValues(&problem, Y);
52  ReadXValues(&problem, &xSpace,X);
53  ReadWValues(&problem);
54 
55  const char * error_msg = nullptr;
56  error_msg = LibSVM::svm_check_parameter(&problem, m_Parameter);
57  if (error_msg)
58  {
59  LibSVM::svm_destroy_param(m_Parameter);
60  free(problem.y);
61  free(problem.x);
62  free(xSpace);
63  mitkThrow() << "Error: " << error_msg;
64  }
65 
66  m_Model = LibSVM::svm_train(&problem, m_Parameter);
67 
68  // free(problem.y);
69  // free(problem.x);
70  // free(xSpace);
71 }
72 
73 Eigen::MatrixXi mitk::LibSVMClassifier::Predict(const Eigen::MatrixXd &X)
74 {
75  if ( ! m_Model)
76  {
77  mitkThrow() << "No Model is trained. Train or load a model before predicting new values.";
78  }
79  int noOfPoints = static_cast<int>(X.rows());
80  int noOfFeatures = static_cast<int>(X.cols());
81 
82  Eigen::MatrixXi result(noOfPoints,1);
83 
84  LibSVM::svm_node * xVector = static_cast<LibSVM::svm_node *>(malloc(sizeof(LibSVM::svm_node) * (noOfFeatures+1)));
85  for (int point = 0; point < noOfPoints; ++point)
86  {
87  for (int feature = 0; feature < noOfFeatures; ++feature)
88  {
89  xVector[feature].index = feature+1;
90  xVector[feature].value = X(point, feature);
91  }
92  xVector[noOfFeatures].index = -1;
93  result(point,0) = LibSVM::svm_predict(m_Model,xVector);
94  }
95 
96  free(xVector);
97  return result;
98 }
99 
101 {
102  // Get the proerty // Some defaults
103  if(!this->GetPropertyList()->Get("classifier.svm.svm-type",this->m_Parameter->svm_type)) this->m_Parameter->svm_type = 0;
104  if(!this->GetPropertyList()->Get("classifier.svm.kernel-type",this->m_Parameter->kernel_type)) this->m_Parameter->kernel_type = 2;
105  if(!this->GetPropertyList()->Get("classifier.svm.degree",this->m_Parameter->degree)) this->m_Parameter->degree = 3;
106  if(!this->GetPropertyList()->Get("classifier.svm.gamma",this->m_Parameter->gamma)) this->m_Parameter->gamma = 0; // 1/n_features;
107  if(!this->GetPropertyList()->Get("classifier.svm.coef0",this->m_Parameter->coef0)) this->m_Parameter->coef0 = 0;
108  if(!this->GetPropertyList()->Get("classifier.svm.nu",this->m_Parameter->nu)) this->m_Parameter->nu = 0.5;
109  if(!this->GetPropertyList()->Get("classifier.svm.cache-size",this->m_Parameter->cache_size)) this->m_Parameter->cache_size = 100.0;
110  if(!this->GetPropertyList()->Get("classifier.svm.c",this->m_Parameter->C)) this->m_Parameter->C = 1.0;
111  if(!this->GetPropertyList()->Get("classifier.svm.eps",this->m_Parameter->eps)) this->m_Parameter->eps = 1e-3;
112  if(!this->GetPropertyList()->Get("classifier.svm.p",this->m_Parameter->p)) this->m_Parameter->p = 0.1;
113  if(!this->GetPropertyList()->Get("classifier.svm.shrinking",this->m_Parameter->shrinking)) this->m_Parameter->shrinking = 1;
114  if(!this->GetPropertyList()->Get("classifier.svm.probability",this->m_Parameter->probability)) this->m_Parameter->probability = 0;
115  if(!this->GetPropertyList()->Get("classifier.svm.nr-weight",this->m_Parameter->nr_weight)) this->m_Parameter->nr_weight = 0;
116 
117  //options:
118  //-s svm_type : set type of SVM (default 0)
119  // 0 -- C-SVC
120  // 1 -- nu-SVC
121  // 2 -- one-class SVM
122  // 3 -- epsilon-SVR
123  // 4 -- nu-SVR
124  //-t kernel_type : set type of kernel function (default 2)
125  // 0 -- linear: u'*v
126  // 1 -- polynomial: (gamma*u'*v + coef0)^degree
127  // 2 -- radial basis function: exp(-gamma*|u-v|^2)
128  // 3 -- sigmoid: tanh(gamma*u'*v + coef0)
129  //-d degree : set degree in kernel function (default 3)
130  //-g gamma : set gamma in kernel function (default 1/num_features)
131  //-r coef0 : set coef0 in kernel function (default 0)
132  //-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)
133  //-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)
134  //-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)
135  //-m cachesize : set cache memory size in MB (default 100)
136  //-e epsilon : set tolerance of termination criterion (default 0.001)
137  //-h shrinking: whether to use the shrinking heuristics, 0 or 1 (default 1)
138  //-b probability_estimates: whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)
139  //-wiweight: set the parameter C of class i to weight*C, for C-SVC (default 1)
140 
141  // this->m_Parameter->weight_label = nullptr;
142  // this->m_Parameter->weight = 1;
143 }
144 
145 /* these are for training only */
146 //int *weight_label; /* for C_SVC */
147 //double* weight; /* for C_SVC */
148 
150 {
151  this->GetPropertyList()->SetIntProperty("classifier.svm.probability",val);
152 }
153 
155 {
156  this->GetPropertyList()->SetIntProperty("classifier.svm.shrinking",val);
157 }
158 
160 {
161  this->GetPropertyList()->SetIntProperty("classifier.svm.nr_weight",val);
162 }
163 
165 {
166  this->GetPropertyList()->SetDoubleProperty("classifier.svm.nu",val);
167 }
168 
170 {
171  this->GetPropertyList()->SetDoubleProperty("classifier.svm.p",val);
172 }
173 
175 {
176  this->GetPropertyList()->SetDoubleProperty("classifier.svm.eps",val);
177 }
178 
180 {
181  this->GetPropertyList()->SetDoubleProperty("classifier.svm.c",val);
182 }
183 
185 {
186  this->GetPropertyList()->SetDoubleProperty("classifier.svm.cache-size",val);
187 }
188 
190 {
191  this->GetPropertyList()->SetIntProperty("classifier.svm.svm-type",val);
192 }
193 
195 {
196  this->GetPropertyList()->SetIntProperty("classifier.svm.kernel-type",val);
197 }
198 
200 {
201  this->GetPropertyList()->SetIntProperty("classifier.svm.degree",val);
202 }
203 
205 {
206  this->GetPropertyList()->SetDoubleProperty("classifier.svm.gamma",val);
207 }
208 
210 {
211  this->GetPropertyList()->SetDoubleProperty("classifier.svm.coef0",val);
212 }
213 
215 {
216  if(this->m_Parameter == nullptr)
217  {
218  MITK_WARN("LibSVMClassifier") << "Parameters are not initialized. Please call ConvertParameter() first!";
219  return;
220  }
221 
222  this->ConvertParameter();
223 
224  // Get the proerty // Some defaults
225  if(!this->GetPropertyList()->Get("classifier.svm.svm-type",this->m_Parameter->svm_type))
226  str << "classifier.svm.svm-type\tNOT SET (default " << this->m_Parameter->svm_type << ")" << "\n";
227  else
228  str << "classifier.svm.svm-type\t" << this->m_Parameter->svm_type << "\n";
229 
230  if(!this->GetPropertyList()->Get("classifier.svm.kernel-type",this->m_Parameter->kernel_type))
231  str << "classifier.svm.kernel-type\tNOT SET (default " << this->m_Parameter->kernel_type << ")" << "\n";
232  else
233  str << "classifier.svm.kernel-type\t" << this->m_Parameter->kernel_type << "\n";
234 
235  if(!this->GetPropertyList()->Get("classifier.svm.degree",this->m_Parameter->degree))
236  str << "classifier.svm.degree\t\tNOT SET (default " << this->m_Parameter->degree << ")" << "\n";
237  else
238  str << "classifier.svm.degree\t\t" << this->m_Parameter->degree << "\n";
239 
240  if(!this->GetPropertyList()->Get("classifier.svm.gamma",this->m_Parameter->gamma))
241  str << "classifier.svm.gamma\t\tNOT SET (default " << this->m_Parameter->gamma << ")" << "\n";
242  else
243  str << "classifier.svm.gamma\t\t" << this->m_Parameter->gamma << "\n";
244 
245  if(!this->GetPropertyList()->Get("classifier.svm.coef0",this->m_Parameter->coef0))
246  str << "classifier.svm.coef0\t\tNOT SET (default " << this->m_Parameter->coef0 << ")" << "\n";
247  else
248  str << "classifier.svm.coef0\t\t" << this->m_Parameter->coef0 << "\n";
249 
250  if(!this->GetPropertyList()->Get("classifier.svm.nu",this->m_Parameter->nu))
251  str << "classifier.svm.nu\t\tNOT SET (default " << this->m_Parameter->nu << ")" << "\n";
252  else
253  str << "classifier.svm.nu\t\t" << this->m_Parameter->nu << "\n";
254 
255  if(!this->GetPropertyList()->Get("classifier.svm.cache-size",this->m_Parameter->cache_size))
256  str << "classifier.svm.cache-size\tNOT SET (default " << this->m_Parameter->cache_size << ")" << "\n";
257  else
258  str << "classifier.svm.cache-size\t" << this->m_Parameter->cache_size << "\n";
259 
260  if(!this->GetPropertyList()->Get("classifier.svm.c",this->m_Parameter->C))
261  str << "classifier.svm.c\t\tNOT SET (default " << this->m_Parameter->C << ")" << "\n";
262  else
263  str << "classifier.svm.c\t\t" << this->m_Parameter->C << "\n";
264 
265  if(!this->GetPropertyList()->Get("classifier.svm.eps",this->m_Parameter->eps))
266  str << "classifier.svm.eps\t\tNOT SET (default " << this->m_Parameter->eps << ")" << "\n";
267  else
268  str << "classifier.svm.eps\t\t" << this->m_Parameter->eps << "\n";
269 
270  if(!this->GetPropertyList()->Get("classifier.svm.p",this->m_Parameter->p))
271  str << "classifier.svm.p\t\tNOT SET (default " << this->m_Parameter->p << ")" << "\n";
272  else
273  str << "classifier.svm.p\t\t" << this->m_Parameter->p << "\n";
274 
275  if(!this->GetPropertyList()->Get("classifier.svm.shrinking",this->m_Parameter->shrinking))
276  str << "classifier.svm.shrinking\tNOT SET (default " << this->m_Parameter->shrinking << ")" << "\n";
277  else
278  str << "classifier.svm.shrinking\t" << this->m_Parameter->shrinking << "\n";
279 
280  if(!this->GetPropertyList()->Get("classifier.svm.probability",this->m_Parameter->probability))
281  str << "classifier.svm.probability\tNOT SET (default " << this->m_Parameter->probability << ")" << "\n";
282  else
283  str << "classifier.svm.probability\t" << this->m_Parameter->probability << "\n";
284 
285  if(!this->GetPropertyList()->Get("classifier.svm.nr-weight",this->m_Parameter->nr_weight))
286  str << "classifier.svm.nr-weight\tNOT SET (default " << this->m_Parameter->nr_weight << ")" << "\n";
287  else
288  str << "classifier.svm.nr-weight\t" << this->m_Parameter->nr_weight << "\n";
289 }
290 
291 // Trying to assign from matrix to noOfPoints
292 void mitk::LibSVMClassifier::ReadXValues(LibSVM::svm_problem * problem, LibSVM::svm_node** xSpace, const Eigen::MatrixXd &X)
293 {
294  int noOfPoints = static_cast<int>(X.rows());
295  int features = static_cast<int>(X.cols());
296 
297  problem->x = static_cast<LibSVM::svm_node **>(malloc(sizeof(LibSVM::svm_node *) * noOfPoints));
298  (*xSpace) = static_cast<LibSVM::svm_node *> (malloc(sizeof(LibSVM::svm_node) * noOfPoints * (features+1)));
299 
300  for (int row = 0; row < noOfPoints; ++row)
301  {
302  for (int col = 0; col < features; ++col)
303  {
304  (*xSpace)[row*features + col].index = col;
305  (*xSpace)[row*features + col].value = X(row,col);
306  }
307  (*xSpace)[row*features + features].index = -1;
308 
309  problem->x[row] = &((*xSpace)[row*features]);
310  }
311 }
312 
313 void mitk::LibSVMClassifier::ReadYValues(LibSVM::svm_problem * problem, const Eigen::MatrixXi &Y)
314 {
315  problem->l = static_cast<int>(Y.rows());
316  problem->y = static_cast<double *>(malloc(sizeof(double) * problem->l));
317 
318  for (int i = 0; i < problem->l; ++i)
319  {
320  problem->y[i] = Y(i,0);
321  }
322 }
323 
324 void mitk::LibSVMClassifier::ReadWValues(LibSVM::svm_problem * problem)
325 {
326  Eigen::MatrixXd & W = this->GetPointWiseWeight();
327  int noOfPoints = problem->l;
328  problem->W = static_cast<double *>(malloc(sizeof(double) * noOfPoints));
329 
330  if (IsUsingPointWiseWeight())
331  {
332  for (int i = 0; i < noOfPoints; ++i)
333  {
334  problem->W[i] = W(i,0);
335  }
336  } else
337  {
338  for (int i = 0; i < noOfPoints; ++i)
339  {
340  problem->W[i] = 1;
341  }
342  }
343 }
void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y) override
void svm_destroy_param(struct svm_parameter *param)
Definition: svm.cpp:3166
const char * svm_check_parameter(const struct svm_problem *prob, const struct svm_parameter *param)
#define MITK_WARN
Definition: mitkLogMacros.h:23
#define mitkThrow()
void PrintParameter(std::ostream &str)
const char features[]
Eigen::MatrixXi Predict(const Eigen::MatrixXd &X) override
struct svm_model * svm_train(const struct svm_problem *prob, const struct svm_parameter *param)
double svm_predict(const struct svm_model *model, const struct svm_node *x)
void svm_free_and_destroy_model(struct svm_model **model_ptr_ptr)
Definition: svm.cpp:3156