gpt4 book ai didi

python - 谷歌 JAX 一维卷积神经网络

转载 作者:行者123 更新时间:2023-12-05 06:15:43 27 4
gpt4 key购买 nike

我正在尝试使用 stax.GeneralConv() ( https://jax.readthedocs.io/en/latest/_modules/jax/experimental/stax.html#GeneralConv ) 在 Google Jax 中实现一维卷积神经网络。我有一个包含 18 个条目的一维输入数组和包含 6 个条目的输出数组。我想实现一个内核宽度为 3 的 CNN,如下所示:

init_random_params, conv_net = stax.serial(
GeneralConv(('NC','IO','NC'),1,(3,),padding='SAME'), # dimension_numbers = ('NC','IO','NC')
LogSoftmax,
Dense(6),
)

初始网络参数:

rng = jax.random.PRNGKey(0)
_, init_params = init_random_params(rng, (18,))

但是我得到以下错误:

stax.py", line 75, in <listcomp>
next(filter_shape_iter) for c in rhs_spec]

IndexError: tuple index out of range

stax 要求维数 rhs_spec 至少有 2 个字符长,但我使用的是一维过滤器。有人知道如何解决这个问题吗?

最佳答案

我自己还没有尝试过,但我希望一维卷积仍然需要一个方向来进行卷积,例如

Conv2d = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))
Conv1d = functools.partial(GeneralConv, ('NHC', 'HIO', 'NHC'))

换句话说,放下 W 轴以从 2d 到 1d 卷积。

NHC对应的输入shape为(batch_size, sequence_length, num_channels)

请注意,即使 channel 数可能为 1,您仍然需要包括该轴,因为 GeneralConv 会沿着 num_channels = input_shape['NHC'.index] 行进行索引查找('C')].

关于python - 谷歌 JAX 一维卷积神经网络,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62357214/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com