gpt4 book ai didi

python - 嵌套 tf.map_fn 性能缓慢

转载 作者:行者123 更新时间:2023-12-01 06:42:13 31 4
gpt4 key购买 nike

我运行下面的代码,以便从给定的索引矩阵获取填充矩阵(words_chars_ids 形状为 (6,200,20))。结果的形状为 (6,200,20,emb_size),其中对于输出中的每个条目,它都保存一个由 1 或 0 组成的张量(大小为 emb_size)。

我有两个问题:

  1. 有没有更优雅的方式来实现这个方法(没有嵌套的map_fn)

  2. 性能似乎相当慢 - 是否有更有效的方法来实现结果?

def get_padding_mask(words_chars_ids, emb_size):

padding_mask = tf.map_fn(
lambda x: tf.map_fn(
lambda y: tf.map_fn(
lambda z: tf.cond(tf.less(z, 1),
lambda: tf.zeros([emb_size, ], dtype=tf.int32),
lambda: tf.ones([emb_size, ], dtype=tf.int32)
),
y),
x),
words_chars_ids)
return padding_mask

最佳答案

您可以简单地执行相同的操作:

def get_padding_mask(words_chars_ids, emb_size):
mask = tf.dtypes.cast(words_chars_ids >= 1, tf.int32)
return tf.tile(tf.expand_dims(mask, -1), [1, 1, 1, emb_size])

关于python - 嵌套 tf.map_fn 性能缓慢,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59395593/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com