gpt4 book ai didi

python - `uniq` 用于 2D Theano 张量

转载 作者:行者123 更新时间:2023-11-28 19:17:47 24 4
gpt4 key购买 nike

我有这个 Numpy 代码:

def uniq(seq):
"""
Like Unix tool uniq. Removes repeated entries.
:param seq: numpy.array. (time,) -> label
:return: seq
"""
diffs = np.ones_like(seq)
diffs[1:] = seq[1:] - seq[:-1]
idx = diffs.nonzero()
return seq[idx]

现在,我想扩展它以支持二维数组并使其使用 Theano。它在 GPU 上应该很快。

我将获得一个包含多个序列的数组,格式为 (time,batch) 的多个批处理,以及一个间接指定每个序列长度的 time_mask

我目前的尝试:

def uniq_with_lengths(seq, time_mask):
# seq is (time,batch) -> label
# time_mask is (time,batch) -> 0 or 1
num_batches = seq.shape[1]
diffs = T.ones_like(seq)
diffs = T.set_subtensor(diffs[1:], seq[1:] - seq[:-1])
time_range = T.arange(seq.shape[0]).dimshuffle([0] + ['x'] * (seq.ndim - 1))
idx = T.switch(T.neq(diffs, 0) * time_mask, time_range, -1)
seq_lens = T.sum(T.ge(idx, 0), axis=0) # (batch,) -> len
max_seq_len = T.max(seq_lens)

# I don't know any better way without scan.
def step(batch_idx, out_seq_b1):
out_seq = seq[T.ge(idx[:, batch_idx], 0).nonzero(), batch_idx][0]
return T.concatenate((out_seq, T.zeros((max_seq_len - out_seq.shape[0],), dtype=seq.dtype)))

out_seqs, _ = theano.scan(
step,
sequences=[T.arange(num_batches)],
outputs_info=[T.zeros((max_seq_len,), dtype=seq.dtype)]
)
# out_seqs is (batch,max_seq_len)
return out_seqs.T, seq_lens

如何直接构造out_seqs

我会做类似 out_seqs = seq[idx] 的事情,但我不确定如何表达。

最佳答案

这是一个仅解决部分任务的快速答案:

def compile_theano_uniq(x):
diffs = x[1:] - x[:-1]
diffs = tt.concatenate([tt.ones_like([x[0]], dtype=diffs.dtype), diffs])
y = diffs.nonzero_values()
return theano.function(inputs=[x], outputs=y)

theano_uniq = compile_theano_uniq(tt.vector(dtype='int32'))

关键是 nonzero_values()

更新:我无法想象不使用 theano.scan 有什么方法可以做到这一点。明确地说,使用 0 作为填充,我假设给定输入

1 1 2 3 3 4 0
1 2 2 2 3 3 4
1 2 3 4 5 0 0

你希望输出是

1 2 3 4 0 0 0
1 2 3 4 0 0 0
1 2 3 4 5 0 0

甚至

1 2 3 4 0
1 2 3 4 0
1 2 3 4 5

您可以在不使用扫描的情况下识别要保留的项目的索引。然后要么需要从头开始构造一个新的张量,要么要保留一些值以移动以使序列连续。如果没有 theano.scan,这两种方法似乎都不可行。

关于python - `uniq` 用于 2D Theano 张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/31379971/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com