LINE Solver
MATLAB API documentation
Loading...
Searching...
No Matches
maph2m_fit_multiclass.m
1function [maph,fB] = maph2m_fit_multiclass(aph,p,B,classWeights)
2% Performs approximate fitting of a MAPH given the underlying APH(2),
3% the class probabilities (always fitted exactly), the backward moments,
4% and the one-step class transition probabilities.
5% Input
6% - aph: second-order APH underlying the MAPH[m]
7% - p: vector of class probabilities
8% - B: vector of backward moments
9% - classWeights: optional vector of weights for each class
10% Output
11% - maph: fitted MAPH[m]
12% - fB: vector of optimal feasible backward moments
13
14if (size(aph{1},1) ~= 2)
15 error('Underlying APH must be of second-order.');
16end
17if (aph{1}(2,1) ~= 0)
18 error('Underlying APH must be acyclic');
19end
20if (aph{2}(1,2) ~= 0 || aph{2}(2,2) ~= 0)
21 error('Underlying APH must be in canonical acyclic form');
22end
23
24fprintf('Fitting MAPH(2,m)\n');
25
26% number of classes
27k = length(p);
28
29% default weights
30if nargin == 3 || isempty(classWeights)
31 classWeights = ones(k,1);
32end
33
34% result
35maph = cell(1,2+k);
36maph{1} = aph{1};
37maph{2} = aph{2};
38
39% get parameters of the underlying AMAP(2)
40h1 = -1/aph{1}(1,1);
41h2 = -1/aph{1}(2,2);
42r1 = aph{1}(1,2) * h1;
43
44% set tolerance constants
45degentol = 1e-6;
46feastol = 1e-8;
47
48q = zeros(2,k);
49
50if abs(1-r1) < degentol
51
52 % DEGENERATE
53 fprintf('Fitting MAPH(2,m): detected degenerate form\n');
54
55 % only one degree of freedom: match class probabilites
56 for c = 1:k
57 q(1,c) = p(c);
58 q(2,c) = p(c);
59 end
60
61else
62
63 % FULL
64
65 % coefficients: q(j,c) = F(c) * q_b(j,c) + q_0(j,c)
66 q_b = zeros(2,k);
67 q_0 = zeros(2,k);
68 for c = 1:k
69 % q1
70 q_b(1,c) = p(c) * ( 1/(h2*(r1 - 1)) );
71 q_0(1,c) = p(c) * ( -(h1 + h2)/(h2*(r1 - 1)) );
72 % q2
73 q_b(2,c) = p(c) * ( 1/(h2*r1) );
74 q_0(2,c) = p(c) * ( -h1/(h2*r1) );
75 end
76
77 % inequality constraints
78 A = zeros(4*k,k);
79 b = zeros(4*k,1);
80 for c = 1:k
81 for j = 1:2
82 row = ((c-1)*4+(j-1)*2);
83 col = (c-1);
84 % q_j <= 1
85 A(row+1,col+1) = q_b(j,c);
86 b(row+1) = 1 - q_0(j,c);
87 % q_j >= 0
88 A(row+2,col+1) = - q_b(j,c);
89 b(row+2) = q_0(j,c);
90 end
91 end
92
93 % equality constraints
94 Aeq = zeros(2,k);
95 beq = ones(2,1);
96 for c = 1:k
97 % equality constraints
98 for j = 1:2
99 Aeq(j,c) = q_b(j,c);
100 beq(j) = beq(j) - q_0(j,c);
101 end
102 end
103
104 % objective function
105 H = zeros(k, k);
106 h = zeros(k, 1);
107 for c = 1:k
108 base = (c-1);
109 backwardWeight = classWeights(c);
110 H(base+1,base+1) = 2/B(c)^2 * backwardWeight;
111 h(base+1) = -2/B(c) * backwardWeight;
112 end
113
114 % solve optimization problem
115 fB = solve();
116
117 for c = 1:k
118 fprintf('Fitting MAPH(2,m): B(%d) = %f -> %f\n', c, B(c), fB(c));
119 end
120
121 % compute parameters of D11,D12,...,D1k
122 q = zeros(2,k);
123 for c = 1:k
124 for j = 1:2
125 q(j,c) = fB(c) * q_b(j,c) + q_0(j,c);
126 end
127 end
128
129end
130
131% check parameter feasibility
132for c = 1:k
133 if ~(isfeasible(q(1,:)) && isfeasible(q(2,:)))
134 error('Fitting MAPH(2,m): Feasibility could not be restored');
135 end
136end
137% parameters feasible within feastol: restrict to [0,1]
138q(1,:) = fix(q(1,:));
139q(2,:) = fix(q(2,:));
140
141% compute D11,D12,...,D1k
142for c = 1:k
143 maph{2+c} = maph{2} .* [q(1,c) 0; q(2,c) 0];
144end
145
146 function feas = isfeasible(qj)
147 feas = min(qj) >= -feastol && sum(qj) <= (1+feastol);
148 end
149
150 function QJ = fix(qj)
151 QJ = zeros(1,k);
152 for cc = 1:k
153 QJ(cc) = max(qj(cc),0);
154 end
155 QJ = QJ ./ sum(QJ);
156 end
157
158 function x = solve()
159 fprintf('Fitting MAPH(2,m): running quadratic programming solver...\n');
160 options = optimset('Algorithm','interior-point-convex','Display','none');
161 %[x,fx,xflag] = quadprog(H, h, A, b, Aeq, beq, [], [], [], options);
162 lb = 1e-6*ones( size(A,2),1);
163 ub = 1e6*ones( size(A,2),1);
164 [x,fx,xflag]=QP(H, h, A, b, Aeq, beq, lb, ub, options);
165 if xflag ~= 1
166 error('Quadratic programming solver failed: %d\n', exit);
167 end
168 fit_error = fx + length(x);
169 fprintf('Fitting MAPH(2,m): error = %f\n', fit_error);
170 end
171
172end % end function