gpt4 book ai didi

python - PyTorch 不能 pickle lambda

转载 作者:行者123 更新时间:2023-12-05 03:31:43 25 4
gpt4 key购买 nike

我有一个使用自定义 LambdaLayer 的模型,如下所示:

class LambdaLayer(LightningModule):
def __init__(self, fun):
super(LambdaLayer, self).__init__()
self.fun = fun

def forward(self, x):
return self.fun(x)


class TorchCatEmbedding(LightningModule):
def __init__(self, start, end):
super(TorchCatEmbedding, self).__init__()
self.lb = LambdaLayer(lambda x: x[:, start:end])
self.embedding = torch.nn.Embedding(50, 5)

def forward(self, inputs):
o = self.lb(inputs).to(torch.int32)
o = self.embedding(o)
return o.squeeze()

该模型在 CPU 或 1 个 GPU 上运行完美。但是,当使用 PyTorch Lightning 在 2+ GPU 上运行它时,会发生此错误:

AttributeError: Can't pickle local object 'TorchCatEmbedding.__init__.<locals>.<lambda>'

这里使用 lambda 函数的目的是给定一个 inputs 张量,我只想将 inputs[:, start:end] 传递给 嵌入层。

我的问题:

  • 在这种情况下是否有替代方法来使用 lambda?
  • 如果不是,应该怎么做才能让 lambda 函数在这种情况下工作?

最佳答案

所以问题不在于 lambda 函数本身,而是 pickle 不适用于不仅仅是模块级函数的函数(pickle 处理函数的方式就像对某些模块级名称的引用) .所以,不幸的是,如果你需要捕获 startend 参数,你将无法使用闭包,你通常只需要像这样的东西:

def function_maker(start, end):
def function(x):
return x[:, start:end]
return function

但是就 pickle 问题而言,这会让您回到起点。

所以,尝试这样的事情:

class Slicer:
def __init__(self, start, end):
self.start = start
self.end = end
def __call__(self, x):
return x[:, self.start:self.end])

然后你可以使用:

LambdaLayer(Slicer(start, end))

我不熟悉 PyTorch,我很惊讶它不提供使用不同序列化后端的能力。悲伤/dill例如,project 可以 pickle 任意函数,而且通常更容易使用它。但我相信以上应该可以解决问题。

关于python - PyTorch 不能 pickle lambda,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/70608810/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com