gpt4 book ai didi

python - theano.scan() 如何处理空序列?

转载 作者:太空宇宙 更新时间:2023-11-04 05:39:37 25 4
gpt4 key购买 nike

代码片段如下:

t = tensor.arange(1, K)
results, updates = theano.scan(fn=updatefunc, sequences=t, ...)

扫描过程将沿着 t 迭代。然而,当 K<=1 时,t 将是一个空范围,然后 theano.scan() 将崩溃。有什么办法可以解决这个问题吗?

最佳答案

您可以使用 theano.ifelse.ifelse 仅当序列中包含某些元素时才计算扫描。例如:

import theano
import theano.tensor as tt
import theano.ifelse


def step(x_t, s_tm1):
return s_tm1 + x_t


def compile():
K = tt.lscalar()
t = tt.arange(1, K)
zero = tt.constant(0, dtype='int64')
outputs, _ = theano.scan(step, sequences=[t], outputs_info=[zero])
output = theano.ifelse.ifelse(tt.gt(K, 1), outputs[-1], zero)
return theano.function([K], outputs=[output])


def main():
f = compile()
print f(3)
print f(2)
print f(1)
print f(0)
print f(-1)


main()

打印

[array(3L, dtype=int64)]
[array(1L, dtype=int64)]
[array(0L, dtype=int64)]
[array(0L, dtype=int64)]
[array(0L, dtype=int64)]

关于python - theano.scan() 如何处理空序列?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34413575/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com