gpt4 book ai didi

python - 训练 keras 模型时维度如何工作?

转载 作者:太空宇宙 更新时间:2023-11-03 13:55:11 24 4
gpt4 key购买 nike

获得:

    assert q_values.shape == (len(state_batch), self.nb_actions)
AssertionError
q_values.shape <class 'tuple'>: (1, 1, 10)
(len(state_batch), self.nb_actions) <class 'tuple'>: (1, 10)

来自 sarsa 代理的 keras-rl 库:

rl.agents.sarsa.SARSAAgent#compute_batch_q_values

    batch = self.process_state_batch(state_batch)
q_values = self.model.predict_on_batch(batch)
assert q_values.shape == (len(state_batch), self.nb_actions)

这是我的代码:

class MyEnv(Env):

def __init__(self):
self._reset()

def _reset(self) -> None:
self.i = 0

def _get_obs(self) -> List[float]:
return [1] * 20

def reset(self) -> List[float]:
self._reset()
return self._get_obs()



model = Sequential()
model.add(Dense(units=20, activation='relu', input_shape=(1, 20)))
model.add(Dense(units=10, activation='softmax'))
logger.info(model.summary())

policy = BoltzmannQPolicy()
agent = SARSAAgent(model=model, nb_actions=10, policy=policy)

optimizer = Adam(lr=1e-3)
agent.compile(optimizer, metrics=['mae'])

env = MyEnv()
agent.fit(env, 1, verbose=2, visualize=True)

想知道是否有人可以向我解释应该如何设置维度以及它如何与库一起使用?我输入了一个包含 20 个输入的列表,并希望得到 10 个输出。

最佳答案

此特定错误是由您的输入形状 (1, 20) 引起的。如果您使用 (20,) 的输入形状,错误将会消失。

换句话说,SARSAAgent 需要一个输出二维张量(batch_size、nb_actions)的模型。您的模型正在输出 (batch_size, 1, 10) 的形状。您可以减少模型输入的维度或展平输出。

关于python - 训练 keras 模型时维度如何工作?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57104436/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com