LINE Solver
MATLAB API documentation
Loading...
Searching...
No Matches
mamap2m_fit_fb_multiclass.m
1function [mmap,fF,fB] = mamap2m_fit_fb_multiclass(map,p,F,B,classWeights,fbWeights)
2% Performs approximate fitting of a MMAP given the underlying MAP,
3% the class probabilities (always fitted exactly), the forward moments,
4% and the backward moments.
5% Input
6% - map: second-order AMAP underlying the MAMAP[m]
7% - p: vector of class probabilities
8% - F: vector of forward moments
9% - B: vector of backward moments
10% - classWeights: optional vector of weights for each class
11% - fbWeights: optional 2-vector of weights of forward and backward moments
12% Output
13% - mmap: fitted MAMAP[m]
14% - fF: vector of optimal feasible forward moments
15% - fB: vector of optimal feasible backward moments
16
17if (size(map{1},1) ~= 2)
18 error('Underlying MAP must be of second-order.');
19end
20if (map{1}(2,1) ~= 0)
21 error('Underlying MAP must be acyclic');
22end
23if (map{2}(1,2) == 0)
24 form = 1;
25elseif (map{2}(1,1) == 0)
26 form = 2;
27else
28 error('Underlying MAP must be in canonical acyclic form');
29end
30
31%fprintf('Fitting MAMAP(2,m) F+B: form = %d\n', form);
32
33% number of classes
34k = length(p);
35
36% default weights to use in the objective function
37if nargin < 5 || isempty(classWeights)
38 classWeights = ones(k,1);
39end
40if nargin < 6
41 fbWeights = ones(2,1);
42end
43
44% result
45mmap = cell(1,2+k);
46mmap{1} = map{1};
47mmap{2} = map{2};
48
49h1 = -1/map{1}(1,1);
50h2 = -1/map{1}(2,2);
51r1 = map{1}(1,2) * h1;
52r2 = map{2}(2,2) * h2;
53
54degentol = 1e-8;
55
56if (form == 1 && (r1 < degentol || r2 > 1-degentol || abs(h2-h1*r2) < degentol || abs(h1 - h2 + h2*r1) < degentol )) || ...
57 (form == 2 && (r2 > 1-degentol || abs(h1 - h2 + h2*r1) < degentol || abs(h1 - h2 - h1*r1 + h1*r1*r2) < degentol ))
58
59 % POISSON PROCESS
60 % fprintf('Fitting MAMAP(2,m) F+B: detected Poisson process\n');
61
62 % return marked poisson process
63 h = map_mean(mmap);
64 mmap = cell(1,2+k);
65 mmap{1} = -1/h;
66 mmap{2} = 1/h;
67 for c = 1:k
68 mmap{2+c} = mmap{2} * p(c);
69 end
70
71 fF = mmap_forward_moment(mmap,1);
72 fB = mmap_backward_moment(mmap,1);
73
74 return;
75
76elseif (form == 2 && r2 < degentol && abs(1-r1) < degentol)
77
78 % DEGENERATE PHASE_TYPE
79 % fprintf('Fitting MAMAP(2,m) F+B: detected degenerate phase-type form\n');
80
81 % compute parameters of D11,D12,...,D1k
82 q = zeros(3,k);
83 for c = 1:k
84 q(1,c) = p(c);
85 q(2,c) = p(c);
86 q(3,c) = p(c);
87 end
88
89elseif form == 1 && r2 < degentol
90
91 % CANONICAL PHASE_TYPE
92% fprintf('Fitting MAMAP(2,m) F+B: detected canonical phase-type form\n');
93
94% fprintf('Fitting MAMAP(2,m) F+B: fitting backward\n');
95
96 % convert to phase-type
97 aph = map;
98 aph{2}(2,2) = 0;
99 aph = map_normalize(aph);
100
101 mmap = maph2m_fit_multiclass(aph, p, B, classWeights);
102
103 fF = mmap_forward_moment(mmap, 1);
104 fB = mmap_backward_moment(mmap, 1);
105
106 return;
107
108elseif (form == 1 && abs(1-r1) < degentol) || ...
109 (form == 2 && abs(1-r1) < degentol)
110
111 % NON-CANONICAL PHASE_TYPE
112 % fprintf('Fitting MAMAP(2,m) F+B: detected non-canonical phase-type form\n');
113
114 % coefficients: q(j,c) = F(c) * q_f(j,c) + q_0(j,c)
115 q_f = zeros(2,k);
116 q_0 = zeros(2,k);
117 for c = 1:k
118 % q2
119 q_f(1,c) = p(c) * ( -1/((h1 + h2*(r1 - 1))*(r2 - 1)*(r1 + r2 - r1*r2)) );
120 q_0(1,c) = p(c) * ( h2/((r2 - 1)*(r1 + r2 - r1*r2)*(h1 - h2 + h2*r1)) );
121 % q3
122 q_f(2,c) = p(c) * ( -1/(r2*(h1 + h2*(r1 - 1))*(r1 + r2 - r1*r2)) );
123 q_0(2,c) = p(c) * ( (h1 + h2*r1)/(r2*(r1 + r2 - r1*r2)*(h1 - h2 + h2*r1)) );
124 end
125
126 % inequality constraints
127 A = zeros(4*k,k);
128 b = zeros(4*k,1);
129 for c = 1:k
130 for j = 1:2
131 row = ((c-1)*4+(j-1)*2);
132 col = (c-1);
133 % q_j <= 1
134 A(row+1,col+1) = q_f(j,c);
135 b(row+1) = 1 - q_0(j,c);
136 % q_j >= 0
137 A(row+2,col+1) = - q_f(j,c);
138 b(row+2) = q_0(j,c);
139 end
140 end
141
142 % equality constraints
143 Aeq = zeros(2,k);
144 beq = ones(2,1);
145 for c = 1:k
146 % equality constraints
147 for j = 1:2
148 Aeq(j,c) = q_f(j,c);
149 beq(j) = beq(j) - q_0(j,c);
150 end
151 end
152
153 % objective function
154 H = zeros(k, k);
155 h = zeros(k, 1);
156 for c = 1:k
157 base = (c-1);
158 forwardWeight = (classWeights(c) * fbWeights(1));
159 H(base+1,base+1) = 2/F(c)^2 * forwardWeight;
160 h(base+1) = -2/F(c) * forwardWeight;
161 end
162
163 % solve optimization problem
164 fF = solve();
165
166 for c = 1:k
167 % fprintf('Fitting MAMAP(2,m) F+B: F(%d) = %f -> %f\n', c, F(c), fF(c));
168 end
169
170 % compute parameters of D11,D12,...,D1k
171 q = zeros(3,k);
172 for c = 1:k
173 q(1,c) = 1/k;
174 end
175 for c = 1:k
176 q(2,c) = fF(c) * q_f(1,c) + q_0(1,c);
177 q(3,c) = fF(c) * q_f(2,c) + q_0(2,c);
178 end
179
180elseif form == 2 && r2 < degentol
181
182 % DEGENERATE CASE FOR gamma < 0
183 % fprintf('Fitting MAMAP(2,m) F+B: detected degenerate MMAP form\n');
184
185 if fbWeights(1) >= fbWeights(2)
186
187 % fprintf('Fitting MAMAP(2,m) F+B: fitting forward\n');
188
189 % coefficients: q(j,c) = F(c) * q_f(j,c) + q_0(j,c)
190 q_f = zeros(2,k);
191 q_0 = zeros(2,k);
192 for c = 1:k
193 % q1
194 q_f(1,c) = p(c) * ( -(r1 - 2)/((h1 + h2*(r1 - 1))*(r1 - 1)) );
195 q_0(1,c) = p(c) * ( 1 - (h1 + h2)/((r1 - 1)*(h1 - h2 + h2*r1)) );
196 % q2
197 q_f(2,c) = p(c) * ( -(r1 - 2)/(h1 + h2*(r1 - 1)) );
198 q_0(2,c) = p(c) * ( (h2*(r1 - 2))/(h1 - h2 + h2*r1) );
199 end
200
201 % inequality constraints
202 A = zeros(4*k,k);
203 b = zeros(4*k,1);
204 for c = 1:k
205 for j = 1:2
206 row = ((c-1)*4+(j-1)*2);
207 col = (c-1);
208 % q_j <= 1
209 A(row+1,col+1) = q_f(j,c);
210 b(row+1) = 1 - q_0(j,c);
211 % q_j >= 0
212 A(row+2,col+1) = - q_f(j,c);
213 b(row+2) = q_0(j,c);
214 end
215 end
216
217 % equality constraints
218 Aeq = zeros(2,k);
219 beq = ones(2,1);
220 for c = 1:k
221 % equality constraints
222 for j = 1:2
223 Aeq(j,c) = q_f(j,c);
224 beq(j) = beq(j) - q_0(j,c);
225 end
226 end
227
228 % objective function
229 H = zeros(k, k);
230 h = zeros(k, 1);
231 for c = 1:k
232 base = (c-1);
233 forwardWeight = (classWeights(c) * fbWeights(1));
234 H(base+1,base+1) = 2/F(c)^2 * forwardWeight;
235 h(base+1) = -2/F(c) * forwardWeight;
236 end
237
238 % solve optimization problem
239 fF = solve();
240
241 for c = 1:k
242 % fprintf('Fitting MAMAP(2,m) F+B: F(%d) = %f -> %f\n', c, F(c), fF(c));
243 end
244
245 % compute parameters of D11,D12,...,D1k
246 q = zeros(3,k);
247 for c = 1:k
248 for j = 1:2
249 q(j,c) = fF(c) * q_f(j,c) + q_0(j,c);
250 end
251 end
252 for c = 1:k
253 q(3,c) = 1/k;
254 end
255
256 else
257
258 % fprintf('Fitting MAMAP(2,m) F+B: fitting backward\n');
259
260 % coefficients: q(j,c) = F(c) * q_b(j,c) + q_0(j,c)
261 q_b = zeros(2,k);
262 q_0 = zeros(2,k);
263 for c = 1:k
264 % q1
265 q_b(1,c) = p(c) * ( -(r1 - 2)/((h2 + h1*(r1 - 1))*(r1 - 1)) );
266 q_0(1,c) = p(c) * ( 1 - (h1 + h2)/((r1 - 1)*(h2 - h1 + h1*r1)) );
267 % q2
268 q_b(2,c) = p(c) * ( -(r1 - 2)/(h2 + h1*(r1 - 1)) );
269 q_0(2,c) = p(c) * ( (h1*(r1 - 2))/(h2 - h1 + h1*r1) );
270 end
271
272 % inequality constraints
273 A = zeros(4*k,k);
274 b = zeros(4*k,1);
275 for c = 1:k
276 for j = 1:2
277 row = ((c-1)*4+(j-1)*2);
278 col = (c-1);
279 % q_j <= 1
280 A(row+1,col+1) = q_b(j,c);
281 b(row+1) = 1 - q_0(j,c);
282 % q_j >= 0
283 A(row+2,col+1) = - q_b(j,c);
284 b(row+2) = q_0(j,c);
285 end
286 end
287
288 % equality constraints
289 Aeq = zeros(2,k);
290 beq = ones(2,1);
291 for c = 1:k
292 % equality constraints
293 for j = 1:2
294 Aeq(j,c) = q_b(j,c);
295 beq(j) = beq(j) - q_0(j,c);
296 end
297 end
298
299 % objective function
300 H = zeros(k, k);
301 h = zeros(k, 1);
302 for c = 1:k
303 base = (c-1);
304 backwardWeight = (classWeights(c) * fbWeights(2));
305 H(base+1,base+1) = 2/B(c)^2 * backwardWeight;
306 h(base+1) = -2/B(c) * backwardWeight;
307 end
308
309 % solve optimization problem
310 fB = solve();
311
312 for c = 1:k
313 % fprintf('Fitting MAMAP(2,m) F+B: B(%d) = %f -> %f\n', c, B(c), fB(c));
314 end
315
316 % compute parameters of D11,D12,...,D1k
317 q = zeros(3,k);
318 for c = 1:k
319 for j = 1:2
320 q(j,c) = fB(c) * q_b(j,c) + q_0(j,c);
321 end
322 end
323 for c = 1:k
324 q(3,c) = 1/k;
325 end
326
327 end
328
329else
330
331 % coefficients: q(j,c) = F(c) * q_f(j,c) + B(c) * q_b(j,c) + q_0(j,c)
332 q_f = zeros(3,k);
333 q_b = zeros(3,k);
334 q_0 = zeros(3,k);
335 for c = 1:k
336 if form == 1
337 % first canonical form (positive auto-correlation decay)
338 q_f(1,c) = 0;
339 q_b(1,c) = -(p(c)*(r1*r2 - r2 + 1))/((h2 - h1*r2)*(r1 - 1)*(r2 - 1));
340 q_0(1,c) = (p(c)*(h1 + h2 - h1*r2)*(r1*r2 - r2 + 1))/((h2 - h1*r2)*(r1 - 1)*(r2 - 1));
341 q_f(2,c) = -(p(c)*(r1*r2 - r2 + 1))/(r1*(h1 + h2*(r1 - 1))*(r2 - 1));
342 q_b(2,c) = -(p(c)*(r1*r2 - r2 + 1))/(r1*(h2 - h1*r2)*(r2 - 1));
343 q_0(2,c) = (p(c)*(r1*r2 - r2 + 1))/((r1 - 1)*(r2 - 1)) + (h1*p(c)*(r1*r2 - r2 + 1))/(r1*(h2 - h1*r2)*(r2 - 1)) - (h1*p(c)*(r1*r2 - r2 + 1))/(r1*(h1 + h2*(r1 - 1))*(r1 - 1)*(r2 - 1));
344 q_f(3,c) = -(p(c)*(r1*r2 - r2 + 1))/(r1*r2*(h1 - h2 + h2*r1));
345 q_b(3,c) = 0;
346 q_0(3,c) = (p(c)*(h1 + h2*r1)*(r1*r2 - r2 + 1))/(r1*r2*(h1 - h2 + h2*r1));
347 else
348 % second canonical form (negative auto-correlation decay)
349 q_f(1,c) = 0;
350 q_b(1,c) = -(p(c)*(r1 + r2 - r1*r2 - 2))/((r1 - 1)*(r2 - 1)*(h1 - h2 - h1*r1 + h1*r1*r2));
351 q_0(1,c) = (p(c)*(h2 + h1*r1 - h1*r1*r2)*(r1 + r2 - r1*r2 - 2))/((r1 - 1)*(r2 - 1)*(h1 - h2 - h1*r1 + h1*r1*r2));
352 q_f(2,c) = (p(c)*(r1 + r2 - r1*r2 - 2))/((r2 - 1)*(h1 - h2 + h2*r1));
353 q_b(2,c) = 0;
354 q_0(2,c) = -(h2*p(c)*(r1 + r2 - r1*r2 - 2))/((r2 - 1)*(h1 - h2 + h2*r1));
355 q_f(3,c) = (p(c)*(r1 + r2 - r1*r2 - 2))/(r2*(h1 + h2*(r1 - 1)));
356 q_b(3,c) = (p(c)*(r1 + r2 - r1*r2 - 2))/(r2*(h1 - h2 - h1*r1 + h1*r1*r2));
357 q_0(3,c) = (h1*p(c)*(r1 + r2 - r1*r2 - 2))/(r2*(h1 + h2*(r1 - 1))*(r1 - 1)) - (h1*p(c)*(r1 + r2 - r1*r2 - 2))/(r2*(h1 - h2 - h1*r1 + h1*r1*r2)) - (p(c)*(r1 + r2 - r1*r2 - 2))/(r2*(r1 - 1));
358 end
359 end
360
361 % inequality constraints
362 A = zeros(6*k,2*k);
363 b = zeros(6*k,1);
364 for c = 1:k
365 for j = 1:3
366 row = ((c-1)*6+(j-1)*2);
367 col = (c-1)*2;
368 % q_j <= 1
369 A(row+1,col+1) = q_f(j,c);
370 A(row+1,col+2) = q_b(j,c);
371 b(row+1) = 1 - q_0(j,c);
372 % q_j >= 0
373 A(row+2,col+1) = - q_f(j,c);
374 A(row+2,col+2) = - q_b(j,c);
375 b(row+2) = q_0(j,c);
376 end
377 end
378
379 % equality constraints
380 Aeq = zeros(3,2*k);
381 beq = ones(3,1);
382 for c = 1:k
383 % equality constraints
384 for j = 1:3
385 Aeq(j,(c-1)*2+1) = q_f(j,c);
386 Aeq(j,(c-1)*2+2) = q_b(j,c);
387 beq(j) = beq(j) - q_0(j,c);
388 end
389 end
390
391 % objective function
392 H = zeros(2*k, 2*k);
393 h = zeros(2*k, 1);
394 for c = 1:k
395 base = (c-1)*2;
396 forwardWeight = (classWeights(c) * fbWeights(1));
397 backwardweight = (classWeights(c) * fbWeights(2));
398 H(base+1,base+1) = 2/F(c)^2 * forwardWeight;
399 H(base+2,base+2) = 2/B(c)^2 * backwardweight;
400 h(base+1) = -2/F(c) * forwardWeight;
401 h(base+2) = -2/B(c) * backwardweight;
402 end
403
404 % solve optimization problem
405 x = solve();
406
407 % feasible set of moments
408 fF = zeros(k,1);
409 fB = zeros(k,1);
410 for c = 1:k
411 fF(c) = x((c-1)*2+1);
412 fB(c) = x((c-1)*2+2);
413 end
414
415 for c = 1:k
416 % fprintf('Fitting MAMAP(2,m) F+B: F(%d) = %f -> %f\n', c, F(c), fF(c));
417 % fprintf('Fitting MAMAP(2,m) F+B: B(%d) = %f -> %f\n', c, B(c), fB(c));
418 end
419
420 % compute parameters of D11,D12,...,D1k
421 q = zeros(3,k);
422 for c = 1:k
423 for j = 1:3
424 q(j,c) = fF(c) * q_f(j,c) + fB(c) * q_b(j,c) + q_0(j,c);
425 end
426 end
427end
428
429% compute D11,D12,...,D1k
430if form == 1
431 for c = 1:k
432 mmap{2+c} = mmap{2} .* [q(1,c) 0; q(2,c) q(3,c)];
433 end
434else
435 for c = 1:k
436 mmap{2+c} = mmap{2} .* [0 q(1,c); q(2,c) q(3,c)];
437 end
438end
439
440fF = mmap_forward_moment(mmap, 1);
441fB = mmap_backward_moment(mmap, 1);
442
443 function x = solve()
444 % fprintf('Fitting MAMAP(2,m) F+B: running quadratic programming solver...\n');
445 options = optimset('Algorithm','interior-point-convex','Display','none','MaxIter',3000);
446 %[x,fx,xflag] = quadprog(H, h, A, b, Aeq, beq, [], [], [], options);
447 lb = 1e-6*ones( size(A,2),1);
448 ub = 1e6*ones( size(A,2),1);
449 [x,fx,xflag]=QP(H, h, A, b, Aeq, beq, lb, ub, options);
450 if xflag ~= 1
451 error('Quadratic programming solver failed: %d\n', xflag);
452 end
453 fit_error = fx + length(x);
454 %fprintf('Fitting MAMAP(2,m) F+B: error = %e\n', fit_error);
455 end
456
457end