gpt4 book ai didi

python - 将代码从 Keras 1 转换为 Keras 2 : TypeError: __call__() missing 1 required positional argument: 'shape'

转载 作者:太空宇宙 更新时间:2023-11-03 14:10:32 24 4
gpt4 key购买 nike

我正在尝试将用 Keras 1 编写的 V-net 代码转换为 Keras 2。我似乎对以下类有问题:

class Deconv3D(Layer):
def __init__(self, nb_filter, kernel_dims, output_shape, strides):
assert K.backend() == 'tensorflow'
self.nb_filter = nb_filter
self.kernel_dims = kernel_dims
self.strides = (1,) + strides + (1,)
self.output_shape_ = output_shape
super(Deconv3D, self).__init__()

def build(self, input_shape):
assert len(input_shape) == 5
self.input_shape_ = input_shape
W_shape = self.kernel_dims + (self.nb_filter, input_shape[4], )
self.W = self.add_weight(W_shape, initializer=functools.partial(initializers.glorot_uniform), name='{}_W'.format(self.name))
self.b = self.add_weight((1,1,1,self.nb_filter,), initializer='zero', name='{}_b'.format(self.name))
self.built = True

def get_output_shape_for(self, input_shape):
return (None, ) + self.output_shape_[1:]

def call(self, x, mask=None):
return tf.nn.conv3d_transpose(x, self.W, output_shape=self.output_shape_, strides=self.strides, padding='same', name=self.name) + self.b

当我尝试使用 Deconv3D(128, (2, 2, 2), (1, 16, 16, 8, 128), (2, 2, 2))() 调用它时,我收到以下我不明白的错误:

Traceback (most recent call last):
File "V-net.py", line 118, in <module>
downsample_5 = Deconv3D(128, (2, 2, 2), (1, 16, 16, 8, 128), (2, 2, 2))(prelu_5_1) # Check the 8
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/keras/engine/topology.py", line 569, in __call__
self.build(input_shapes[0])
File "V-net.py", line 35, in build
self.W = self.add_weight(W_shape, initializer=functools.partial(initializers.glorot_uniform), name='{}_W'.format(self.name))
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/keras/legacy/interfaces.py", line 87, in wrapper
return func(*args, **kwargs)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/keras/engine/topology.py", line 391, in add_weight
weight = K.variable(initializer(shape), dtype=dtype, name=name)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py", line 321, in variable
v = tf.Variable(value, dtype=_convert_string_dtype(dtype), name=name)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/ops/variables.py", line 200, in __init__
expected_shape=expected_shape)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/ops/variables.py", line 278, in _init_from_args
initial_value(), name="initial_value", dtype=dtype)
TypeError: __call__() missing 1 required positional argument: 'shape'

我做错了什么?

最佳答案

Deconv3D 类必须与 Keras 2 架构相匹配。

class Deconvolution3D(Layer):

def __init__(self, nb_filter, kernel_dims, output_shape, subsample, **kwargs):
self.nb_filter = nb_filter
self.kernel_dims = kernel_dims
self.strides = (1, ) + subsample + (1, )
self.output_shape_ = output_shape
assert K.backend() == 'tensorflow'
super(Deconvolution3D, self).__init__(**kwargs)

def build(self, input_shape):
assert len(input_shape) == 5
self.W = self.add_weight(shape=self.kernel_dims + (self.nb_filter, input_shape[4], ),
initializer='glorot_uniform',
name='{}_W'.format(self.name),
trainable=True)
self.b = self.add_weight(shape=(1, 1, 1, self.nb_filter,),
initializer='zero',
name='{}_b'.format(self.name),
trainable=True)
super(Deconvolution3D, self).build(input_shape)

def call(self, x, mask=None):
return tf.nn.conv3d_transpose(x, self.W, output_shape=self.output_shape_,
strides=self.strides, padding='SAME', name=self.name) + self.b

def compute_output_shape(self, input_shape):
return (input_shape[0], ) + self.output_shape_[1:]

关于python - 将代码从 Keras 1 转换为 Keras 2 : TypeError: __call__() missing 1 required positional argument: 'shape' ,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48533216/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com