gpt4 book ai didi

pytorch - 如何使用 PyTorch DataLoader 进行强化学习?

转载 作者:行者123 更新时间:2023-12-04 15:42:40 27 4
gpt4 key购买 nike

我正在尝试在 PyTorch 中建立一个通用的强化学习框架,以利用所有利用 PyTorch DataSet 和 DataLoader 的高级实用程序,例如 Ignite 或 FastAI,但我遇到了具有动态特性的阻碍强化学习数据:

  • 数据项是从代码生成的,而不是从文件中读取的,它们依赖于之前的操作和模型结果,因此每个 nextItem 调用都需要访问模型状态。
  • 训练集的长度不是固定的,所以我需要一个动态的批量大小以及一个动态的总数据集大小。我更喜欢使用终止条件函数而不是数字。我可以“可能”用填充来做到这一点,就像在 NLP 句子处理中一样,但这是一个真正的黑客。

  • 到目前为止,我的 Google 和 StackOverflow 搜索已经产生了 zilch。这里有人知道将 DataLoader 或 DataSet 与强化学习结合使用的现有解决方案或变通方法吗?我讨厌失去对所有依赖于这些库的现有库的访问。

    最佳答案

    Here是一个基于 PyTorch 的框架和 here是来自 Facebook 的东西。

    当谈到您的问题时(毫无疑问,这是崇高的追求):

    您可以轻松创建 torch.utils.data.Dataset依赖于任何东西,包括模型,像这样的东西(原谅弱抽象,这只是为了证明一点):

    import typing

    import torch
    from torch.utils.data import Dataset


    class Environment(Dataset):
    def __init__(self, initial_state, actor: torch.nn.Module, max_interactions: int):
    self.current_state = initial_state
    self.actor: torch.nn.Module = actor
    self.max_interactions: int = max_interactions

    # Just ignore the index
    def __getitem__(self, _):
    self.current_state = self.actor.update(self.current_state)
    return self.current_state.get_data()

    def __len__(self):
    return self.max_interactions

    假设, torch.nn.Module - 类似网络有某种 update环境状态的变化。总而言之,它只是一个 Python 结构,所以你可以用它来建模很多东西。

    您可以指定 max_interactions差不多 infinite或者,如果需要,您可以在训练期间通过一些回调即时更改它(因为 __len__ 可能会在整个代码中多次调用)。环境还可以提供 batches而不是 sample 。

    torch.utils.data.DataLoader batch_sampler参数,在那里你可以生成不同长度的批次。由于网络不依赖于第一维,您也可以从那里返回任何您想要的批量大小。

    顺便提一句。如果每个样本的长度不同,则应使用填充,不同的批次大小与此无关。

    关于pytorch - 如何使用 PyTorch DataLoader 进行强化学习?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57258323/

    27 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com