LINE Solver
MATLAB API documentation
Loading...
Searching...
No Matches
m3pp22_fitc_approx_cov_multiclass.m
1function [FIT] = m3pp22_fitc_approx_cov_multiclass(mmpp, ai, st3, t3)
2% Fits a M3PP(2,2) given the underlying MMPP(2).
3% INPUT
4% - mmpp: underlying mmpp(2)
5% - ai: rates of the two classes
6% - st3: count covariance between the two classes at scale t3
7% - t3: third time scale
8
9% number of classes
10m = size(ai,1);
11
12if m > 2
13 error('No more than two classes supported');
14end
15
16FIT = mmpp;
17
18% degenerate case: poisson process
19if size(FIT{1},1) == 1
20 D0 = FIT{1};
21 D1 = FIT{2};
22 FIT = cell(1,2+m);
23 FIT{1} = D0;
24 FIT{2} = D1;
25 a = sum(ai);
26 pi = ai/a;
27 for i = 1:m
28 FIT{2+i} = pi(i)*D1;
29 end
30 return;
31end
32
33% just a single class!
34if m == 1
35 FIT = {FIT{1},FIT{2},FIT{2}};
36 return;
37end
38
39l1 = FIT{2}(1,1);
40l2 = FIT{2}(2,2);
41r1 = FIT{1}(1,2);
42r2 = FIT{1}(2,1);
43t = t3;
44a1 = ai(1);
45w0 = (2*r1*(1 - exp(-(r1+r2)*t)- (r1+r2)*t)*(a1^2*(r1+r2)-a1*r2*(l1-l2)))/(r2*(r1+r2)^3);
46w1 = -(2*r1*(1 - exp(-(r1+r2)*t)- (r1+r2)*t)*(2*a1*l2*(r1+r2)-l2*r2*(l1-l2)))/(r2*(r1 + r2)^3);
47w2 = ((2*l2^2*r2*t)*(r1+r2) + 2*l2^2*r1*(1 - exp(-(r1+r2)*t)))/(r2*(r1 + r2)^2) - (2*l2^2*t)/r2;
48w3 = (r1+r2)/(l2*r1);
49w4 = (l1*r2)/(l2*r1);
50
51% bounds for the first and the second root
52L1 = -inf; L2 = -inf;
53U1 = +inf; U2 = +inf;
54
55% if set to 1, the first (second) root is never feasible
56infeasible1 = 0;
57infeasible2 = 0;
58
59% defined for convenience
60z = w0 - w1^2/(4*w2);
61
62% impose square root argument is >= 0
63if w2 > 0
64 L1 = max(L1, z);
65 L2 = max(L2, z);
66elseif w2 < 0
67 U1 = min(U1, z);
68 U2 = min(U2, z);
69end
70
71% impose q2 >= 0
72% first root
73if w1 >= 0
74 L1 = max(L1, w0);
75elseif w2 < 0
76 infeasible1 = 1;
77end
78% second root
79if w1 <= 0
80 U2 = min(U2, w0);
81elseif w2 > 0
82 infeasible2 = 1;
83end
84
85% impose q1 >= 0
86tmp = 2*a1*w3*w2 + w1;
87% first root
88if tmp >= 0
89 U1 = min(U1, z + tmp^2/(4*w2));
90elseif w2 > 0
91 infeasible1 = 1;
92end
93% second root
94if tmp <= 0
95 L2 = max(L2, z + tmp^2/(4*w2));
96elseif w2 < 0
97 infeasible2 = 1;
98end
99
100% impose q2 <= 1
101tmp = 2*w2+w1;
102% first root
103if tmp >= 0
104 U1 = min(U1, z + tmp^2/(4*w2));
105elseif w2 > 0
106 infeasible1 = 1;
107end
108% second root
109if tmp <= 0
110 L2 = max(L2, z + tmp^2/(4*w2));
111elseif w2 < 0
112 infeasible2 = 1;
113end
114
115% impose q1 <= 1
116tmp = (2*a1*w2*w3-2*w2*w4+w1);
117% first root
118if tmp >= 0
119 L1 = max(L1, z + tmp^2/(4*w2));
120elseif w2 < 0
121 infeasible1 = 1;
122end
123% second root
124if tmp <= 0
125 U2 = min(U2, z + tmp^2/(4*w2));
126elseif w2 > 0
127 infeasible2 = 1;
128end
129
130if infeasible1 && infeasible2
131 error('Empty feasibility region. This should not happen.');
132end
133
134% compute feasible covariance
135if infeasible2
136 sigma = max(min(st3,U1), L1);
137 root = 1;
138elseif infeasible1
139 sigma = max(min(st3,U2), L2);
140 root = 2;
141else
142 sigma1 = max(min(st3,U1), L1);
143 sigma2 = max(min(st3,U2), L2);
144 if abs(sigma1-st3) < abs(sigma2-st3)
145 sigma = sigma1;
146 root = 1;
147 else
148 sigma = sigma2;
149 root = 2;
150 end
151end
152
153% compute parameters
154if root == 1
155 q2 = (-w1 + sqrt(w1^2 - 4*w2*(w0-sigma)) )/(2*w2);
156else
157 q2 = (-w1 - sqrt(w1^2 - 4*w2*(w0-sigma)) )/(2*w2);
158end
159q1 = ( a1*(r1 + r2) - l2*q2*r1 )/(l1*r2);
160
161% check feasibility just to be sure
162tol = 1e-8;
163if (q1 >= tol && q1 <= 1+tol) && (q2 >= tol && q2 <= 1+tol)
164 q1 = min(max(q1,0),1);
165 q2 = min(max(q2,0),1);
166else
167 error('Parameters are infeasible. This should not happen.');
168end
169
170% assemble M3PP[2]
171D0 = FIT{1};
172D1 = FIT{2};
173FIT = {D0, D1, D1 .* [q1 0; 0 q2], D1 .* [(1-q1) 0; 0 (1-q2)]};
174
175% print per-class rates and covariance
176% fai = mmap_count_mean(FIT,1);
177% for i = 1:m
178% fprintf('Rate class %d: input = %.3f, output = %.3f\n', ...
179% i, ai(i), fai(i));
180% end
181% fsigma = 1/2*( map_count_var(FIT,t3) - sum(mmap_count_var(FIT,t3)) );
182% fprintf('Covariance(t3): input = %.4f, output = %.4f\n', st3, fsigma);
183
184end