gpt4 book ai didi

python - 在 TensorFlow 中,如何沿着参差不齐的维度索引参差不齐的张量?

转载 作者:太空宇宙 更新时间:2023-11-04 04:16:12 28 4
gpt4 key购买 nike

我需要通过沿着参差不齐的维度进行索引来获取参差不齐的张量中的值。一些索引工作([:, :x][:, -x:][:, x:y]),但是不是直接索引([:, x]):

R = tf.RaggedTensor.from_tensor([[1, 2, 3], [4, 5, 6]])
print(R[:, :2]) # RaggedTensor([[1, 2], [4, 5]])
print(R[:, 1:2]) # RaggedTensor([[2], [5]])
print(R[:, 1]) # ValueError: Cannot index into an inner ragged dimension.

documentation解释为什么会失败:

RaggedTensors supports multidimensional indexing and slicing, with one restriction: indexing into a ragged dimension is not allowed. This case is problematic because the indicated value may exist in some rows but not others. In such cases, it's not obvious whether we should (1) raise an IndexError; (2) use a default value; or (3) skip that value and return a tensor with fewer rows than we started with. Following the guiding principles of Python ("In the face of ambiguity, refuse the temptation to guess" ), we currently disallow this operation.

这是有道理的,但我如何实际实现选项 1、2 和 3?我必须将参差不齐的数组转换为张量的 Python 数组,然后手动迭代它们吗?有没有更有效的解决方案?一种无需通过 Python 解释器即可在 TensorFlow 图中 100% 工作的方法?

最佳答案

如果你有一个 2D RaggedTensor,那么你可以通过以下方式获得行为 (3):

def get_column_slice_v3(rt, column):
assert column >= 0 # Negative column index not supported
slice = rt[:, column:column+1]
return slice.flat_values

您可以通过添加 rt.nrows() == tf.size(slice.flat_values) 的断言来获得行为 (1):

def get_column_slice_v1(rt, column):
assert column >= 0 # Negative column index not supported
slice = rt[:, column:column+1]
with tf.assert_equal(rt.nrows(), tf.size(slice.flat_values):
return tf.identity(slice.flat_values)

要获得行为 (2),我认为最简单的方法可能是连接一个默认值向量,然后再次切片:

def get_colum_slice_v2(rt, column, default=None):
assert column >= 0 # Negative column index not supported
slice = rt[:, column:column+1]
if default is None:
defaults = tf.zeros([slice.nrows(), 1], slice.dtype)
ele:
defaults = tf.fill([slice.nrows(), 1], default)
slice_plus_default = tf.concat([rt, defaults], axis=1)
slice2 = slice_plus_defaults[:1]
return slice2.flat_values

可以扩展它们以支持高维参差不齐的张量,但逻辑会变得有点复杂。还应该可以扩展它们以支持负列索引。

关于python - 在 TensorFlow 中,如何沿着参差不齐的维度索引参差不齐的张量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55368272/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com