gpt4 book ai didi

python - 从出队的一批值、索引、形状创建 SparseTensor

转载 作者:行者123 更新时间:2023-12-01 03:19:06 24 4
gpt4 key购买 nike

我正在尝试将张量(其中一些是稀疏的)从 RAM 提供给模型。我创建了一个 PaddingFIFOQueue,假设稀疏值无法通过其他方法从 RAM 中批量处理(如果情况并非如此,请告诉我),我将稀疏张量的索引、值和形状分别排队到其中。由于序列的长度不同,因此需要对其进行填充。

我正在将以下内容出列...

indices = [batch size, None, 2]
values = [batch size, None]
shapes = [batch size, 2]

我尝试使用这些值来创建 SparseTensor,但收到以下错误。

ValueError: Shape (512, ?, 2) must have rank 2

主要部分代码如下...

indices, values, shapes = self.queue.dequeue_many(batch_size)
sp_tensor = tf.SparseTensor(indices, values, shapes)

我认为这是因为 SparseTensor 需要一个 2 级张量,而不是一批 2 级张量(如错误消息所示),但我不确定如何转换该批处理。

最佳答案

这可以通过一些平铺和 reshape 来实现:

import tensorflow as tf

def sparse_tensor_merge(indices, values, shape):
"""Creates a SparseTensor from batched indices, values, and shapes.

Args:
indices: A [batch_size, N, D] integer Tensor.
values: A [batch_size, N] Tensor of any dtype.
shape: A [batch_size, D] Integer Tensor.
Returns:
A SparseTensor of dimension D + 1 with batch_size as its first dimension.
"""
merged_shape = tf.reduce_max(shape, axis=0)
batch_size, elements, shape_dim = tf.unstack(tf.shape(indices))
index_range_tiled = tf.tile(tf.range(batch_size)[..., None],
tf.stack([1, elements]))[..., None]
merged_indices = tf.reshape(
tf.concat([tf.cast(index_range_tiled, tf.int64), indices], axis=2),
[-1, 1 + tf.size(merged_shape)])
merged_values = tf.reshape(values, [-1])
return tf.SparseTensor(
merged_indices, merged_values,
tf.concat([[tf.cast(batch_size, tf.int64)], merged_shape], axis=0))

例如:

batch_indices = tf.constant(
[[[0, 0], [0, 1]],
[[0, 0], [1, 1]]], dtype=tf.int64)
batch_values = tf.constant(
[[0.1, 0.2],
[0.3, 0.4]])
batch_shapes = tf.constant(
[[2, 2],
[3, 2]], dtype=tf.int64)

merged = sparse_tensor_merge(batch_indices, batch_values, batch_shapes)

with tf.Session():
print(merged.eval())

打印:

SparseTensorValue(indices=array([[0, 0, 0],
[0, 0, 1],
[1, 0, 0],
[1, 1, 1]]),
values=array([ 0.1 , 0.2 , 0.30000001, 0.40000001],
dtype=float32),
dense_shape=array([2, 3, 2]))

请注意,组合后的 SparseTensor 的形状是原始批处理维度,后跟该批处理中每个其他维度的最大值。

关于python - 从出队的一批值、索引、形状创建 SparseTensor,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42147362/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com