Skip to content

Commit 7e477c7

Browse files
committed
asynchronous A2C
1 parent 6b184f1 commit 7e477c7

8 files changed

+205
-2
lines changed

Hard_A2C.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tensorflow.keras.layers import Input, Dense, Concatenate
1515
from tensorflow.keras.models import Model, load_model
1616
from tensorflow.keras.optimizers import Adam
17-
17+
import threading
1818
from utils import Portfolio
1919

2020
# Tensorflow GPU configuration
@@ -143,10 +143,12 @@ def experience_replay(self):
143143
y_batch = np.vstack(y_batch)
144144
states_batch = np.vstack([tup[0] for tup in mini_batch]) # batch_size * state_dim
145145
actions_batch = np.vstack([tup[1] for tup in mini_batch]) # batch_size * action_dim
146-
146+
lock=threading.Lock()
147+
lock.acquire()
147148
# update critic by minimizing the loss
148149
loss = self.critic.model.train_on_batch([states_batch, actions_batch], y_batch)
149150
print("Critic Loss", loss)
151+
lock.release()
150152
# update actor using the sampled policy gradients
151153
action_grads_batch = self.critic.gradients(states_batch, self.actor.model.predict(states_batch)) # batch_size * action_dim
152154
self.actor.train(states_batch, action_grads_batch)

__pycache__/Hard_A2C.cpython-38.pyc

86 Bytes
Binary file not shown.

train_asynchronous_A2C.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Sat Jul 2 13:12:25 2022
4+
5+
@author: Abhilash
6+
"""
7+
8+
import argparse
9+
import importlib
10+
import logging
11+
import sys
12+
import time
13+
import numpy as np
14+
from utils import *
15+
from Agent import *
16+
from DDQN_Agent import *
17+
from DuelingDDQN_Agent import *
18+
from AC_Agent import *
19+
from Hard_A2C import *
20+
import threading
21+
from threading import Lock,Thread
22+
23+
parser = argparse.ArgumentParser(description='command line options')
24+
parser.add_argument('--stock_name', action="store", dest="stock_name", default='S&P_2010-2015', help="stock name")
25+
parser.add_argument('--window_size', action="store", dest="window_size", default=10, type=int, help="span (days) of observation")
26+
parser.add_argument('--num_episode', action="store", dest="num_episode", default=10, type=int, help='episode number')
27+
parser.add_argument('--initial_balance', action="store", dest="initial_balance", default=50000, type=int, help='initial balance')
28+
inputs = parser.parse_args()
29+
30+
#model_name="DQN"
31+
#model_name="DDQN"
32+
#model_name="DuelingDDQN"
33+
#model_name="AC"
34+
model_name="Hard_A2C"
35+
#model_name="A3C"
36+
37+
stock_name = inputs.stock_name
38+
window_size = inputs.window_size
39+
num_episode = inputs.num_episode
40+
initial_balance = inputs.initial_balance
41+
stock_prices = stock_close_prices(stock_name)
42+
trading_period = len(stock_prices) - 1
43+
returns_across_episodes = []
44+
num_experience_replay = 0
45+
delta=1e-7
46+
action_dict = {0: 'Hold', 1: 'Buy', 2: 'Sell'}
47+
# configure logging
48+
logging.basicConfig(filename=f'logs/{model_name}_training_{stock_name}.log', filemode='w',
49+
format='[%(asctime)s.%(msecs)03d %(filename)s:%(lineno)3s] %(message)s',
50+
datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
51+
52+
logging.info(f'Trading Object: {stock_name}')
53+
logging.info(f'Trading Period: {trading_period} days')
54+
logging.info(f'Window Size: {window_size} days')
55+
logging.info(f'Training Episode: {num_episode}')
56+
logging.info(f'Model Name: {model_name}')
57+
logging.info('Initial Portfolio Value: ${:,}'.format(initial_balance))
58+
59+
60+
61+
#agent = DQN_Agent(state_dim=window_size + 3, balance=initial_balance)
62+
#agent=DDQN_Agent(state_dim=window_size + 3, balance=initial_balance)
63+
#agent=DuelingDDQN_Agent(state_dim=window_size + 3, balance=initial_balance)
64+
#agent=AC_Agent(state_dim=window_size + 3, balance=initial_balance)
65+
agent=Hard_A2C_Agent(state_dim=window_size + 3, balance=initial_balance)
66+
lock=Lock()
67+
def train(n_threads):
68+
# Instantiate one environment per thread
69+
70+
# Create threads
71+
threads = [threading.Thread(
72+
target=train_,
73+
daemon=True,
74+
args=()) for _ in range(n_threads)]
75+
76+
for t in threads:
77+
time.sleep(2)
78+
t.start()
79+
80+
81+
def hold(actions):
82+
# encourage selling for profit and liquidity
83+
next_probable_action = np.argsort(actions)[1]
84+
if next_probable_action == 2 and len(agent.inventory) > 0:
85+
max_profit = stock_prices[t] - min(agent.inventory)
86+
if max_profit > 0:
87+
sell(t)
88+
actions[next_probable_action] = 1 # reset this action's value to the highest
89+
return 'Hold', actions
90+
91+
def buy(t):
92+
if agent.balance > stock_prices[t]:
93+
agent.balance -= stock_prices[t]
94+
agent.inventory.append(stock_prices[t])
95+
return 'Buy: ${:.2f}'.format(stock_prices[t])
96+
97+
def sell(t):
98+
if len(agent.inventory) > 0:
99+
agent.balance += stock_prices[t]
100+
bought_price = agent.inventory.pop(0)
101+
profit = stock_prices[t] - bought_price
102+
global reward
103+
reward = profit
104+
return 'Sell: ${:.2f} | Profit: ${:.2f}'.format(stock_prices[t], profit)
105+
106+
def train_(num_experience_replay=0):
107+
start_time = time.time()
108+
for e in range(1, num_episode + 1):
109+
logging.info(f'\nEpisode: {e}/{num_episode}')
110+
111+
agent.reset() # reset to initial balance and hyperparameters
112+
state = generate_combined_state(0, window_size, stock_prices, agent.balance, len(agent.inventory))
113+
114+
for t in range(1, trading_period + 1):
115+
if t % 100 == 0:
116+
logging.info(f'\n-------------------Period: {t}/{trading_period}-------------------')
117+
118+
reward = 0
119+
next_state = generate_combined_state(t, window_size, stock_prices, agent.balance, len(agent.inventory))
120+
previous_portfolio_value = len(agent.inventory) * stock_prices[t] + agent.balance
121+
122+
if model_name == 'AC' or model_name=='Hard_A2C':
123+
actions = agent.act(state, t)
124+
action = np.argmax(actions)
125+
else:
126+
127+
actions = agent.model.predict(state)[0]
128+
action = agent.act(state)
129+
130+
# execute position
131+
logging.info('Step: {}\tHold signal: {:.4} \tBuy signal: {:.4} \tSell signal: {:.4}'.format(t, actions[0], actions[1], actions[2]))
132+
if action != np.argmax(actions): logging.info(f"\t\t'{action_dict[action]}' is an exploration.")
133+
if action == 0: # hold
134+
execution_result = hold(actions)
135+
if action == 1: # buy
136+
execution_result = buy(t)
137+
if action == 2: # sell
138+
execution_result = sell(t)
139+
140+
# check execution result
141+
if execution_result is None:
142+
reward -= treasury_bond_daily_return_rate() * agent.balance # missing opportunity
143+
else:
144+
if isinstance(execution_result, tuple): # if execution_result is 'Hold'
145+
actions = execution_result[1]
146+
execution_result = execution_result[0]
147+
logging.info(execution_result)
148+
149+
# calculate reward
150+
current_portfolio_value = len(agent.inventory) * stock_prices[t] + agent.balance
151+
unrealized_profit = current_portfolio_value - agent.initial_portfolio_value
152+
reward += unrealized_profit+delta
153+
154+
agent.portfolio_values.append(current_portfolio_value)
155+
agent.return_rates.append((current_portfolio_value - previous_portfolio_value) / previous_portfolio_value)
156+
157+
done = True if t == trading_period else False
158+
agent.remember(state, actions, reward, next_state, done)
159+
160+
# update state
161+
state = next_state
162+
163+
# experience replay
164+
if len(agent.memory) > agent.buffer_size:
165+
num_experience_replay += 1
166+
print("Getting Loss")
167+
#lock.acquire()
168+
loss = agent.experience_replay()
169+
#lock.release()
170+
logging.info('Episode: {}\tLoss: {:.2f}\tAction: {}\tReward: {:.2f}\tBalance: {:.2f}\tNumber of Stocks: {}'.format(e, loss, action_dict[action], reward, agent.balance, len(agent.inventory)))
171+
agent.tensorboard.on_batch_end(num_experience_replay, {'loss': loss, 'portfolio value': current_portfolio_value})
172+
173+
if done:
174+
portfolio_return = evaluate_portfolio_performance(agent, logging)
175+
returns_across_episodes.append(portfolio_return)
176+
177+
# save models periodically
178+
if e % 5 == 0:
179+
if model_name == 'DQN':
180+
agent.model.save('saved_models/DQN_ep' + str(e) + '.h5')
181+
elif model_name=='DDQN':
182+
agent.model.save('saved_models/DDQN_ep' + str(e) + '.h5')
183+
elif model_name=='DuelingDDQN':
184+
agent.model.save('saved_models/DuelingDDQN_ep' + str(e) + '.h5')
185+
186+
#tbd-> on policy
187+
elif model_name == 'AC':
188+
agent.actor.model.save_weights('saved_models/AC_ep{}_actor.h5'.format(str(e)))
189+
agent.critic.model.save_weights('saved_models/AC_ep{}_critic.h5'.format(str(e)))
190+
elif model_name == 'Hard_A2C':
191+
agent.actor.model.save_weights('saved_models/A2C_ep{}_actor.h5'.format(str(e)))
192+
agent.critic.model.save_weights('saved_models/A2C_ep{}_critic.h5'.format(str(e)))
193+
194+
logging.info('model saved')
195+
196+
logging.info('total training time: {0:.2f} min'.format((time.time() - start_time)/60))
197+
plot_portfolio_returns_across_episodes(model_name, returns_across_episodes)
198+
199+
200+
if __name__=='__main__':
201+
train(10)

0 commit comments

Comments
 (0)