1classdef rl_td_agent < handle
5 vSize; % size of value function
6 QSize; % size of Q function
7 epsilon = 1; % explore-exploit rate
8 eps_decay = 0.99; % explore-exploit rate decay
9 lr = 0.05; % learning rate
13 function obj=rl_td_agent(lr, eps, epsDecay)
16 obj.eps_decay = epsDecay;
23 function reset(obj, env)
31 function v = getValueFunction(obj)
35 function Q = getQFunction(obj)
39% function p = getPolicy(obj)
40% p = zeros(size(obj.v));
44 function solve(obj, env)
47 obj.v = zeros((zeros(1, env.actionSize)+env.stateSize + 5)); % value function
48 obj.Q = rand([(zeros(1, env.actionSize)+env.stateSize + 5), env.actionSize]); % Q function
49 obj.vSize = size(obj.v);
50 obj.QSize = size(obj.Q);
52 x = zeros(1, env.actionSize); % initial state
53 n = zeros(1, env.actionSize); % initial previous state
54 % t_prev = 0; % time of last
event
55 t = 0; % time of current
event
56 % dt = 0; % time period between two successive events
57 c = 0; % incurred costs between the
visits
58 T = 0; % total discounted elapsed time
59 C = 0; % total discounted costs
65 while j < num_episodes
67 line_printf(mfilename,sprintf(
'running episode #%d .\n',j));
71 eps = eps * obj.eps_decay;
75 [dt, depNode] = env.sample(); % how to successive sampling
81 if ismember(depNode, env.idxOfSourceInNodes) %
new job
82 if env.isInActionSpace(env.model.nodes)
83 % create an exploit-explore policy
84 next_locs = zeros(env.actionSize, env.actionSize) + x + 1 + eye(env.actionSize);
85 next_states = obj.get_state_from_locs(obj.vSize, next_locs);
86 policy = obj.createGreedyPolicy(obj.v(next_states), eps, env.actionSize);
88 action = sum(rand >= cumsum([0, policy]));
90 action = find(x==min(x)); % JSQ
92 action = randomsample(action, 1);
96 x(action) = x(action) + 1;
98 % env.model.nodes{action+1}.state = State.fromMarginal(model, model.stations{action+1}, sum(model.stations{action+1}.state)+1);
99 % State.afterEvent(model, queue1, queue1.space, sample.event{1}, 1)
100 % State.afterEvent(sn, ind, inspace, event,
class, isSimulation)
102 elseif ismember(depNode, env.idxOfQueueInNodes) % dep from Queue, idx: Node{depNode}
103 x(env.idxOfQueueInNodes == depNode) = max(0, x(env.idxOfQueueInNodes == depNode) - 1);
105 % State.afterEvent(sn, ind, inspace, event,
class, isSimulation)
108 if env.isInStateSpace(env.model.nodes)
110 T = env.gamma * T + t;
111 C = env.gamma * C + c;
112 mean_cost_rate = C/T;
114 prev_state = obj.get_state_from_loc(obj.vSize, n+1);
115 cur_state = obj.get_state_from_loc(obj.vSize, x+1);
116 obj.v(prev_state) = (1-obj.lr)*obj.v(prev_state) + obj.lr*(c - t*mean_cost_rate + obj.v(cur_state)); % here
"obj.v(cur_state) * env.gamma" ?
117 obj.v = obj.v - obj.v(1);
127 function s = get_state_from_locs(obj, objSize, locs)
128 s = zeros(1, size(locs,1));
130 s(i) = obj.get_state_from_loc(objSize, locs(i,:));
136 function policy = createGreedyPolicy(state_Q, epsilon, nA)
137 policy = ones(1, nA) * epsilon / nA;
138 argmin = find(state_Q-min(state_Q)<GlobalConstants.FineTol);
139 policy(argmin) = policy(argmin) + (1-epsilon)/length(argmin);
142 function s = get_state_from_loc(objSize, loc)
144 if size(objSize,2) == size(loc, 2)
145 for i=1:size(objSize,2)
149 s = s + (loc(i)-1) * prod(objSize(1:(i-1)));