LINE Solver
MATLAB API documentation
Loading...
Searching...
No Matches
rl_td_agent.m
1classdef rl_td_agent < handle
2 properties
3 v; % value function
4 Q; % Q function
5 vSize; % size of value function
6 QSize; % size of Q function
7 epsilon = 1; % explore-exploit rate
8 eps_decay = 0.99; % explore-exploit rate decay
9 lr = 0.05; % learning rate
10 end
11
12 methods
13 function obj=rl_td_agent(lr, eps, epsDecay)
14 obj.lr = lr;
15 obj.epsilon = eps;
16 obj.eps_decay = epsDecay;
17 obj.v = 0;
18 obj.vSize = 0;
19 obj.Q = 0;
20 obj.QSize = 0;
21 end
22
23 function reset(obj, env)
24 obj.v = 0;
25 obj.vSize = 0;
26 obj.Q = 0;
27 obj.QSize = 0;
28 env.reset();
29 end
30
31 function v = getValueFunction(obj)
32 v = obj.v;
33 end
34
35 function Q = getQFunction(obj)
36 Q = obj.Q;
37 end
38
39% function p = getPolicy(obj)
40% p = zeros(size(obj.v));
41%
42% end
43
44 function solve(obj, env)
45 obj.reset(env);
46
47 obj.v = zeros((zeros(1, env.actionSize)+env.stateSize + 5)); % value function
48 obj.Q = rand([(zeros(1, env.actionSize)+env.stateSize + 5), env.actionSize]); % Q function
49 obj.vSize = size(obj.v);
50 obj.QSize = size(obj.Q);
51
52 x = zeros(1, env.actionSize); % initial state
53 n = zeros(1, env.actionSize); % initial previous state
54 % t_prev = 0; % time of last event
55 t = 0; % time of current event
56 % dt = 0; % time period between two successive events
57 c = 0; % incurred costs between the visits
58 T = 0; % total discounted elapsed time
59 C = 0; % total discounted costs
60
61 num_episodes = 1e4;
62 eps = obj.epsilon;
63 j = 0;
64
65 while j < num_episodes
66 if mod(j, 1e3)==0
67 line_printf(mfilename,sprintf('running episode #%d .\n',j));
68 end
69
70 % if mod(j, 100) == 0
71 eps = eps * obj.eps_decay;
72 % end
73
74 % t_prev = t;
75 [dt, depNode] = env.sample(); % how to successive sampling
76 t = dt + t;
77 % t = dt + t_prev;
78
79 c = c + sum(x) * dt;
80
81 if ismember(depNode, env.idxOfSourceInNodes) % new job
82 if env.isInActionSpace(env.model.nodes)
83 % create an exploit-explore policy
84 next_locs = zeros(env.actionSize, env.actionSize) + x + 1 + eye(env.actionSize);
85 next_states = obj.get_state_from_locs(obj.vSize, next_locs);
86 policy = obj.createGreedyPolicy(obj.v(next_states), eps, env.actionSize);
87
88 action = sum(rand >= cumsum([0, policy]));
89 else
90 action = find(x==min(x)); % JSQ
91 if length(action)>1
92 action = randomsample(action, 1);
93 end
94 end
95
96 x(action) = x(action) + 1;
97 env.update(x);
98 % env.model.nodes{action+1}.state = State.fromMarginal(model, model.stations{action+1}, sum(model.stations{action+1}.state)+1);
99 % State.afterEvent(model, queue1, queue1.space, sample.event{1}, 1)
100 % State.afterEvent(sn, ind, inspace, event, class, isSimulation)
101
102 elseif ismember(depNode, env.idxOfQueueInNodes) % dep from Queue, idx: Node{depNode}
103 x(env.idxOfQueueInNodes == depNode) = max(0, x(env.idxOfQueueInNodes == depNode) - 1);
104 env.update(x);
105 % State.afterEvent(sn, ind, inspace, event, class, isSimulation)
106 end
107
108 if env.isInStateSpace(env.model.nodes)
109 j = j + 1;
110 T = env.gamma * T + t;
111 C = env.gamma * C + c;
112 mean_cost_rate = C/T;
113
114 prev_state = obj.get_state_from_loc(obj.vSize, n+1);
115 cur_state = obj.get_state_from_loc(obj.vSize, x+1);
116 obj.v(prev_state) = (1-obj.lr)*obj.v(prev_state) + obj.lr*(c - t*mean_cost_rate + obj.v(cur_state)); % here "obj.v(cur_state) * env.gamma" ?
117 obj.v = obj.v - obj.v(1);
118
119 t = 0;
120 c = 0;
121 n = x;
122 end
123
124 end
125 end
126
127 function s = get_state_from_locs(obj, objSize, locs)
128 s = zeros(1, size(locs,1));
129 for i=1:size(locs,1)
130 s(i) = obj.get_state_from_loc(objSize, locs(i,:));
131 end
132 end
133 end
134
135 methods(Static)
136 function policy = createGreedyPolicy(state_Q, epsilon, nA)
137 policy = ones(1, nA) * epsilon / nA;
138 argmin = find(state_Q-min(state_Q)<GlobalConstants.FineTol);
139 policy(argmin) = policy(argmin) + (1-epsilon)/length(argmin);
140 end
141
142 function s = get_state_from_loc(objSize, loc)
143 s = 0;
144 if size(objSize,2) == size(loc, 2)
145 for i=1:size(objSize,2)
146 if i==1
147 s = s + loc(i);
148 else
149 s = s + (loc(i)-1) * prod(objSize(1:(i-1)));
150 end
151 end
152 end
153 end
154
155 end
156
157end
158
159