gpt4 book ai didi

machine-learning - Chainer如何保存和加载DQN模型

转载 作者:行者123 更新时间:2023-11-30 08:40:32 25 4
gpt4 key购买 nike

我正在学习深度强化学习 框架Chainer。

我按照教程操作并获得了以下代码:

def train_dddqn(env):

class Q_Network(chainer.Chain):

def __init__(self, input_size, hidden_size, output_size):
super(Q_Network, self).__init__(
fc1=L.Linear(input_size, hidden_size),
fc2=L.Linear(hidden_size, hidden_size),
fc3=L.Linear(hidden_size, hidden_size // 2),
fc4=L.Linear(hidden_size, hidden_size // 2),
state_value=L.Linear(hidden_size // 2, 1),
advantage_value=L.Linear(hidden_size // 2, output_size)
)
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size

def __call__(self, x):
h = F.relu(self.fc1(x))
h = F.relu(self.fc2(h))
hs = F.relu(self.fc3(h))
ha = F.relu(self.fc4(h))
state_value = self.state_value(hs)
advantage_value = self.advantage_value(ha)
advantage_mean = (F.sum(advantage_value, axis=1) / float(self.output_size)).reshape(-1, 1)
q_value = F.concat([state_value for _ in range(self.output_size)], axis=1) + (
advantage_value - F.concat([advantage_mean for _ in range(self.output_size)], axis=1))
return q_value

def reset(self):
self.cleargrads()


Q = Q_Network(input_size=env.history_t + 1, hidden_size=100, output_size=3)
Q_ast = copy.deepcopy(Q)
optimizer = chainer.optimizers.Adam()
optimizer.setup(Q)

epoch_num = 50
step_max = len(env.data) - 1
memory_size = 200
batch_size = 50
epsilon = 1.0
epsilon_decrease = 1e-3
epsilon_min = 0.1
start_reduce_epsilon = 200
train_freq = 10
update_q_freq = 20
gamma = 0.97
show_log_freq = 5

memory = []
total_step = 0
total_rewards = []
total_losses = []

start = time.time()
for epoch in range(epoch_num):

pobs = env.reset()
step = 0
done = False
total_reward = 0
total_loss = 0

while not done and step < step_max:

# select act
pact = np.random.randint(3)
if np.random.rand() > epsilon:
pact = Q(np.array(pobs, dtype=np.float32).reshape(1, -1))
pact = np.argmax(pact.data)

# act
obs, reward, done = env.step(pact)

# add memory
memory.append((pobs, pact, reward, obs, done))
if len(memory) > memory_size:
memory.pop(0)

# train or update q
if len(memory) == memory_size:
if total_step % train_freq == 0:
shuffled_memory = np.random.permutation(memory)
memory_idx = range(len(shuffled_memory))
for i in memory_idx[::batch_size]:
batch = np.array(shuffled_memory[i:i + batch_size])
b_pobs = np.array(batch[:, 0].tolist(), dtype=np.float32).reshape(batch_size, -1)
b_pact = np.array(batch[:, 1].tolist(), dtype=np.int32)
b_reward = np.array(batch[:, 2].tolist(), dtype=np.int32)
b_obs = np.array(batch[:, 3].tolist(), dtype=np.float32).reshape(batch_size, -1)
b_done = np.array(batch[:, 4].tolist(), dtype=np.bool)

q = Q(b_pobs)

indices = np.argmax(q.data, axis=1)
maxqs = Q_ast(b_obs).data
target = copy.deepcopy(q.data)
for j in range(batch_size):
Q.reset()
loss = F.mean_squared_error(q, target)
total_loss += loss.data
loss.backward()
optimizer.update()

if total_step % update_q_freq == 0:
Q_ast = copy.deepcopy(Q)

# epsilon
if epsilon > epsilon_min and total_step > start_reduce_epsilon:
epsilon -= epsilon_decrease

# next step
total_reward += reward
pobs = obs
step += 1
total_step += 1

total_rewards.append(total_reward)
total_losses.append(total_loss)

if (epoch + 1) % show_log_freq == 0:
log_reward = sum(total_rewards[((epoch + 1) - show_log_freq):]) / show_log_freq
log_loss = sum(total_losses[((epoch + 1) - show_log_freq):]) / show_log_freq
elapsed_time = time.time() - start
print('\t'.join(map(str, [epoch + 1, epsilon, total_step, log_reward, log_loss, elapsed_time])))
start = time.time()

return Q, total_losses, total_rewards


Q, total_losses, total_rewards = train_dddqn(Environment1(train))

我的问题是如何保存和加载这个已经训练好的模型?我知道 Kreas 有一些功能,例如:model.save 和 load_model。

那么这个 Chainer 代码需要什么具体代码?

最佳答案

您可以使用serializer模块来保存/加载chainer的模型参数(Chain类)。

from chainer import serializers

Q = Q_Network(input_size=env.history_t + 1, hidden_size=100, output_size=3)
Q_ast = Q_Network(input_size=env.history_t + 1, hidden_size=100, output_size=3)

# --- train Q here... ---

# copy Q parameter into Q_ast by saving Q's parameter and load to Q_ast
serializers.save_npz('my.model', Q)
serializers.load_npz('my.model', Q_ast)

详情参见官方文档:

此外,您还可以引用chainerrl,这是一个用于强化学习的chainer库。

chainerrl 有一个实用函数 copy_param 将参数从网络 source_link 复制到 target_link

关于machine-learning - Chainer如何保存和加载DQN模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54053848/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com