LINE Solver
MATLAB API documentation
Loading...
Searching...
No Matches
QueueNetworkLearningRNNLayer.m
1% Adapted from:
2% Garbi, G et al. (2020). Learning Queueing Networks by Recurrent Neural Networks
3
4classdef QueueNetworkLearningRNNLayer < nnet.layer.Layer % ...
5 % & nnet.layer.Formattable ... % (Optional)
6 % & nnet.layer.Acceleratable % (Optional)
7
8 properties
9 M
10 R
11 N
12 I
13 concurrency
14 end
15
16 properties (Learnable)
17 % Layer learnable parameters.
18 mu
19 P
20 end
21
22 properties (State)
23 hiddenState
24 end
25
26 methods
27 function layer = QueueNetworkLearningRNNLayer(M,R,concurrency)
28 % Create a QueueNetworkLearningRNNLayer
29
30 layer.M = M;
31 layer.R = R;
32 layer.concurrency = concurrency;
33 layer.I = eye(M);
34
35 layer.mu = layer.initializeUniformNonNeg([M,1]);
36
37 P = layer.initializeUniformNonNeg([M,M-1]);
38 P = P ./ sum(P, 2);
39 oneHot = (find(mod((0:(M^2)-1), M+1)~=0))'==1:M^2;
40 layer.P = reshape(reshape(P.', 1, []) * oneHot, M, M);
41
42 layer = layer.resetState();
43 end
44
45 function parameter = initializeUniformNonNeg(layer, sz)
46 a = 0.01;
47 b = 10.0;
48 parameter = a + (b-a).*rand(sz, 'single');
49 parameter = dlarray(parameter);
50 end
51
52 function [Z,state] = predict(layer, X)
53 numTimeSteps = size(X,4);
54 layer = layer.resetState();
55 Z = dlarray(zeros([size(X,1), size(X,2), 1, numTimeSteps]));
56 for t=1:numTimeSteps
57 currentT = X(1,1,1,t);
58 oldT = layer.hiddenState(1);
59 deltaT = currentT - oldT;
60 pm = abs(layer.mu)'.*(abs(layer.P) - layer.I);
61 pred = layer.hiddenState(2:end) + (deltaT*min(layer.hiddenState(2:end), layer.concurrency)) * pm;
62
63 if deltaT == 0
64 pred = X(:,2,1,t)';
65 end
66 xh_pred = cat(2, [currentT], pred);
67
68 layer.hiddenState = xh_pred;
69 Z(:,1,:,t) = currentT;
70 Z(:,2,:,t) = xh_pred(2:end);
71 end
72 state = xh_pred;
73 end
74
75 function layer = resetState(layer)
76 layer.hiddenState = zeros(layer.M+1,1)';
77 end
78 end
79end