LINE Solver
MATLAB API documentation
Loading...
Searching...
No Matches
estimator_rnn.m
1function estVal = estimator_rnn(self, nodes)
2 % Explainable RNN technique to learn queue structure and parameters
3 % with the aim of service demand estimation.
4 %
5 % Delegates training to a PyTorch-based RNN via a Python bridge script,
6 % removing the dependency on the MATLAB Deep Learning Toolbox.
7 %
8 % Adapted from:
9 % Garbi, G et al. (2020). Learning Queueing Networks by Recurrent Neural Networks
10 % Copyright (c) 2012-2026, Imperial College London
11 % All rights reserved.
12 % This code is released under the 3-Clause BSD License.
13
14 sn = self.model.getStruct;
15
16 % Use whole model nodes for RNN estimation
17 nodes = self.model.getNodes();
18 qLenTs = {};
19 qLenTrace = {};
20 % Obtain per-class metrics
21 minSampleCount = inf;
22 for n = 1:size(nodes, 1)
23 node = nodes{n};
24 numServers(n) = node.getNumberOfServers();
25 for r = 1:sn.nclasses
26 samples = self.getQLen(node, self.model.classes{r});
27 if isempty(samples)
28 error('Queue-length data for node %d in class %d is missing.', self.model.getNodeIndex(node), r);
29 else
30 for d = 1:length(samples)
31 qLen = samples{d};
32
33 if length(qLenTs) < d
34 qLenTs{d} = {};
35 qLenTrace{d} = {};
36 end
37
38 qLenTs{d}{n, r} = qLen.t;
39 qLenTrace{d}{n, r} = qLen.data;
40
41 if size(qLen.data, 1) < minSampleCount
42 minSampleCount = size(qLen.data, 1);
43 end
44 end
45 end
46 end
47 end
48
49 % Build 4D traces array: (traceCount x S x M x R+1)
50 traces = [];
51 for n = 1:length(qLenTs)
52 traceQL = cell2mat(qLenTrace{n});
53 traceQL = reshape(traceQL, size(traceQL, 1) / size(qLenTrace{n}, 1), size(qLenTrace{n}, 1), sn.nclasses);
54 traceTs = cell2mat(qLenTs{n});
55 traceTs = reshape(traceTs, size(traceTs, 1) / size(qLenTs{n}, 1), size(qLenTs{n}, 1), sn.nclasses);
56 traceQL = cat(3, traceTs, traceQL);
57 traces(end + 1, :, :, :) = traceQL(1:minSampleCount, :, :);
58 end
59
60 estVal = rnn_data(traces, numServers);
61end
62
63% Call Python PyTorch RNN via bridge script
64function demEst = rnn_data(avgQL, numServers)
65 % Locate the Python bridge script relative to this file
66 thisDir = fileparts(mfilename('fullpath'));
67 repoRoot = fullfile(thisDir, '..', '..', '..', '..');
68 bridgeScript = fullfile(repoRoot, 'python', 'line_inference', 'api', 'rnn_bridge.py');
69
70 % Create temporary files for data exchange
71 inputFile = [tempname, '.mat'];
72 outputFile = [tempname, '.mat'];
73
74 % Save input data
75 traces = avgQL; %#ok<NASGU>
76 save(inputFile, 'traces', 'numServers', '-v7');
77
78 % Call Python bridge
79 pythonCmd = sprintf('python3 %s %s %s', bridgeScript, inputFile, outputFile);
80
81 % Add the Python package to PYTHONPATH
82 pythonPath = fullfile(repoRoot, 'python');
83 envCmd = sprintf('PYTHONPATH=%s:$PYTHONPATH %s', pythonPath, pythonCmd);
84
85 [status, cmdout] = system(envCmd);
86
87 % Clean up input file
88 delete(inputFile);
89
90 if status ~= 0
91 if exist(outputFile, 'file')
92 delete(outputFile);
93 end
94 error('RNN Python bridge failed (exit code %d):\n%s', status, cmdout);
95 end
96
97 % Load results
98 result = load(outputFile);
99 delete(outputFile);
100
101 demEst = result.demandEst(:)';
102end
Definition mmt.m:124