Medical Imaging Interaction Toolkit  2016.11.0
Medical Imaging Interaction Toolkit
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
mitkLibSVMClassifier.cpp
Go to the documentation of this file.
1 /*===================================================================
2 
3 The Medical Imaging Interaction Toolkit (MITK)
4 
5 Copyright (c) German Cancer Research Center,
6 Division of Medical and Biological Informatics.
7 All rights reserved.
8 
9 This software is distributed WITHOUT ANY WARRANTY; without
10 even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 A PARTICULAR PURPOSE.
12 
13 See LICENSE.txt or http://www.mitk.org for details.
14 
15 ===================================================================*/
16 
17 #include <mitkLibSVMClassifier.h>
18 
20 namespace LibSVM
21 {
22 #include "svm.h"
23 }
24 #include <mitkExceptionMacro.h>
25 
27  m_Model(nullptr),m_Parameter(nullptr)
28 {
29  this->m_Parameter = new LibSVM::svm_parameter();
30 }
31 
33 {
34  if (m_Model)
35  {
37  }
38  if( m_Parameter)
39  LibSVM::svm_destroy_param(m_Parameter);
40 }
41 
42 void mitk::LibSVMClassifier::Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y)
43 {
44  this->SetPointWiseWeight(Eigen::MatrixXd(Y.rows(),1));
45  this->UsePointWiseWeight(false);
46 
47  LibSVM::svm_node *xSpace;
48  LibSVM::svm_problem problem;
49 
50  ConvertParameter();
51  ReadYValues(&problem, Y);
52  ReadXValues(&problem, &xSpace,X);
53  ReadWValues(&problem);
54 
55  const char * error_msg = nullptr;
56  error_msg = LibSVM::svm_check_parameter(&problem, m_Parameter);
57  if (error_msg)
58  {
59  LibSVM::svm_destroy_param(m_Parameter);
60  free(problem.y);
61  free(problem.x);
62  free(xSpace);
63  mitkThrow() << "Error: " << error_msg;
64  }
65 
66  m_Model = LibSVM::svm_train(&problem, m_Parameter);
67 
68  // free(problem.y);
69  // free(problem.x);
70  // free(xSpace);
71 }
72 
73 Eigen::MatrixXi mitk::LibSVMClassifier::Predict(const Eigen::MatrixXd &X)
74 {
75  if ( ! m_Model)
76  {
77  mitkThrow() << "No Model is trained. Train or load a model before predicting new values.";
78  }
79  int noOfPoints = static_cast<int>(X.rows());
80  int noOfFeatures = static_cast<int>(X.cols());
81 
82  Eigen::MatrixXi result(noOfPoints,1);
83 
84  LibSVM::svm_node * xVector = static_cast<LibSVM::svm_node *>(malloc(sizeof(LibSVM::svm_node) * (noOfFeatures+1)));
85  for (int point = 0; point < noOfPoints; ++point)
86  {
87  for (int feature = 0; feature < noOfFeatures; ++feature)
88  {
89  xVector[feature].index = feature+1;
90  xVector[feature].value = X(point, feature);
91  }
92  xVector[noOfFeatures].index = -1;
93  result(point,0) = LibSVM::svm_predict(m_Model,xVector);
94  }
95 
96  free(xVector);
97  return result;
98 }
99 
101 {
102  // Get the proerty // Some defaults
103  if(!this->GetPropertyList()->Get("classifier.svm.svm-type",this->m_Parameter->svm_type)) this->m_Parameter->svm_type = 0;
104  if(!this->GetPropertyList()->Get("classifier.svm.kernel-type",this->m_Parameter->kernel_type)) this->m_Parameter->kernel_type = 2;
105  if(!this->GetPropertyList()->Get("classifier.svm.degree",this->m_Parameter->degree)) this->m_Parameter->degree = 3;
106  if(!this->GetPropertyList()->Get("classifier.svm.gamma",this->m_Parameter->gamma)) this->m_Parameter->gamma = 0; // 1/n_features;
107  if(!this->GetPropertyList()->Get("classifier.svm.coef0",this->m_Parameter->coef0)) this->m_Parameter->coef0 = 0;
108  if(!this->GetPropertyList()->Get("classifier.svm.nu",this->m_Parameter->nu)) this->m_Parameter->nu = 0.5;
109  if(!this->GetPropertyList()->Get("classifier.svm.cache-size",this->m_Parameter->cache_size)) this->m_Parameter->cache_size = 100.0;
110  if(!this->GetPropertyList()->Get("classifier.svm.c",this->m_Parameter->C)) this->m_Parameter->C = 1.0;
111  if(!this->GetPropertyList()->Get("classifier.svm.eps",this->m_Parameter->eps)) this->m_Parameter->eps = 1e-3;
112  if(!this->GetPropertyList()->Get("classifier.svm.p",this->m_Parameter->p)) this->m_Parameter->p = 0.1;
113  if(!this->GetPropertyList()->Get("classifier.svm.shrinking",this->m_Parameter->shrinking)) this->m_Parameter->shrinking = 1;
114  if(!this->GetPropertyList()->Get("classifier.svm.probability",this->m_Parameter->probability)) this->m_Parameter->probability = 0;
115  if(!this->GetPropertyList()->Get("classifier.svm.nr-weight",this->m_Parameter->nr_weight)) this->m_Parameter->nr_weight = 0;
116 
117  //options:
118  //-s svm_type : set type of SVM (default 0)
119  // 0 -- C-SVC
120  // 1 -- nu-SVC
121  // 2 -- one-class SVM
122  // 3 -- epsilon-SVR
123  // 4 -- nu-SVR
124  //-t kernel_type : set type of kernel function (default 2)
125  // 0 -- linear: u'*v
126  // 1 -- polynomial: (gamma*u'*v + coef0)^degree
127  // 2 -- radial basis function: exp(-gamma*|u-v|^2)
128  // 3 -- sigmoid: tanh(gamma*u'*v + coef0)
129  //-d degree : set degree in kernel function (default 3)
130  //-g gamma : set gamma in kernel function (default 1/num_features)
131  //-r coef0 : set coef0 in kernel function (default 0)
132  //-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)
133  //-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)
134  //-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)
135  //-m cachesize : set cache memory size in MB (default 100)
136  //-e epsilon : set tolerance of termination criterion (default 0.001)
137  //-h shrinking: whether to use the shrinking heuristics, 0 or 1 (default 1)
138  //-b probability_estimates: whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)
139  //-wiweight: set the parameter C of class i to weight*C, for C-SVC (default 1)
140 
141  // this->m_Parameter->weight_label = nullptr;
142  // this->m_Parameter->weight = 1;
143 }
144 
145 /* these are for training only */
146 //int *weight_label; /* for C_SVC */
147 //double* weight; /* for C_SVC */
148 
150 {
151  this->GetPropertyList()->SetIntProperty("classifier.svm.probability",val);
152 }
153 
155 {
156  this->GetPropertyList()->SetIntProperty("classifier.svm.shrinking",val);
157 }
158 
160 {
161  this->GetPropertyList()->SetIntProperty("classifier.svm.nr_weight",val);
162 }
163 
165 {
166  this->GetPropertyList()->SetDoubleProperty("classifier.svm.nu",val);
167 }
168 
170 {
171  this->GetPropertyList()->SetDoubleProperty("classifier.svm.p",val);
172 }
173 
175 {
176  this->GetPropertyList()->SetDoubleProperty("classifier.svm.eps",val);
177 }
178 
180 {
181  this->GetPropertyList()->SetDoubleProperty("classifier.svm.c",val);
182 }
183 
185 {
186  this->GetPropertyList()->SetDoubleProperty("classifier.svm.cache-size",val);
187 }
188 
190 {
191  this->GetPropertyList()->SetIntProperty("classifier.svm.svm-type",val);
192 }
193 
195 {
196  this->GetPropertyList()->SetIntProperty("classifier.svm.kernel-type",val);
197 }
198 
200 {
201  this->GetPropertyList()->SetIntProperty("classifier.svm.degree",val);
202 }
203 
205 {
206  this->GetPropertyList()->SetDoubleProperty("classifier.svm.gamma",val);
207 }
208 
210 {
211  this->GetPropertyList()->SetDoubleProperty("classifier.svm.coef0",val);
212 }
213 
215 {
216  if(this->m_Parameter == nullptr)
217  {
218  MITK_WARN("LibSVMClassifier") << "Parameters are not initialized. Please call ConvertParameter() first!";
219  return;
220  }
221 
222  this->ConvertParameter();
223 
224  // Get the proerty // Some defaults
225  if(!this->GetPropertyList()->Get("classifier.svm.svm-type",this->m_Parameter->svm_type))
226  str << "classifier.svm.svm-type\tNOT SET (default " << this->m_Parameter->svm_type << ")" << "\n";
227  else
228  str << "classifier.svm.svm-type\t" << this->m_Parameter->svm_type << "\n";
229 
230  if(!this->GetPropertyList()->Get("classifier.svm.kernel-type",this->m_Parameter->kernel_type))
231  str << "classifier.svm.kernel-type\tNOT SET (default " << this->m_Parameter->kernel_type << ")" << "\n";
232  else
233  str << "classifier.svm.kernel-type\t" << this->m_Parameter->kernel_type << "\n";
234 
235  if(!this->GetPropertyList()->Get("classifier.svm.degree",this->m_Parameter->degree))
236  str << "classifier.svm.degree\t\tNOT SET (default " << this->m_Parameter->degree << ")" << "\n";
237  else
238  str << "classifier.svm.degree\t\t" << this->m_Parameter->degree << "\n";
239 
240  if(!this->GetPropertyList()->Get("classifier.svm.gamma",this->m_Parameter->gamma))
241  str << "classifier.svm.gamma\t\tNOT SET (default " << this->m_Parameter->gamma << ")" << "\n";
242  else
243  str << "classifier.svm.gamma\t\t" << this->m_Parameter->gamma << "\n";
244 
245  if(!this->GetPropertyList()->Get("classifier.svm.coef0",this->m_Parameter->coef0))
246  str << "classifier.svm.coef0\t\tNOT SET (default " << this->m_Parameter->coef0 << ")" << "\n";
247  else
248  str << "classifier.svm.coef0\t\t" << this->m_Parameter->coef0 << "\n";
249 
250  if(!this->GetPropertyList()->Get("classifier.svm.nu",this->m_Parameter->nu))
251  str << "classifier.svm.nu\t\tNOT SET (default " << this->m_Parameter->nu << ")" << "\n";
252  else
253  str << "classifier.svm.nu\t\t" << this->m_Parameter->nu << "\n";
254 
255  if(!this->GetPropertyList()->Get("classifier.svm.cache-size",this->m_Parameter->cache_size))
256  str << "classifier.svm.cache-size\tNOT SET (default " << this->m_Parameter->cache_size << ")" << "\n";
257  else
258  str << "classifier.svm.cache-size\t" << this->m_Parameter->cache_size << "\n";
259 
260  if(!this->GetPropertyList()->Get("classifier.svm.c",this->m_Parameter->C))
261  str << "classifier.svm.c\t\tNOT SET (default " << this->m_Parameter->C << ")" << "\n";
262  else
263  str << "classifier.svm.c\t\t" << this->m_Parameter->C << "\n";
264 
265  if(!this->GetPropertyList()->Get("classifier.svm.eps",this->m_Parameter->eps))
266  str << "classifier.svm.eps\t\tNOT SET (default " << this->m_Parameter->eps << ")" << "\n";
267  else
268  str << "classifier.svm.eps\t\t" << this->m_Parameter->eps << "\n";
269 
270  if(!this->GetPropertyList()->Get("classifier.svm.p",this->m_Parameter->p))
271  str << "classifier.svm.p\t\tNOT SET (default " << this->m_Parameter->p << ")" << "\n";
272  else
273  str << "classifier.svm.p\t\t" << this->m_Parameter->p << "\n";
274 
275  if(!this->GetPropertyList()->Get("classifier.svm.shrinking",this->m_Parameter->shrinking))
276  str << "classifier.svm.shrinking\tNOT SET (default " << this->m_Parameter->shrinking << ")" << "\n";
277  else
278  str << "classifier.svm.shrinking\t" << this->m_Parameter->shrinking << "\n";
279 
280  if(!this->GetPropertyList()->Get("classifier.svm.probability",this->m_Parameter->probability))
281  str << "classifier.svm.probability\tNOT SET (default " << this->m_Parameter->probability << ")" << "\n";
282  else
283  str << "classifier.svm.probability\t" << this->m_Parameter->probability << "\n";
284 
285  if(!this->GetPropertyList()->Get("classifier.svm.nr-weight",this->m_Parameter->nr_weight))
286  str << "classifier.svm.nr-weight\tNOT SET (default " << this->m_Parameter->nr_weight << ")" << "\n";
287  else
288  str << "classifier.svm.nr-weight\t" << this->m_Parameter->nr_weight << "\n";
289 }
290 
291 // Trying to assign from matrix to noOfPoints
292 void mitk::LibSVMClassifier::ReadXValues(LibSVM::svm_problem * problem, LibSVM::svm_node** xSpace, const Eigen::MatrixXd &X)
293 {
294  int noOfPoints = static_cast<int>(X.rows());
295  int features = static_cast<int>(X.cols());
296 
297  problem->x = static_cast<LibSVM::svm_node **>(malloc(sizeof(LibSVM::svm_node *) * noOfPoints));
298  (*xSpace) = static_cast<LibSVM::svm_node *> (malloc(sizeof(LibSVM::svm_node) * noOfPoints * (features+1)));
299 
300  for (int row = 0; row < noOfPoints; ++row)
301  {
302  for (int col = 0; col < features; ++col)
303  {
304  (*xSpace)[row*features + col].index = col;
305  (*xSpace)[row*features + col].value = X(row,col);
306  }
307  (*xSpace)[row*features + features].index = -1;
308 
309  problem->x[row] = &((*xSpace)[row*features]);
310  }
311 }
312 
313 void mitk::LibSVMClassifier::ReadYValues(LibSVM::svm_problem * problem, const Eigen::MatrixXi &Y)
314 {
315  problem->l = static_cast<int>(Y.rows());
316  problem->y = static_cast<double *>(malloc(sizeof(double) * problem->l));
317 
318  for (int i = 0; i < problem->l; ++i)
319  {
320  problem->y[i] = Y(i,0);
321  }
322 }
323 
324 void mitk::LibSVMClassifier::ReadWValues(LibSVM::svm_problem * problem)
325 {
326  Eigen::MatrixXd & W = this->GetPointWiseWeight();
327  int noOfPoints = problem->l;
328  problem->W = static_cast<double *>(malloc(sizeof(double) * noOfPoints));
329 
330  if (IsUsingPointWiseWeight())
331  {
332  for (int i = 0; i < noOfPoints; ++i)
333  {
334  problem->W[i] = W(i,0);
335  }
336  } else
337  {
338  for (int i = 0; i < noOfPoints; ++i)
339  {
340  problem->W[i] = 1;
341  }
342  }
343 }
void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y) override
void svm_destroy_param(struct svm_parameter *param)
Definition: svm.cpp:3166
const char * svm_check_parameter(const struct svm_problem *prob, const struct svm_parameter *param)
#define MITK_WARN
Definition: mitkLogMacros.h:23
#define mitkThrow()
void PrintParameter(std::ostream &str)
const char features[]
Eigen::MatrixXi Predict(const Eigen::MatrixXd &X) override
struct svm_model * svm_train(const struct svm_problem *prob, const struct svm_parameter *param)
double svm_predict(const struct svm_model *model, const struct svm_node *x)
void svm_free_and_destroy_model(struct svm_model **model_ptr_ptr)
Definition: svm.cpp:3156