gpt4 book ai didi

python - 在 Keras 中定义自定义 LSTM 单元?

转载 作者:行者123 更新时间:2023-12-04 23:36:14 27 4
gpt4 key购买 nike

我使用 Keras 和 TensorFlow 作为后端。如果我想对 LSTM 单元进行修改,例如“移除”输出门,我该怎么做?它是一个乘法门,所以我必须以某种方式将它设置为固定值,这样无论乘以什么,都无效。

最佳答案

首先,您应该定义您的 own custom layer .如果您需要一些直觉如何实现您自己的单元,请参阅 LSTMCell在 Keras 存储库中。例如。您的自定义单元格将是:

class MinimalRNNCell(keras.layers.Layer):

def __init__(self, units, **kwargs):
self.units = units
self.state_size = units
super(MinimalRNNCell, self).__init__(**kwargs)

def build(self, input_shape):
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.built = True

def call(self, inputs, states):
prev_output = states[0]
h = K.dot(inputs, self.kernel)
output = h + K.dot(prev_output, self.recurrent_kernel)
return output, [output]

然后,使用 tf.keras.layers.RNN 使用您的手机:
cell = MinimalRNNCell(32)
x = keras.Input((None, 5))
layer = RNN(cell)
y = layer(x)

# Here's how to use the cell to build a stacked RNN:

cells = [MinimalRNNCell(32), MinimalRNNCell(64)]
x = keras.Input((None, 5))
layer = RNN(cells)
y = layer(x)

关于python - 在 Keras 中定义自定义 LSTM 单元?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54231440/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com