1function estVal = estimator_rnn(self,
nodes)
2 % Explainable RNN technique to learn queue structure and parameters
3 % with the aim of service demand estimation.
5 % Delegates training to a PyTorch-based RNN via a Python bridge script,
6 % removing the dependency on the MATLAB Deep Learning Toolbox.
9 % Garbi, G et al. (2020). Learning Queueing Networks by Recurrent Neural Networks
10 % Copyright (c) 2012-2026, Imperial College London
11 % All rights reserved.
12 % This code
is released under the 3-Clause BSD License.
14 sn = self.model.getStruct;
16 % Use whole model
nodes for RNN estimation
17 nodes = self.model.getNodes();
20 % Obtain per-
class metrics
22 for n = 1:size(
nodes, 1)
24 numServers(n) = node.getNumberOfServers();
26 samples = self.getQLen(node, self.model.classes{r});
28 error(
'Queue-length data for node %d in class %d is missing.', self.model.getNodeIndex(node), r);
30 for d = 1:length(samples)
38 qLenTs{d}{n, r} = qLen.t;
39 qLenTrace{d}{n, r} = qLen.data;
41 if size(qLen.data, 1) < minSampleCount
42 minSampleCount = size(qLen.data, 1);
49 % Build 4D traces
array: (traceCount x S x M x R+1)
51 for n = 1:length(qLenTs)
52 traceQL = cell2mat(qLenTrace{n});
53 traceQL = reshape(traceQL, size(traceQL, 1) / size(qLenTrace{n}, 1), size(qLenTrace{n}, 1), sn.nclasses);
54 traceTs = cell2mat(qLenTs{n});
55 traceTs = reshape(traceTs, size(traceTs, 1) / size(qLenTs{n}, 1), size(qLenTs{n}, 1), sn.nclasses);
56 traceQL = cat(3, traceTs, traceQL);
57 traces(end + 1, :, :, :) = traceQL(1:minSampleCount, :, :);
60 estVal = rnn_data(traces, numServers);
63% Call Python PyTorch RNN via bridge script
64function demEst = rnn_data(avgQL, numServers)
65 % Locate the Python bridge script relative to
this file
66 thisDir = fileparts(mfilename(
'fullpath'));
67 repoRoot = fullfile(thisDir,
'..',
'..',
'..',
'..');
68 bridgeScript = fullfile(repoRoot,
'python',
'line_inference',
'api',
'rnn_bridge.py');
70 % Create temporary files
for data exchange
71 inputFile = [tempname,
'.mat'];
72 outputFile = [tempname,
'.mat'];
75 traces = avgQL; %#ok<NASGU>
76 save(inputFile,
'traces',
'numServers',
'-v7');
79 pythonCmd = sprintf(
'python3 %s %s %s', bridgeScript, inputFile, outputFile);
81 % Add the Python
package to PYTHONPATH
82 pythonPath = fullfile(repoRoot, 'python');
83 envCmd = sprintf(
'PYTHONPATH=%s:$PYTHONPATH %s', pythonPath, pythonCmd);
85 [status, cmdout] = system(envCmd);
91 if exist(outputFile,
'file')
94 error('RNN Python bridge failed (exit code %d):\n%s', status, cmdout);
98 result = load(outputFile);
101 demEst = result.demandEst(:)';