27 m_Model(nullptr),m_Parameter(nullptr)
29 this->m_Parameter =
new LibSVM::svm_parameter();
44 this->SetPointWiseWeight(Eigen::MatrixXd(Y.rows(),1));
45 this->UsePointWiseWeight(
false);
47 LibSVM::svm_node *xSpace;
48 LibSVM::svm_problem problem;
51 ReadYValues(&problem, Y);
52 ReadXValues(&problem, &xSpace,X);
53 ReadWValues(&problem);
55 const char * error_msg =
nullptr;
77 mitkThrow() <<
"No Model is trained. Train or load a model before predicting new values.";
79 int noOfPoints =
static_cast<int>(X.rows());
80 int noOfFeatures =
static_cast<int>(X.cols());
82 Eigen::MatrixXi result(noOfPoints,1);
84 LibSVM::svm_node * xVector =
static_cast<LibSVM::svm_node *
>(malloc(
sizeof(LibSVM::svm_node) * (noOfFeatures+1)));
85 for (
int point = 0; point < noOfPoints; ++point)
87 for (
int feature = 0; feature < noOfFeatures; ++feature)
89 xVector[feature].index = feature+1;
90 xVector[feature].value = X(point, feature);
92 xVector[noOfFeatures].index = -1;
103 if(!this->GetPropertyList()->Get(
"classifier.svm.svm-type",this->m_Parameter->svm_type)) this->m_Parameter->svm_type = 0;
104 if(!this->GetPropertyList()->Get(
"classifier.svm.kernel-type",this->m_Parameter->kernel_type)) this->m_Parameter->kernel_type = 2;
105 if(!this->GetPropertyList()->Get(
"classifier.svm.degree",this->m_Parameter->degree)) this->m_Parameter->degree = 3;
106 if(!this->GetPropertyList()->Get(
"classifier.svm.gamma",this->m_Parameter->gamma)) this->m_Parameter->gamma = 0;
107 if(!this->GetPropertyList()->Get(
"classifier.svm.coef0",this->m_Parameter->coef0)) this->m_Parameter->coef0 = 0;
108 if(!this->GetPropertyList()->Get(
"classifier.svm.nu",this->m_Parameter->nu)) this->m_Parameter->nu = 0.5;
109 if(!this->GetPropertyList()->Get(
"classifier.svm.cache-size",this->m_Parameter->cache_size)) this->m_Parameter->cache_size = 100.0;
110 if(!this->GetPropertyList()->Get(
"classifier.svm.c",this->m_Parameter->C)) this->m_Parameter->C = 1.0;
111 if(!this->GetPropertyList()->Get(
"classifier.svm.eps",this->m_Parameter->eps)) this->m_Parameter->eps = 1e-3;
112 if(!this->GetPropertyList()->Get(
"classifier.svm.p",this->m_Parameter->p)) this->m_Parameter->p = 0.1;
113 if(!this->GetPropertyList()->Get(
"classifier.svm.shrinking",this->m_Parameter->shrinking)) this->m_Parameter->shrinking = 1;
114 if(!this->GetPropertyList()->Get(
"classifier.svm.probability",this->m_Parameter->probability)) this->m_Parameter->probability = 0;
115 if(!this->GetPropertyList()->Get(
"classifier.svm.nr-weight",this->m_Parameter->nr_weight)) this->m_Parameter->nr_weight = 0;
151 this->GetPropertyList()->SetIntProperty(
"classifier.svm.probability",val);
156 this->GetPropertyList()->SetIntProperty(
"classifier.svm.shrinking",val);
161 this->GetPropertyList()->SetIntProperty(
"classifier.svm.nr_weight",val);
166 this->GetPropertyList()->SetDoubleProperty(
"classifier.svm.nu",val);
171 this->GetPropertyList()->SetDoubleProperty(
"classifier.svm.p",val);
176 this->GetPropertyList()->SetDoubleProperty(
"classifier.svm.eps",val);
181 this->GetPropertyList()->SetDoubleProperty(
"classifier.svm.c",val);
186 this->GetPropertyList()->SetDoubleProperty(
"classifier.svm.cache-size",val);
191 this->GetPropertyList()->SetIntProperty(
"classifier.svm.svm-type",val);
196 this->GetPropertyList()->SetIntProperty(
"classifier.svm.kernel-type",val);
201 this->GetPropertyList()->SetIntProperty(
"classifier.svm.degree",val);
206 this->GetPropertyList()->SetDoubleProperty(
"classifier.svm.gamma",val);
211 this->GetPropertyList()->SetDoubleProperty(
"classifier.svm.coef0",val);
216 if(this->m_Parameter ==
nullptr)
218 MITK_WARN(
"LibSVMClassifier") <<
"Parameters are not initialized. Please call ConvertParameter() first!";
222 this->ConvertParameter();
225 if(!this->GetPropertyList()->Get(
"classifier.svm.svm-type",this->m_Parameter->svm_type))
226 str <<
"classifier.svm.svm-type\tNOT SET (default " << this->m_Parameter->svm_type <<
")" <<
"\n";
228 str <<
"classifier.svm.svm-type\t" << this->m_Parameter->svm_type <<
"\n";
230 if(!this->GetPropertyList()->Get(
"classifier.svm.kernel-type",this->m_Parameter->kernel_type))
231 str <<
"classifier.svm.kernel-type\tNOT SET (default " << this->m_Parameter->kernel_type <<
")" <<
"\n";
233 str <<
"classifier.svm.kernel-type\t" << this->m_Parameter->kernel_type <<
"\n";
235 if(!this->GetPropertyList()->Get(
"classifier.svm.degree",this->m_Parameter->degree))
236 str <<
"classifier.svm.degree\t\tNOT SET (default " << this->m_Parameter->degree <<
")" <<
"\n";
238 str <<
"classifier.svm.degree\t\t" << this->m_Parameter->degree <<
"\n";
240 if(!this->GetPropertyList()->Get(
"classifier.svm.gamma",this->m_Parameter->gamma))
241 str <<
"classifier.svm.gamma\t\tNOT SET (default " << this->m_Parameter->gamma <<
")" <<
"\n";
243 str <<
"classifier.svm.gamma\t\t" << this->m_Parameter->gamma <<
"\n";
245 if(!this->GetPropertyList()->Get(
"classifier.svm.coef0",this->m_Parameter->coef0))
246 str <<
"classifier.svm.coef0\t\tNOT SET (default " << this->m_Parameter->coef0 <<
")" <<
"\n";
248 str <<
"classifier.svm.coef0\t\t" << this->m_Parameter->coef0 <<
"\n";
250 if(!this->GetPropertyList()->Get(
"classifier.svm.nu",this->m_Parameter->nu))
251 str <<
"classifier.svm.nu\t\tNOT SET (default " << this->m_Parameter->nu <<
")" <<
"\n";
253 str <<
"classifier.svm.nu\t\t" << this->m_Parameter->nu <<
"\n";
255 if(!this->GetPropertyList()->Get(
"classifier.svm.cache-size",this->m_Parameter->cache_size))
256 str <<
"classifier.svm.cache-size\tNOT SET (default " << this->m_Parameter->cache_size <<
")" <<
"\n";
258 str <<
"classifier.svm.cache-size\t" << this->m_Parameter->cache_size <<
"\n";
260 if(!this->GetPropertyList()->Get(
"classifier.svm.c",this->m_Parameter->C))
261 str <<
"classifier.svm.c\t\tNOT SET (default " << this->m_Parameter->C <<
")" <<
"\n";
263 str <<
"classifier.svm.c\t\t" << this->m_Parameter->C <<
"\n";
265 if(!this->GetPropertyList()->Get(
"classifier.svm.eps",this->m_Parameter->eps))
266 str <<
"classifier.svm.eps\t\tNOT SET (default " << this->m_Parameter->eps <<
")" <<
"\n";
268 str <<
"classifier.svm.eps\t\t" << this->m_Parameter->eps <<
"\n";
270 if(!this->GetPropertyList()->Get(
"classifier.svm.p",this->m_Parameter->p))
271 str <<
"classifier.svm.p\t\tNOT SET (default " << this->m_Parameter->p <<
")" <<
"\n";
273 str <<
"classifier.svm.p\t\t" << this->m_Parameter->p <<
"\n";
275 if(!this->GetPropertyList()->Get(
"classifier.svm.shrinking",this->m_Parameter->shrinking))
276 str <<
"classifier.svm.shrinking\tNOT SET (default " << this->m_Parameter->shrinking <<
")" <<
"\n";
278 str <<
"classifier.svm.shrinking\t" << this->m_Parameter->shrinking <<
"\n";
280 if(!this->GetPropertyList()->Get(
"classifier.svm.probability",this->m_Parameter->probability))
281 str <<
"classifier.svm.probability\tNOT SET (default " << this->m_Parameter->probability <<
")" <<
"\n";
283 str <<
"classifier.svm.probability\t" << this->m_Parameter->probability <<
"\n";
285 if(!this->GetPropertyList()->Get(
"classifier.svm.nr-weight",this->m_Parameter->nr_weight))
286 str <<
"classifier.svm.nr-weight\tNOT SET (default " << this->m_Parameter->nr_weight <<
")" <<
"\n";
288 str <<
"classifier.svm.nr-weight\t" << this->m_Parameter->nr_weight <<
"\n";
292 void mitk::LibSVMClassifier::ReadXValues(LibSVM::svm_problem * problem, LibSVM::svm_node** xSpace,
const Eigen::MatrixXd &X)
294 int noOfPoints =
static_cast<int>(X.rows());
295 int features =
static_cast<int>(X.cols());
297 problem->x =
static_cast<LibSVM::svm_node **
>(malloc(
sizeof(LibSVM::svm_node *) * noOfPoints));
298 (*xSpace) =
static_cast<LibSVM::svm_node *
> (malloc(
sizeof(LibSVM::svm_node) * noOfPoints * (features+1)));
300 for (
int row = 0; row < noOfPoints; ++row)
302 for (
int col = 0; col <
features; ++col)
304 (*xSpace)[row*features + col].index = col;
305 (*xSpace)[row*features + col].value = X(row,col);
307 (*xSpace)[row*features +
features].index = -1;
309 problem->x[row] = &((*xSpace)[row*
features]);
313 void mitk::LibSVMClassifier::ReadYValues(LibSVM::svm_problem * problem,
const Eigen::MatrixXi &Y)
315 problem->l =
static_cast<int>(Y.rows());
316 problem->y =
static_cast<double *
>(malloc(
sizeof(
double) * problem->l));
318 for (
int i = 0; i < problem->l; ++i)
320 problem->y[i] = Y(i,0);
324 void mitk::LibSVMClassifier::ReadWValues(LibSVM::svm_problem * problem)
326 Eigen::MatrixXd & W = this->GetPointWiseWeight();
327 int noOfPoints = problem->l;
328 problem->W =
static_cast<double *
>(malloc(
sizeof(
double) * noOfPoints));
330 if (IsUsingPointWiseWeight())
332 for (
int i = 0; i < noOfPoints; ++i)
334 problem->W[i] = W(i,0);
338 for (
int i = 0; i < noOfPoints; ++i)
void SetProbability(int val)
void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y) override
void svm_destroy_param(struct svm_parameter *param)
const char * svm_check_parameter(const struct svm_problem *prob, const struct svm_parameter *param)
void SetCoef0(double val)
void PrintParameter(std::ostream &str)
void SetNrWeight(int val)
void SetShrinking(int val)
void SetKernelType(int val)
Eigen::MatrixXi Predict(const Eigen::MatrixXd &X) override
struct svm_model * svm_train(const struct svm_problem *prob, const struct svm_parameter *param)
double svm_predict(const struct svm_model *model, const struct svm_node *x)
void SetGamma(double val)
void SetCacheSize(double val)
void svm_free_and_destroy_model(struct svm_model **model_ptr_ptr)