gpt4 book ai didi

tensorflow - 转换为张量后,参差不齐的张量没有 len()

转载 作者:行者123 更新时间:2023-12-04 17:16:23 27 4
gpt4 key购买 nike

我正在对具有可变维度的图像堆栈训练深度学习模型。 (Shape = [Batch, None, 256, 256, 1]),其中 None 可以是可变的。

我使用 tf.RaggedTensor.merge_dimsions(0,1) 将参差不齐的张量转换为 [None, 256, 256, 1] 的形状以运行预训练的 keras CNN 模型。

但是,使用 KerasLayer API 会导致以下错误:TypeError: the object of type 'RaggedTensor' has no len()

当我在 KerasLayer 外部应用 .merge_dimsions 并将张量传递给相同的预训练模型时,我没有收到此错误。

import tensorflow as tf

# Synthetic Data Pipeline
def synthetic_gen():
varShape = tf.random.uniform((), minval=1, maxval=12, dtype=tf.int32)
image = tf.random.normal((varShape, 256, 256, 1))
image = tf.RaggedTensor.from_tensor(image, ragged_rank=1)
yield image

ds = tf.data.Dataset.from_generator(synthetic_gen, output_signature=(tf.RaggedTensorSpec(shape=(None, 256, 256, 1), dtype=tf.float32, ragged_rank=1)))
ds = ds.repeat().batch(8)
print(next(iter(ds)).shape)

# Build Model
inputs = tf.keras.Input(
type_spec=tf.RaggedTensorSpec(
shape=(8, None, 256, 256, 1),
dtype=tf.float32,
ragged_rank=1))

ResNet50 = tf.keras.applications.ResNet50(
include_top=True,
input_shape=(256, 256, 1),
weights=None)

def merge(x):
x = x.merge_dims(0, 1)
return x
x = tf.keras.layers.Lambda(merge)(inputs)
merged_inputs = x
# x = ResNet50(x) # Uncommenting this will result in `model` producing an error when run for inference.

model = tf.keras.Model(inputs, x)

# Run inference
data = next(iter(ds))
model(data).shape # Will be an error if ResNet50 is used

这是一个演示问题的 colab notebook。 https://colab.research.google.com/drive/1kN78mf4_oNqxWOluV054NlqmakC5msli?usp=sharing

最佳答案

不确定以下答案或解决方法对于复杂的网络设计是否稳定。但这里有一些提示。你得到的原因

Ragged Tensors have no len()

是因为 ResNet 模型,因为它需要 tensor 而不是 ragged_tensor。不过,我不确定 ResNet(weights=None) 是否能够直接采用 ragged_tensor。因此,如果我们可以在 ResNet 被馈送之前转换参差不齐的数据,也许它就不会提示了。下面是完整的工作代码。但请注意,可能存在一些有效的方法。


数据

import tensorflow as tf

# Synthetic Data Pipeline
def synthetic_gen():
varShape = tf.random.uniform((), minval=1, maxval=12, dtype=tf.int32)
image = tf.random.normal((varShape, 256, 256, 1))
image = tf.RaggedTensor.from_tensor(image, ragged_rank=1)
yield image

ds = tf.data.Dataset.from_generator(synthetic_gen,
output_signature=(tf.RaggedTensorSpec(
shape=(None, 256, 256, 1),
dtype=tf.float32, ragged_rank=1
)
)
)
ds = ds.repeat().batch(8)

基本模型

# Build Model
inputs = tf.keras.Input(
type_spec=tf.RaggedTensorSpec(
shape=(8, None, 256, 256, 1),
dtype=tf.float32,
ragged_rank=1))

ResNet50 = tf.keras.applications.ResNet50(
include_top=True,
input_shape=(256, 256, 1),
weights=None)

def merge(x):
x = x.merge_dims(0, 1)
return x

衣衫褴褛的模特

在这里,我们将 ragged_tensor 转换为 tensor,然后再将数据传递给 ResNet

class RagModel(tf.keras.Model):
def __init__(self):
super(RagModel, self).__init__()
# base models
self.a = tf.keras.layers.Lambda(merge)
# convert: tensor = ragged_tensor.to_tensor()
self.b = tf.keras.layers.Lambda(lambda x: x.to_tensor())
self.c = ResNet50

def call(self, inputs, training=None, plot=False, **kwargs):
x = self.a(inputs)
x = self.b(x) if not plot else x
x = self.c(x)
return x

# a helper function to plot
def build_graph(self):
x = tf.keras.Input(type_spec=tf.RaggedTensorSpec(
shape=(8, None, 256, 256, 1),
dtype=tf.float32, ragged_rank=1)
)
return tf.keras.Model(inputs=[x],
outputs=self.call(x, plot=True))

x_model = RagModel()

运行

data = next(iter(ds)); print(data.shape)
x_model(data).shape
(8, None, 256, 256, 1)
TensorShape([39, 1000])

情节

tf.keras.utils.plot_model(x_model.build_graph(), 
show_shapes=True, show_layer_names=True)

enter image description here

x_model.build_graph().summary()

Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_4 (InputLayer) [(8, None, 256, 256, 1)] 0
_________________________________________________________________
lambda_2 (Lambda) (None, 256, 256, 1) 0
_________________________________________________________________
resnet50 (Functional) (None, 1000) 25630440
=================================================================
Total params: 25,630,440
Trainable params: 25,577,320
Non-trainable params: 53,120
_________________________________________________________________

关于tensorflow - 转换为张量后,参差不齐的张量没有 len(),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68638911/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com