1classdef rl_td_agent_general < handle %
class for TD learning and TD control
5 vSize; % size of value function
6 epsilon = 1; % explore-exploit rate
7 eps_decay = 0.9999; % explore-exploit rate decay
8 lr = 0.1; % learning rate
12 function obj = rl_td_agent_general(lr, eps, epsDecay)
15 obj.eps_decay = epsDecay;
20 function reset(obj, env)
26 function v = getValueFunction(obj)
32 % TD learning
for value function with heuristic routing strategy
33 function v = solve_for_fixed_policy(obj, env, num_episodes) % num_epsiodes = 10^4 ususally
37 obj.v = zeros((zeros(1, env.nqueues)+env.stateSize + 1)); % value function
38 obj.vSize = size(obj.v);
40 t = 0; % time of current
event
41 c = 0; % incurred costs between the
visits
42 T = 0; % total discounted elapsed time
43 C = 0; % total discounted costs
44 x = zeros(1, env.nqueues); % initial state
45 n = zeros(1, env.nqueues); % initial previous state
49 while j < num_episodes
51 line_printf(
'running episode #%d \n',j);
54 [dt, depNode, arvNode, sample] = env.sample();
58 if ismember(depNode, env.idxOfQueueInNodes) % Event involves departure from server i
59 depServer = find(env.idxOfQueueInNodes == depNode);
60 x(depServer) = x(depServer) - 1;
63 if ismember(arvNode, env.idxOfQueueInNodes) % Event involves Arrival at server j
64 arvServer = find(env.idxOfQueueInNodes == arvNode);
65 x(arvServer) = x(arvServer) + 1;
70 if env.isInStateSpace(x)
72 T = env.gamma * T + t;
73 C = env.gamma * C + c;
76 prev_state = num2cell(n+1); % obj.get_state_from_loc(obj.vSize, n+1);
77 cur_state = num2cell(x+1); % obj.get_state_from_loc(obj.vSize, x+1);
78 obj.v(prev_state{:}) = (1-obj.lr)*obj.v(prev_state{:}) + obj.lr*(c - t*mean_cost_rate + obj.v(cur_state{:})); %(1-obj.lr)*obj.v(prev_state) + obj.lr*(c - t*mean_cost_rate + obj.v(cur_state));
79 obj.v = obj.v - obj.v(1);
93 % TD Control with Tabular value function
94 function value_function = solve(obj, env, num_episodes) % num_epsiodes = 10^4 ususally
97 obj.v = zeros((zeros(1, env.nqueues)+env.stateSize + 1)); % value function
98 obj.vSize = size(obj.v);
100 t = 0; % time of current
event
101 c = 0; % incurred costs between the
visits
102 T = 0; % total discounted elapsed time
103 C = 0; % total discounted costs
104 x = zeros(1, env.nqueues); % initial state
105 n = zeros(1, env.nqueues); % initial previous state
110 while j < num_episodes
112 line_printf(
'running episode #%d .\n',j);
116 eps = eps * obj.eps_decay;
118 [dt, depNode, arvNode, sample] = env.sample();
122 if ismember(depNode, env.idxOfQueueInNodes) % Event involves departure from server i
123 depServer = find(env.idxOfQueueInNodes == depNode);
124 x(depServer) = max(0, x(depServer) - 1);
127 if ismember(depNode, env.idxOfActionNodes) && env.isInActionSpace(x) % actions wanted at server i, and in action space
129 actions = env.actionSpace(depNode); % dep at node i, possible actions are [k,l,m]
132 % create an exploit-explore policy
133 next_values = obj.gen_next_values(env, x, actions);
134 policy = obj.createGreedyPolicy(next_values, eps, length(actions));
136 arvNode = actions(sum(rand >= cumsum([0, policy])));
138 % update sample with
new arvNode
139 for i = 1:length(sample.event)
140 if sample.event{i}.event == EventType.ARV
141 sample.event{i}.node = arvNode;
148 % update current state and model
149 x(env.idxOfQueueInNodes == arvNode) = x(env.idxOfQueueInNodes == arvNode) + 1;
153 if env.isInStateSpace(x) % in State Space, update state value
155 T = env.gamma * T + t;
156 C = env.gamma * C + c;
157 mean_cost_rate = C/T;
159 prev_state = num2cell(n+1);
160 cur_state = num2cell(x+1);
161 obj.v(prev_state{:}) = (1-obj.lr)*obj.v(prev_state{:}) + obj.lr*(c - t*mean_cost_rate + obj.v(cur_state{:})); % here
"obj.v(cur_state) * env.gamma" ?
162 obj.v = obj.v - obj.v(1);
170 value_function = obj.v;
176 % TD Control with HashMap value fn
177 function [X, Y]=solve_by_hashmap(obj, env, num_episodes)
180 pointValues = containers.Map; % hashmap value fn
181 pointValues(num2str(zeros(1, env.nqueues))) = 0;
182 pointValues(
'external')=0;
184 t = 0; % time of current
event
185 c = 0; % incurred costs between the
visits
186 T = 0; % total discounted elapsed time
187 C = 0; % total discounted costs
188 x = zeros(1, env.nqueues); % initial state
189 n = zeros(1, env.nqueues); % initial previous state
194 while j < num_episodes
196 line_printf(
'running episode #%d .\n',j);
199 %
if mod(j, 100) == 0
200 eps = eps * obj.eps_decay;
203 [dt, depNode, arvNode, sample] = env.sample();
207 if ismember(depNode, env.idxOfQueueInNodes) % Event involves departure from server i
208 depServer = find(env.idxOfQueueInNodes == depNode);
209 x(depServer) = max(0, x(depServer) - 1);
213 if ismember(depNode, env.idxOfActionNodes) && env.isInActionSpace(x) % actions wanted at server i, and in action space
215 actions = env.actionSpace(depNode); % dep at node i, possible actions are [k,l,m]
217 % create an exploit-explore policy
218 nextPointValues = zeros(1, length(actions));
219 for act_i = 1 : length(actions)
220 q_idx = find(env.idxOfQueueInNodes == actions(act_i));
222 tmp_next_state(q_idx) = tmp_next_state(q_idx) + 1;
223 if pointValues.isKey(num2str(tmp_next_state))
224 nextPointValues(act_i) = pointValues(num2str(tmp_next_state));
226 nextPointValues(act_i) = pointValues(
'external');
229 policy = obj.createGreedyPolicy(nextPointValues, eps, length(actions));
231 arvNode = actions(sum(rand >= cumsum([0, policy])));
234 for i = 1:length(sample.event)
235 if sample.event{i}.event == EventType.ARV
236 sample.event{i}.node = arvNode;
243 % update current state and model
244 x(env.idxOfQueueInNodes == arvNode) = x(env.idxOfQueueInNodes == arvNode) + 1;
247 if env.isInStateSpace(x) % in State Space, update state value
249 T = env.gamma * T + t;
250 C = env.gamma * C + c;
251 mean_cost_rate = C/T;
253 if ~pointValues.isKey(num2str(n))
254 pointValues(num2str(n)) = pointValues(
'external');
257 if pointValues.isKey(num2str(x))
258 pointValues(num2str(n)) = (1-obj.lr)*pointValues(num2str(n)) + obj.lr*(c-t*mean_cost_rate + pointValues(num2str(x)));
260 pointValues(num2str(n)) = (1-obj.lr)*pointValues(num2str(n)) + obj.lr*(c-t*mean_cost_rate + pointValues(
'external'));
264 substractor = pointValues(num2str(n));
265 for k = keys(pointValues)
266 pointValues(k{1}) = pointValues(k{1}) - substractor;
278 pointValues.remove(
'external');
279 X = zeros(pointValues.Count, 1 + env.nqueues);
280 Y = zeros(pointValues.Count, 1);
282 for k = keys(pointValues)
283 X(iterator, :) = [1 str2num(k{1})];
284 Y(iterator, :) = pointValues(k{1});
285 iterator = iterator + 1;
292 % TD control
using linear value fn approximator:
293 % v(q1,q2,...,qn) = w1*q1 + w2*q2 + ... + wn*qn (linear fn)
294 function [X, Y, coeff]=solve_by_linear(obj, env, num_episodes)
295 [X, Y] = obj.solve_by_hashmap(env, num_episodes);
297 coeff = regress(Y, X);
301 % TD control
using quadratic value fn approximator:
302 % v(q1,q2,...,qn) = sum_{i,j} w_{ij} * q_i * q_j (quadratic fn)
303 function [X, Y, coeff]=solve_by_quad(obj, env, num_episodes)
304 [X, Y] = obj.solve_by_hashmap(env, num_episodes);
309 X(:,end+1) = X(:,i).* X(:,j);
313 coeff = regress(Y, X);
317 function values=gen_next_values(obj, env, cur_state, actions) % cur_state = x
318 values = zeros(1, length(actions));
319 for act_i = 1 : length(actions)
320 q_idx = find(env.idxOfQueueInNodes == actions(act_i));
321 tmp_loc = cur_state + 1;
322 tmp_loc(q_idx) = tmp_loc(q_idx) + 1;
323 tmp_idx = num2cell(tmp_loc);
324 values(act_i) = obj.v(tmp_idx{:});
332 function policy = createGreedyPolicy(state_Q, epsilon, nA)
333 policy = ones(1, nA) * epsilon / nA;
334 argmin = find(state_Q == min(state_Q));
335 policy(argmin) = policy(argmin) + (1-epsilon)/length(argmin);