gpt4 book ai didi

jax - 在 jax 中进行嵌入的推荐方法是什么?

转载 作者:行者123 更新时间:2023-12-02 18:08:38 32 4
gpt4 key购买 nike

所以我的意思是你有一个分类特征 $X$(假设你已经把它变成了整数)并说你想使用特征 $A$ 将它嵌入到某个维度中,其中 $A$ 是 arity x n_embed .

通常的做法是什么?使用 for 循环和 vmap 是否正确?我不想要像 jax.nn 这样的东西,像

这样更高效的东西

https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding

例如考虑高 arity 和低嵌入 dim。

这里的 flax.linen 实现中是 jnp.take 吗? https://github.com/google/flax/blob/main/flax/linen/linear.py#L624

最佳答案

事实上,在纯 jax 中执行此操作的典型方法是使用 jnp.take。给定形状为 (num_embeddings, num_features) 的嵌入数组 A 和形状为 (n,) 的整数的分类特征 x > 然后下面为您提供嵌入查找。

jnp.take(A, x, axis=0)  # shape: (n, num_features)

如果使用 Flax,那么推荐的方法是使用 flax.linen.Embed模块,并会达到同样的效果:

import flax.linen as nn

class Model(nn.Module):
@nn.compact
def __call__(self, x):
emb = nn.Embed(num_embeddings, num_features)(x) # shape

关于jax - 在 jax 中进行嵌入的推荐方法是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/72817730/

32 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com