-
Notifications
You must be signed in to change notification settings - Fork 28
/
trpo.py
27 lines (19 loc) · 889 Bytes
/
trpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# -*- coding: utf8 -*-
import tensorflow as tf
from yarll.agents.tf2.ppo.ppo import PPO, PPOContinuous, PPODiscrete, PPODiscreteCNN
from yarll.misc.network_ops import kl_divergence
def trpo_loss(old_log, new_log, beta, advantage):
return tf.exp(new_log - old_log) * advantage - beta * kl_divergence(old_log, new_log)
class TRPO(PPO):
"""Trust Region Policy Optimization agent."""
def __init__(self, env, monitor_path: str, video=False, **usercfg) -> None:
usercfg["kl_coef"] = 1.0 # beta
super().__init__(env, monitor_path, video=video, **usercfg)
def _actor_loss(self, old_logprob, new_logprob, advantage):
return trpo_loss(old_logprob, new_logprob, self.config["kl_coef"], advantage)
class TRPODiscrete(TRPO, PPODiscrete):
pass
class TRPODiscreteCNN(TRPO, PPODiscreteCNN):
pass
class TRPOContinuous(TRPO, PPOContinuous):
pass