gpt4 book ai didi

tensorflow - 如何在 tensorflow 中为张量的子集赋值?

转载 作者:行者123 更新时间:2023-12-03 00:41:46 24 4
gpt4 key购买 nike

这个问题分为两个部分:

(1) 在 tensorflow 中更新张量子集的最佳方法是什么?我看到了几个相关的问题:

Adjust Single Value within Tensor -- TensorFlowHow to update a subset of 2D tensor in Tensorflow?

而且我知道可以使用 Variable.assign() (和/或 scatter_update 等)分配 Variable 对象,但对我来说,tensorflow 没有更直观的方式来更新部分似乎很奇怪张量对象的。我已经搜索了tensorflow api 文档和stackoverflow很长一段时间了,似乎找不到比上面链接中提供的解决方案更简单的解决方案。这看起来特别奇怪,特别是考虑到 Theano 有与 Tensor.set_subtensor() 等效的版本。我是否遗漏了一些东西,或者目前没有简单的方法可以通过tensorflow api来做到这一点?

(2)如果有更简单的方法,它是否可微?

谢谢!

最佳答案

我认为张量的不变性是构建计算图所必需的;你不能让一个张量更新它的一些值而不成为另一个张量,否则在它之前的图中将没有任何内容。 The same issue comes up in Autograd .

可以使用 bool 掩码来做到这一点(但很难看)(使它们成为变量并使用 assign ,或者甚至在 numpy 中预先定义它们)。这将是可微分的,但实际上我会避免更新子张量。

如果你真的必须这样做,我真的希望有更好的方法来做到这一点,但这里有一种使用 tf.dynamic_stitch 以一维方式做到这一点的方法。和tf.setdiff1d :

def set_subtensor1d(a, b, slice_a, slice_b):
# a[slice_a] = b[slice_b]
a_range = tf.range(a.shape[0])
_, a_from = tf.setdiff1d(a_range, a_range[slice_a])
a_to = a_from
b_from, b_to = tf.range(b.shape[0])[slice_b], a_range[slice_a]
return tf.dynamic_stitch([a_to, b_to],
[tf.gather(a, a_from),tf.gather(b, b_from)])

对于更高的维度,这可以通过滥用 reshape 来推广。 (其中 nd_slice could be implemented like this 但可能有更好的方法):

def set_subtensornd(a, b, slice_tuple_a, slice_tuple_b):
# a[*slice_tuple_a] = b[*slice_tuple_b]
a_range = tf.range(tf.reduce_prod(tf.shape(a)))
a_idxed = tf.reshape(a_range, tf.shape(a))
a_dropped = tf.reshape(nd_slice(a_idxed, slice_tuple_a), [-1])
_, a_from = tf.setdiff1d(a_range, a_dropped)
a_to = a_from
b_range = tf.range(tf.reduce_prod(tf.shape(b)))
b_idxed = tf.reshape(b_range, tf.shape(b))
b_from = tf.reshape(nd_slice(b_idxed, slice_tuple_b), [-1])
b_to = a_dropped
a_flat, b_flat = tf.reshape(a, [-1]), tf.reshape(b, [-1])
stitched = tf.dynamic_stitch([a_to, b_to],
[tf.gather(a_flat, a_from),tf.gather(b_flat, b_from)])
return tf.reshape(stitched, tf.shape(a))

我不知道这会有多慢。我猜会很慢。而且,除了在几个张量上运行它之外,我还没有对它进行太多测试。

关于tensorflow - 如何在 tensorflow 中为张量的子集赋值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40295700/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com