gpt4 book ai didi

python - 您可以反转 PyTorch 神经网络并激活输出的输入吗?

转载 作者:行者123 更新时间:2023-12-02 04:39:30 25 4
gpt4 key购买 nike

我们可以激活神经网​​络的输出来深入了解神经元如何连接到输入特征吗?

如果我从 PyTorch 教程中获取一个基本的神经网络示例。以下是 f(x,y) 训练示例。

import torch

N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.ReLU(),
torch.nn.Linear(H, D_out),
)

loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4
for t in range(500):
y_pred = model(x)
loss = loss_fn(y_pred, y)
model.zero_grad()
loss.backward()
with torch.no_grad():
for param in model.parameters():
param -= learning_rate * param.grad

在我完成网络训练以根据 x 输入预测 y 后。是否可以反转经过训练的神经网络,以便它现在可以根据 y 输入预测 x

我不希望 y 与训练 y 输出的原始输入相匹配。因此,我希望看到模型激活哪些功能来匹配 xy

如果可能,那么如何在不破坏所有权重和连接的情况下重新排列Sequential模型?

最佳答案

这是可能的,但仅适用于非常特殊的情况。对于前馈网络(顺序),每一层都需要是可逆的;这意味着以下参数分别适用于每一层。与一层相关的变换为 y = activate(W*x + b) ,其中 W 是权重矩阵,b 是偏差向量。为了求解x,我们需要执行以下步骤:

  1. 反向激活;但并非所有激活函数都有反函数。例如,ReLU 函数在 (-inf, 0) 上没有逆函数。另一方面,如果我们使用 tanh,我们可以使用它的倒数,即 0.5 * log((1 + x)/(1 - x))
  2. 求解 W*x = inverse_activation(y) - b 得到 x;为了存在唯一的解决方案,W 必须具有相似的行和列秩,并且 det(W) 必须非零。前者我们可以通过选择特定的网络架构来控制,而后者则取决于训练过程。

因此,要使神经网络可逆,它必须具有非常具体的架构:所有层必须具有相同数量的输入和输出神经元(即平方权重矩阵),并且激活函数都需要可逆。

代码:使用 PyTorch,我们必须手动进行网络反演,无论是求解线性方程组还是找到逆激活函数。考虑以下 1 层神经网络的示例(因为这些步骤分别适用于每一层,因此将其扩展到超过 1 层是微不足道的):

import torch

N = 10 # number of samples
n = 3 # number of neurons per layer

x = torch.randn(N, n)

model = torch.nn.Sequential(
torch.nn.Linear(n, n), torch.nn.Tanh()
)

y = model(x)

z = y # use 'z' for the reverse result, start with the model's output 'y'.
for step in list(model.children())[::-1]:
if isinstance(step, torch.nn.Linear):
z = z - step.bias[None, ...]
z = z[..., None] # 'torch.solve' requires N column vectors (i.e. shape (N, n, 1)).
z = torch.solve(z, step.weight)[0]
z = torch.squeeze(z) # remove the extra dimension that we've added for 'torch.solve'.
elif isinstance(step, torch.nn.Tanh):
z = 0.5 * torch.log((1 + z) / (1 - z))

print('Agreement between x and z: ', torch.dist(x, z))

关于python - 您可以反转 PyTorch 神经网络并激活输出的输入吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59878319/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com