22 #include <v3p_netlib.h>
23 #include <vnl/algo/vnl_qr.h>
28 static void _UpdateXMatrix(
const vnl_matrix<double> &xData,
bool addConstant, v3p_netlib_doublereal *x);
29 static void _UpdatePermXMatrix(
const vnl_matrix<double> &xData,
bool addConstant,
const vnl_vector<unsigned int> &permutation, vnl_matrix<double> &x);
31 static void _FinalizeBVector(vnl_vector<double> &b, vnl_vector<unsigned int> &perm,
int cols);
38 int cols = m_B.size();
39 for (
int i = 0; i < cols; ++i)
41 if (!m_AddConstantColumn)
54 vnl_vector<double> mu(x.rows());
55 int cols = m_B.size();
56 for (
unsigned int r = 0 ; r < mu.size(); ++r)
59 for (
int c = 0; c < cols; ++c)
61 if (!m_AddConstantColumn)
62 mu(r) += x(r,c)*m_B(c);
66 mu(r) += x(r,c-1)*m_B(c);
81 vnl_vector<double> mu(x.rows());
82 int cols = m_B.size();
83 for (
unsigned int r = 0 ; r < mu.size(); ++r)
86 for (
int c = 0; c < cols; ++c)
88 if (!m_AddConstantColumn)
89 mu(r) += x(r,c)*m_B(c);
93 mu(r) += x(r,c-1)*m_B(c);
101 m_AddConstantColumn(addConstantColumn)
103 EstimatePermutation(xData);
107 vnl_matrix<double> x;
108 int rows = xData.rows();
109 int cols = m_Permutation.size();
110 vnl_vector<double> mu(rows);
111 vnl_vector<double> eta(rows);
112 vnl_vector<double> weightedY(rows);
113 vnl_matrix<double> weightedX(rows, cols);
114 vnl_vector<double> oldB(cols);
121 double sqrtEps = sqrt(std::numeric_limits<double>::epsilon());
122 double convertCriterion =1e-6;
124 m_B.set_size(m_Permutation.size());
127 while (iter <= iterLimit)
133 for (
int r = 0; r<rows; ++r)
135 double deta = link.
DLink(mu(r));
136 double zBuffer = eta(r) + (yData(r) - mu(r))*deta;
137 double sqrtWeight = 1 / (std::abs(deta) * dist.
SqrtVariance(mu(r)));
139 weightedY(r) = zBuffer * sqrtWeight;
140 for (
int c=0; c<cols; ++c)
142 weightedX(r,c) = x(r,c) * sqrtWeight;
145 vnl_qr<double> qr(weightedX);
146 m_B = qr.solve(weightedY);
148 for (
int r = 0; r < rows; ++r)
153 bool stayInLoop =
false;
154 for(
int c= 0; c < cols; ++c)
156 stayInLoop |= std::abs( m_B(c) - oldB(c)) > convertCriterion *
std::max(sqrtEps, std::abs(oldB(c)));
164 void mitk::GeneralizedLinearModel::EstimatePermutation(
const vnl_matrix<double> &xData)
166 v3p_netlib_integer rows = xData.rows();
167 v3p_netlib_integer cols = xData.cols();
169 if (m_AddConstantColumn)
172 v3p_netlib_doublereal *x =
new v3p_netlib_doublereal[rows* cols];
174 v3p_netlib_doublereal *qraux =
new v3p_netlib_doublereal[cols];
175 v3p_netlib_integer *jpvt =
new v3p_netlib_integer[cols];
176 std::fill_n(jpvt,cols,0);
177 v3p_netlib_doublereal *work =
new v3p_netlib_doublereal[cols];
178 std::fill_n(work,cols,0);
179 v3p_netlib_integer job = 16;
183 v3p_netlib_dqrdc_(x, &rows, &rows, &cols, qraux, jpvt, work, &job);
185 double limit = std::abs(x[0]) *
std::max(cols, rows) * std::numeric_limits<double>::epsilon();
188 for (
int i = 0; i <cols; ++i)
190 m_Rank += (std::abs(x[i*rows + i]) > limit) ? 1 : 0;
193 m_Permutation.set_size(m_Rank);
194 for (
int i = 0; i < m_Rank; ++i)
196 m_Permutation(i) = jpvt[i]-1;
207 static void _UpdateXMatrix(
const vnl_matrix<double> &xData,
bool addConstant, v3p_netlib_doublereal *x)
209 v3p_netlib_integer rows = xData.rows();
210 v3p_netlib_integer cols = xData.cols();
214 for (
int r=0; r < rows; ++r)
216 for (
int c=0; c <cols; ++c)
220 x[c*rows + r] = xData(r,c);
226 x[c*rows + r] = xData(r, c-1);
234 static void _UpdatePermXMatrix(
const vnl_matrix<double> &xData,
bool addConstant,
const vnl_vector<unsigned int> &permutation, vnl_matrix<double> &x)
236 int rows = xData.rows();
237 int cols = permutation.size();
238 x.set_size(rows, cols);
239 for (
int r=0; r < rows; ++r)
241 for (
int c=0; c<cols; ++c)
243 unsigned int newCol = permutation(c);
246 x(r, c) = xData(r,newCol);
247 }
else if (newCol == 0)
252 x(r, c) = xData(r, newCol-1);
262 int rows = yData.size();
265 for (
int r = 0; r < rows; ++r)
267 mu(r) = dist->
Init(yData(r));
268 eta(r) = link->
Link(mu(r));
274 static void _FinalizeBVector(vnl_vector<double> &b, vnl_vector<unsigned int> &perm,
int cols)
276 vnl_vector<double> tempB(cols+1);
278 for (
unsigned int c = 0; c < perm.size(); ++c)
280 tempB(perm(c)) = b(c);
static void _UpdatePermXMatrix(const vnl_matrix< double > &xData, bool addConstant, const vnl_vector< unsigned int > &permutation, vnl_matrix< double > &x)
double Predict(const vnl_vector< double > &c)
Predicts the value corresponding to the given vector.
virtual double SqrtVariance(double mu)
GeneralizedLinearModel(const vnl_matrix< double > &xData, const vnl_vector< double > &yData, bool addConstantColumn=true)
Initialization of the GLM. The parameters needs to be passed at the beginning.
static void _InitMuEta(mitk::DistSimpleBinominal *dist, mitk::LogItLinking *link, const vnl_vector< double > &yData, vnl_vector< double > &mu, vnl_vector< double > &eta)
double InverseLink(double eta)
vnl_vector< double > ExpMu(const vnl_matrix< double > &x)
Estimation of the exponential factor for a given function.
static void _UpdateXMatrix(const vnl_matrix< double > &xData, bool addConstant, v3p_netlib_doublereal *x)
static void _FinalizeBVector(vnl_vector< double > &b, vnl_vector< unsigned int > &perm, int cols)
vnl_vector< double > B()
Returns the b-Vector for the estimation.
virtual double Init(double y)