gpt4 book ai didi

python - 与 pyTorch 相比,Jax/Flax(非常)慢的 RNN 前向传递?

转载 作者:行者123 更新时间:2023-12-05 05:54:00 26 4
gpt4 key购买 nike

我最近在 Jax 中实现了一个双层 GRU 网络,对其性能感到失望(无法使用)。

因此,我尝试与 Pytorch 进行了一些速度比较。

最小工作示例

这是我的最小工作示例,输出是使用 GPU 运行时在 Google Colab 上创建的。 notebook in colab

import flax.linen as jnn 
import jax
import torch
import torch.nn as tnn
import numpy as np
import jax.numpy as jnp

def keyGen(seed):
key1 = jax.random.PRNGKey(seed)
while True:
key1, key2 = jax.random.split(key1)
yield key2
key = keyGen(1)

hidden_size=200
seq_length = 1000
in_features = 6
out_features = 4
batch_size = 8

class RNN_jax(jnn.Module):

@jnn.compact
def __call__(self, x, carry_gru1, carry_gru2):
carry_gru1, x = jnn.GRUCell()(carry_gru1, x)
carry_gru2, x = jnn.GRUCell()(carry_gru2, x)
x = jnn.Dense(4)(x)
x = x/jnp.linalg.norm(x)
return x, carry_gru1, carry_gru2

class RNN_torch(tnn.Module):
def __init__(self, batch_size, hidden_size, in_features, out_features):
super().__init__()

self.gru = tnn.GRU(
input_size=in_features,
hidden_size=hidden_size,
num_layers=2
)

self.dense = tnn.Linear(hidden_size, out_features)

self.init_carry = torch.zeros((2, batch_size, hidden_size))

def forward(self, X):
X, final_carry = self.gru(X, self.init_carry)
X = self.dense(X)
return X/X.norm(dim=-1).unsqueeze(-1).repeat((1, 1, 4))

rnn_jax = RNN_jax()
rnn_torch = RNN_torch(batch_size, hidden_size, in_features, out_features)

Xj = jax.random.normal(next(key), (seq_length, batch_size, in_features))
Yj = jax.random.normal(next(key), (seq_length, batch_size, out_features))
Xt = torch.from_numpy(np.array(Xj))
Yt = torch.from_numpy(np.array(Yj))

initial_carry_gru1 = jnp.zeros((batch_size, hidden_size))
initial_carry_gru2 = jnp.zeros((batch_size, hidden_size))

params = rnn_jax.init(next(key), Xj[0], initial_carry_gru1, initial_carry_gru2)

def forward(params, X):

carry_gru1, carry_gru2 = initial_carry_gru1, initial_carry_gru2

Yhat = []
for x in X: # x.shape = (batch_size, in_features)
yhat, carry_gru1, carry_gru2 = rnn_jax.apply(params, x, carry_gru1, carry_gru2)
Yhat.append(yhat) # y.shape = (batch_size, out_features)

#return jnp.concatenate(Y, axis=0)

jitted_forward = jax.jit(forward)

结果
# uncompiled jax version
%time forward(params, Xj)

CPU 时间:用户 7 分钟 17 秒,系统:8.18 秒,总计:7 分钟 25 秒 墙时间:7 分钟 17 秒

# time for compiling
%time jitted_forward(params, Xj)

CPU 时间:用户 8 分钟 9 秒,系统:4.46 秒,总计:8 分钟 13 秒 墙时间:8 分钟 12 秒

# compiled jax version
%timeit jitted_forward(params, Xj)

最慢的运行时间比最快的运行时间长 204.20 倍。这可能意味着正在缓存中间结果。 10000 次循环,5 次循环中的最佳循环:每次循环 115 微秒

# torch version
%timeit lambda: rnn_torch(Xt)

10000000 次循环,5 次最佳:每次循环 65.7 ns

问题

为什么我的 Jax 实现如此缓慢?我做错了什么?

另外,为什么编译要花这么长时间?序列没那么长..

谢谢你:)

最佳答案

JAX 代码编译缓慢的原因是在 JIT 编译期间 JAX 展开了循环。因此,就 XLA 编译而言,您的函数实际上非常大:您调用 rnn_jax.apply() 1000 次,并且编译时间往往大致是语句数量的二次方。

相比之下,您的 pytorch 函数不使用 Python 循环,因此在幕后它依赖于运行速度更快的矢量化操作。

任何时候你在 Python 中使用 for 循环数据,一个很好的赌注是你的代码会很慢:无论你使用的是 JAX、torch、numpy、pandas 等,都是如此. 我建议在 JAX 中找到一种解决问题的方法,该方法依赖于矢量化操作而不是依赖于缓慢的 Python 循环。

关于python - 与 pyTorch 相比,Jax/Flax(非常)慢的 RNN 前向传递?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69767707/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com