gpt4 book ai didi

python - 将非符号张量传递给 Keras Lambda 层

转载 作者:太空宇宙 更新时间:2023-11-03 20:23:30 26 4
gpt4 key购买 nike

我正在尝试将 RNNCell 对象传递给 Keras lambda 层,以便我可以在 Keras 模型中使用 Tensorflow 层,如下所示。

conv_cell = ConvGRUCell(shape = [14, 14],
filters = 32,
kernel = [3,3],
padding = 'SAME')

def convGRU(inputs, cell, length):
output, final = tf.nn.bidirectional_dynamic_rnn(
cell, cell, x, length, dtype=tf.float32)
output = tf.concat(output, -1)
final = tf.concat(final, -1)
return [output, final]

lm = Lambda(lambda x: convGRU(x[0], x[1], x[2])([input, conv_cell, length])

但是,我收到一个错误,指出 conv_cell 不是符号张量(它是基于 Tensorflow 的 GRUCell 的自定义层)。

有什么方法可以将单元格传递到 lambda 层吗?我让它与 functools.partial 一起使用,但它无法保存/加载模型,因为它无法访问模型内​​的函数。

最佳答案

def convGRU(cell, length): # if length is produced by the model, use it with the inputs    
def inner_func(inputs):
code...
return inner_func

lm = Lambda(convGRU(cell, length))(input)

对于保存/加载,您需要使用 custom_objects = {'convGRU': convGRU, 'cell':cell, 'length': length} 等。无论 Keras 不知道自动需要什么位于 custom_objects 中以加载已保存的模型。

关于python - 将非符号张量传递给 Keras Lambda 层,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58001215/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com