18 #include <v3p_netlib.h> 19 #include <vnl/algo/vnl_qr.h> 24 static void _UpdateXMatrix(
const vnl_matrix<double> &xData,
bool addConstant, v3p_netlib_doublereal *x);
25 static void _UpdatePermXMatrix(
const vnl_matrix<double> &xData,
bool addConstant,
const vnl_vector<unsigned int> &permutation, vnl_matrix<double> &x);
27 static void _FinalizeBVector(vnl_vector<double> &b, vnl_vector<unsigned int> &perm,
int cols);
34 int cols = m_B.size();
35 for (
int i = 0; i < cols; ++i)
37 if (!m_AddConstantColumn)
50 vnl_vector<double> mu(x.rows());
51 int cols = m_B.size();
52 for (
unsigned int r = 0 ; r < mu.size(); ++r)
55 for (
int c = 0; c < cols; ++c)
57 if (!m_AddConstantColumn)
58 mu(r) += x(r,c)*m_B(c);
62 mu(r) += x(r,c-1)*m_B(c);
77 vnl_vector<double> mu(x.rows());
78 int cols = m_B.size();
79 for (
unsigned int r = 0 ; r < mu.size(); ++r)
82 for (
int c = 0; c < cols; ++c)
84 if (!m_AddConstantColumn)
85 mu(r) += x(r,c)*m_B(c);
89 mu(r) += x(r,c-1)*m_B(c);
97 m_AddConstantColumn(addConstantColumn)
99 EstimatePermutation(xData);
103 vnl_matrix<double> x;
104 int rows = xData.rows();
105 int cols = m_Permutation.size();
106 vnl_vector<double> mu(rows);
107 vnl_vector<double> eta(rows);
108 vnl_vector<double> weightedY(rows);
109 vnl_matrix<double> weightedX(rows, cols);
110 vnl_vector<double> oldB(cols);
117 double sqrtEps = sqrt(std::numeric_limits<double>::epsilon());
118 double convertCriterion =1e-6;
120 m_B.set_size(m_Permutation.size());
123 while (iter <= iterLimit)
129 for (
int r = 0; r<rows; ++r)
131 double deta = link.
DLink(mu(r));
132 double zBuffer = eta(r) + (yData(r) - mu(r))*deta;
133 double sqrtWeight = 1 / (std::abs(deta) * dist.
SqrtVariance(mu(r)));
135 weightedY(r) = zBuffer * sqrtWeight;
136 for (
int c=0; c<cols; ++c)
138 weightedX(r,c) = x(r,c) * sqrtWeight;
141 vnl_qr<double> qr(weightedX);
142 m_B = qr.solve(weightedY);
144 for (
int r = 0; r < rows; ++r)
149 bool stayInLoop =
false;
150 for(
int c= 0; c < cols; ++c)
152 stayInLoop |= std::abs( m_B(c) - oldB(c)) > convertCriterion *
std::max(sqrtEps, std::abs(oldB(c)));
160 void mitk::GeneralizedLinearModel::EstimatePermutation(
const vnl_matrix<double> &xData)
162 v3p_netlib_integer rows = xData.rows();
163 v3p_netlib_integer cols = xData.cols();
165 if (m_AddConstantColumn)
168 v3p_netlib_doublereal *x =
new v3p_netlib_doublereal[rows* cols];
170 v3p_netlib_doublereal *qraux =
new v3p_netlib_doublereal[cols];
171 v3p_netlib_integer *jpvt =
new v3p_netlib_integer[cols];
172 std::fill_n(jpvt,cols,0);
173 v3p_netlib_doublereal *work =
new v3p_netlib_doublereal[cols];
174 std::fill_n(work,cols,0);
175 v3p_netlib_integer job = 16;
179 v3p_netlib_dqrdc_(x, &rows, &rows, &cols, qraux, jpvt, work, &job);
181 double limit = std::abs(x[0]) *
std::max(cols, rows) * std::numeric_limits<double>::epsilon();
184 for (
int i = 0; i <cols; ++i)
186 m_Rank += (std::abs(x[i*rows + i]) > limit) ? 1 : 0;
189 m_Permutation.set_size(m_Rank);
190 for (
int i = 0; i < m_Rank; ++i)
192 m_Permutation(i) = jpvt[i]-1;
203 static void _UpdateXMatrix(
const vnl_matrix<double> &xData,
bool addConstant, v3p_netlib_doublereal *x)
205 v3p_netlib_integer rows = xData.rows();
206 v3p_netlib_integer cols = xData.cols();
210 for (
int r=0; r < rows; ++r)
212 for (
int c=0; c <cols; ++c)
216 x[c*rows + r] = xData(r,c);
222 x[c*rows + r] = xData(r, c-1);
230 static void _UpdatePermXMatrix(
const vnl_matrix<double> &xData,
bool addConstant,
const vnl_vector<unsigned int> &permutation, vnl_matrix<double> &x)
232 int rows = xData.rows();
233 int cols = permutation.size();
234 x.set_size(rows, cols);
235 for (
int r=0; r < rows; ++r)
237 for (
int c=0; c<cols; ++c)
239 unsigned int newCol = permutation(c);
242 x(r, c) = xData(r,newCol);
243 }
else if (newCol == 0)
248 x(r, c) = xData(r, newCol-1);
258 int rows = yData.size();
261 for (
int r = 0; r < rows; ++r)
263 mu(r) = dist->
Init(yData(r));
264 eta(r) = link->
Link(mu(r));
270 static void _FinalizeBVector(vnl_vector<double> &b, vnl_vector<unsigned int> &perm,
int cols)
272 vnl_vector<double> tempB(cols+1);
274 for (
unsigned int c = 0; c < perm.size(); ++c)
276 tempB(perm(c)) = b(c);
static void _UpdatePermXMatrix(const vnl_matrix< double > &xData, bool addConstant, const vnl_vector< unsigned int > &permutation, vnl_matrix< double > &x)
double Predict(const vnl_vector< double > &c)
Predicts the value corresponding to the given vector.
double InverseLink(double eta) override
GeneralizedLinearModel(const vnl_matrix< double > &xData, const vnl_vector< double > &yData, bool addConstantColumn=true)
Initialization of the GLM. The parameters needs to be passed at the beginning.
double Link(double mu) override
static void _InitMuEta(mitk::DistSimpleBinominal *dist, mitk::LogItLinking *link, const vnl_vector< double > &yData, vnl_vector< double > &mu, vnl_vector< double > &eta)
vnl_vector< double > ExpMu(const vnl_matrix< double > &x)
Estimation of the exponential factor for a given function.
static void _UpdateXMatrix(const vnl_matrix< double > &xData, bool addConstant, v3p_netlib_doublereal *x)
static void _FinalizeBVector(vnl_vector< double > &b, vnl_vector< unsigned int > &perm, int cols)
double SqrtVariance(double mu) override
vnl_vector< double > B()
Returns the b-Vector for the estimation.
double DLink(double mu) override
double Init(double y) override