gpt4 book ai didi

python - 强化学习如何通过高斯策略进行连续控制?

转载 作者:行者123 更新时间:2023-12-01 08:22:06 24 4
gpt4 key购买 nike

我正在实现 Soft-Actor-Critic 算法,但我无法理解随机策略的工作原理。我在网上搜索过,但没有找到任何有趣的网站可以很好地解释以下实现。我唯一理解的是,在随机策略的情况下,我们将其建模为高斯,并将均值和对数标准差参数化(我认为标准差是标准差),但例如:为什么我们需要对数标准差和不只是标准?

class ActorNetwork(object):
def __init__(self, act_dim, name):
self.act_dim = act_dim
self.name = name

def step(self, obs, log_std_min=-20, log_std_max=2):
with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):

h1 = tf.layers.dense(obs, 256, tf.nn.relu)
h2 = tf.layers.dense(h1, 256, tf.nn.relu)
mu = tf.layers.dense(h2, self.act_dim, None)
log_std = tf.layers.dense(h2, self.act_dim, tf.tanh)
'''
at the start we could have extremely large values for the log_stds, which could result in some actions
being either entirely deterministic or too random. To protect against that,
we'll constrain the output range of the log_stds, to be within [LOG_STD_MIN, LOG_STD_MAX]
'''
log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1)

std = tf.exp(log_std)
pi = mu + tf.random_normal(tf.shape(mu)) * std

#gaussian likelihood
pre_sum = -0.5 * (((pi - mu) / (tf.exp(log_std) + EPS)) ** 2 + 2 * log_std + np.log(2 * np.pi))
logp_pi = tf.reduce_sum(pre_sum, axis=1)

mu = tf.tanh(mu)
pi = tf.tanh(pi)

clip_pi = 1 - tf.square(pi) #pi^2
clip_up = tf.cast(clip_pi > 1, tf.float32)
clip_low = tf.cast(clip_pi < 0, tf.float32)
clip_pi = clip_pi + tf.stop_gradient((1 - clip_pi) * clip_up + (0 - clip_pi) * clip_low)

logp_pi -= tf.reduce_sum(tf.log(clip_pi + 1e-6), axis=1)

return mu, pi, logp_pi

def evaluate(self, obs): #Choose action
mu, pi, logp_pi = self.step(obs)
action_scale = 2.0 # env.action_space.high[0]

mu *= action_scale
pi *= action_scale
return mu, pi, logp_pi

最佳答案

你是对的。在高斯策略中,您将观察值(使用策略网络)映射到平均值 mu 和标准差的对数 log_std 操作。这是因为你有一个持续行动的空间。一旦您训练模型在操作空间中分配 mulog_std,您就可以计算采取由 pi 采样的操作的对数似然。

在高斯策略中,log_std 优于 std,只是因为 log_std 取 (-inf,+inf) 中的任何值,而 std 仅限于非负值。摆脱这种非负约束可以使训练变得更加容易,而且您也不会因这种转换而丢失任何信息。

关于python - 强化学习如何通过高斯策略进行连续控制?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54569726/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com