gpt4 book ai didi

python - 为什么 Pytorch 需要 DoubleTensor 而不是 FloatTensor?

转载 作者:太空宇宙 更新时间:2023-11-04 08:30:55 24 4
gpt4 key购买 nike

从我在网上看到的所有内容来看,FloatTensors 是 Pytorch 的默认值,当我创建一个张量传递给我的生成器模块时,它是一个 FloatTensor,但是当我尝试通过一个线性层运行它,它提示它需要一个 DoubleTensor

class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fully_connected = nn.Linear(100, 1024*4*4, bias=False)

def forward(self, zvec):
print(zvec.size())
fc = self.fully_connected(zvec)
return(fc.size())

gen = Generator();

gen(torch.from_numpy(np.random.normal(size=100)))

产生

RuntimeError: Expected object of type torch.DoubleTensor but found type torch.FloatTensor for argument #2 'mat2'

最佳答案

这里的问题是您的 numpy 输入使用 double 作为数据类型,相同的数据类型也应用于生成的张量。

另一方面,self.fully_connected 层的权重float。当通过层馈送数据时,应用矩阵乘法,这种乘法要求两个矩阵具有相同的数据类型。

所以你有两个解决方案:

  • 您可以将输入转换为 float :

通过改变:

gen(torch.from_numpy(np.random.normal(size=100)))

致:

gen(torch.from_numpy(np.random.normal(size=100)).float())

输入到 gen 的输入将被转换为 float

转换输入的完整工作代码:

from torch import nn
import torch
import numpy as np
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fully_connected = nn.Linear(100, 1024*4*4, bias=False)

def forward(self, zvec):
print(zvec.size())
fc = self.fully_connected(zvec)
return(fc.size())

gen = Generator();
gen(torch.from_numpy(np.random.normal(size=100)).float()) # converting network input to float

  • 或者您可以将图层权重转换为两倍:

如果您需要 double ,您还可以将权重转换为double

更改此行:

self.fully_connected = nn.Linear(100, 1024*4*4, bias=False)

只是为了:

self.fully_connected = nn.Linear(100, 1024*4*4, bias=False).double()

转换权重的完整工作代码:

from torch import nn
import torch
import numpy as np
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fully_connected = nn.Linear(100, 1024*4*4, bias=False).double() # converting layer weights to double()

def forward(self, zvec):
print(zvec.size())
fc = self.fully_connected(zvec)
return(fc.size())

gen = Generator();
gen(torch.from_numpy(np.random.normal(size=100)))

所以这两种方式都适合你,但如果你不需要 double 的额外精度,你应该使用 float as double 需要更多的计算能力。

关于python - 为什么 Pytorch 需要 DoubleTensor 而不是 FloatTensor?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53055101/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com