LINE Solver
MATLAB API documentation
Loading...
Searching...
No Matches
MultivariateNormal.m
1classdef MultivariateNormal < ContinuousDistribution
2 % MultivariateNormal Multivariate Normal (Gaussian) Distribution
3 %
4 % Represents a d-dimensional normal distribution with mean vector mu and
5 % covariance matrix Sigma. The distribution can be used standalone or
6 % within a Prior for mixture models.
7 %
8 % Constructor:
9 % mvn = MultivariateNormal(mu, Sigma)
10 %
11 % Args:
12 % mu - d-dimensional mean vector (column vector)
13 % Sigma - d x d positive definite covariance matrix
14 %
15 % Examples:
16 % % Create 2D normal distribution
17 % mu = [1; 2];
18 % Sigma = [1, 0.5; 0.5, 1];
19 % mvn = MultivariateNormal(mu, Sigma);
20 %
21 % % Generate samples
22 % samples = mvn.sample(100); % 100 x 2 matrix
23 %
24 % % Evaluate PDF
25 % pdf_val = mvn.evalPDF([1; 2]);
26 %
27 % % Extract marginal
28 % norm = mvn.getMarginalUniv(1);
29 %
30 % Copyright (c) 2012-2026, Imperial College London
31
32 properties
33 dimension; % Dimensionality
34 end
35
36 methods
37 function self = MultivariateNormal(mu, Sigma)
38 % Validate inputs
39 if ~isvector(mu)
40 line_error(mfilename, 'mu must be a vector');
41 end
42 if ~ismatrix(Sigma) || size(Sigma,1) ~= size(Sigma,2)
43 line_error(mfilename, 'Sigma must be square matrix');
44 end
45
46 % Force column vector
47 mu = mu(:);
48 d = length(mu);
49
50 % Check dimension consistency
51 if size(Sigma,1) ~= d
52 line_error(mfilename, 'Sigma dimensions must match mu length');
53 end
54
55 % Check positive definite
56 [~, p] = chol(Sigma);
57 if p ~= 0
58 line_error(mfilename, 'Sigma must be positive definite');
59 end
60
61 % Call superclass constructor
62 self@ContinuousDistribution('MultivariateNormal', 2, [-Inf, Inf]);
63
64 % Store dimension and parameters
65 self.dimension = d;
66 self.setParam(1, 'mu', mu);
67 self.setParam(2, 'Sigma', Sigma);
68
69 % Set mean of first component for Prior compatibility
70 self.mean = mu(1);
71 end
72
73 % =================== ACCESSORS ===================
74
75 function d = getDimension(self)
76 % Get the dimensionality of the distribution
77 d = self.dimension;
78 end
79
80 function mu = getMeanVector(self)
81 % Get the mean vector (d x 1)
82 mu = self.getParam(1).paramValue;
83 end
84
85 function Sigma = getCovariance(self)
86 % Get the covariance matrix (d x d)
87 Sigma = self.getParam(2).paramValue;
88 end
89
90 function R = getCorrelation(self)
91 % Get the correlation matrix
92 Sigma = self.getCovariance();
93 d = self.dimension;
94
95 % R(i,j) = Sigma(i,j) / (sqrt(Sigma(i,i)) * sqrt(Sigma(j,j)))
96 R = zeros(d, d);
97 for i = 1:d
98 for j = 1:d
99 std_i = sqrt(Sigma(i,i));
100 std_j = sqrt(Sigma(j,j));
101 if std_i > eps && std_j > eps
102 R(i,j) = Sigma(i,j) / (std_i * std_j);
103 else
104 R(i,j) = double(i == j);
105 end
106 end
107 end
108 end
109
110 % =================== DISTRIBUTION METHODS ===================
111
112 function MEAN = getMean(self)
113 % Get the mean of the first component (for Prior compatibility)
114 mu = self.getMeanVector();
115 MEAN = mu(1);
116 end
117
118 function VAR = getVar(self)
119 % Get the variance of the first component
120 Sigma = self.getCovariance();
121 VAR = Sigma(1,1);
122 end
123
124 function SCV = getSCV(self)
125 % Get squared coefficient of variation (not meaningful for multivariate)
126 SCV = NaN;
127 line_warning(mfilename, 'SCV not defined for multivariate distributions');
128 end
129
130 function SKEW = getSkewness(self)
131 % Get skewness (multivariate normal is symmetric)
132 SKEW = 0;
133 end
134
135 % =================== SAMPLING ===================
136
137 function X = sample(self, n)
138 % Generate n samples from the multivariate normal distribution
139 %
140 % Returns n x d matrix of samples
141
142 if nargin < 2
143 n = 1;
144 end
145
146 mu = self.getMeanVector();
147 Sigma = self.getCovariance();
148 d = self.dimension;
149
150 % Use mvnrnd if available (Statistics Toolbox)
151 if exist('mvnrnd', 'file') == 2
152 X = mvnrnd(mu', Sigma, n);
153 else
154 % Manual implementation using Cholesky decomposition
155 % X = mu + L*Z where L = chol(Sigma, 'lower'), Z ~ N(0,I)
156 L = chol(Sigma, 'lower');
157 Z = randn(d, n);
158 X = (mu + L*Z)';
159 end
160 end
161
162 % =================== PDF EVALUATION ===================
163
164 function p = evalPDF(self, x)
165 % Evaluate the multivariate normal PDF
166 %
167 % Args:
168 % x - d x 1 column vector or d x n matrix of points
169 %
170 % Returns:
171 % p - PDF value(s)
172
173 mu = self.getMeanVector();
174 Sigma = self.getCovariance();
175 d = self.dimension;
176
177 % Handle both column vector and matrix input
178 if isvector(x)
179 x = x(:)'; % Convert to row vector
180 end
181
182 if exist('mvnpdf', 'file') == 2
183 % Use Statistics Toolbox if available
184 p = mvnpdf(x, mu', Sigma);
185 else
186 % Manual calculation: f(x) = (2π)^(-d/2) |Σ|^(-1/2) exp(-0.5(x-μ)'Σ^(-1)(x-μ))
187 n = size(x, 1);
188 p = zeros(n, 1);
189
190 invSigma = inv(Sigma);
191 detSigma = det(Sigma);
192 normConst = 1 / sqrt((2*pi)^d * detSigma);
193
194 for i = 1:n
195 diff = x(i,:)' - mu;
196 p(i) = normConst * exp(-0.5 * diff' * invSigma * diff);
197 end
198 end
199 end
200
201 function Ft = evalCDF(self, t)
202 % CDF is not well-defined for multivariate distributions
203 line_error(mfilename, 'CDF is not defined for multivariate distributions');
204 end
205
206 function L = evalLST(self, s)
207 % LST is not well-defined for multivariate distributions
208 line_error(mfilename, 'LST is not defined for multivariate distributions');
209 end
210
211 % =================== MARGINAL DISTRIBUTIONS ===================
212
213 function mvn_marg = getMarginal(self, indices)
214 % Extract a marginal distribution for a subset of dimensions
215 %
216 % Args:
217 % indices - vector of dimension indices to keep (1-based)
218 %
219 % Returns:
220 % mvn_marg - MultivariateNormal for the marginal
221
222 mu = self.getMeanVector();
223 Sigma = self.getCovariance();
224
225 % Extract marginal mean and covariance
226 mu_marg = mu(indices);
227 Sigma_marg = Sigma(indices, indices);
228
229 % Create new MultivariateNormal for marginal
230 mvn_marg = MultivariateNormal(mu_marg, Sigma_marg);
231 end
232
233 function norm = getMarginalUniv(self, index)
234 % Extract a univariate marginal distribution
235 %
236 % Args:
237 % index - dimension index (1-based)
238 %
239 % Returns:
240 % norm - 1D MultivariateNormal distribution for that dimension
241
242 mu = self.getMeanVector();
243 Sigma = self.getCovariance();
244
245 mean_marg = mu(index);
246 var_marg = Sigma(index, index);
247
248 norm = MultivariateNormal(mean_marg, var_marg);
249 end
250
251 % =================== SERIALIZATION ===================
252
253 function s = toString(self)
254 % Convert to string representation
255 s = sprintf('jline.MultivariateNormal(d=%d)', self.dimension);
256 end
257 end
258
259 methods (Static)
260 function mvn = fitMeanAndCovariance(mu, Sigma)
261 % Create a multivariate normal from mean and covariance
262 %
263 % Args:
264 % mu - mean vector
265 % Sigma - covariance matrix
266 %
267 % Returns:
268 % mvn - MultivariateNormal distribution
269
270 mvn = MultivariateNormal(mu, Sigma);
271 end
272 end
273end