gpt4 book ai didi

python - ValueError : Graph disconnected: cannot obtain value for tensor Tensor. ..可以毫无问题地访问以下前几层:[]

转载 作者:行者123 更新时间:2023-12-04 01:43:02 25 4
gpt4 key购买 nike

我一直在尝试使用 Keras 创建多输入模型,但遇到了错误。这个想法是结合文本和相应的主题来预测情绪。这是代码:

import numpy as np
text = np.random.randint(5000, size=(442702, 200), dtype='int32')
topic = np.random.randint(2, size=(442702, 227), dtype='int32')
sentiment = to_categorical(np.random.randint(5, size=442702), dtype='int32')

from keras.models import Sequential
from keras.layers import Dense, Activation, Embedding, Flatten, GlobalMaxPool1D, Dropout, Conv1D
from keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from keras.losses import binary_crossentropy
from keras.optimizers import Adam


text_input = Input(shape=(200,), dtype='int32', name='text')
text_encoded = Embedding(input_dim=5000, output_dim=20, input_length=200)(text_input)
text_encoded = Dropout(0.1)(text_encoded)
text_encoded = Conv1D(300, 3, padding='valid', activation='relu', strides=1)(text_encoded)
text_encoded = GlobalMaxPool1D()(text_encoded)

topic_input = Input(shape=(227,), dtype='int32', name='topic')

concatenated = concatenate([text_encoded, topic_input])
sentiment = Dense(5, activation='softmax')(concatenated)

model = Model(inputs=[text_encoded, topic_input], outputs=sentiment)
# summarize layers
print(model.summary())
# plot graph
plot_model(model)

但是,这给了我以下错误:

TypeError:传递给“ConcatV2”Op 的“值”的列表中的张量具有不完全匹配的类型 [float32,int32]。

现在,如果我将 topic_input 的 dtype 从“int32”更改为“float32”,我会得到一个不同的错误:

ValueError: Graph disconnected: cannot obtain value for tensor Tensor("text_37:0", shape=(?, 200), dtype=int32) at layer "text". The following previous layers were accessed without issue: []

另一方面,部分模型与顺序 API 配合得很好。

model = Sequential()
model.add(Embedding(5000, 20, input_length=200))
model.add(Dropout(0.1))
model.add(Conv1D(300, 3, padding='valid', activation='relu', strides=1))
model.add(GlobalMaxPool1D())
model.add(Dense(227))
model.add(Activation('sigmoid'))

print(model.summary())

非常感谢任何指点。

最佳答案

您的 Keras 函数式 API 实现几乎没有问题,

  1. 您应该将 Concatenate 层用作 Concatenate(axis=-1)([text_encoded, topic_input])

  2. 在连接层中,您试图组合一个 int32 张量和一个 float32 张量,这是不允许的。你应该做的是,from keras.backend import castconcatenated = Concatenate(axis=-1)([text_encoded, cast(topic_input, 'float32')]) .

  3. 你有变量冲突,有两个 sentiment 变量,一个指向 to_categorical 输出,另一个指向最终 Dense< 的输出层。

  4. 您的模型输入不能是像 text_encoded 这样的中间张量。它们应该来自 Input 层。

为了帮助您实现,这里提供了 TF 1.13 中代码的工作版本(我不确定这是否正是您想要的)。

from keras.utils import to_categorical
text = np.random.randint(5000, size=(442702, 200), dtype='int32')
topic = np.random.randint(2, size=(442702, 227), dtype='int32')
sentiment1 = to_categorical(np.random.randint(5, size=442702), dtype='int32')

from keras.models import Sequential
from keras.layers import Input, Dense, Activation, Embedding, Flatten, GlobalMaxPool1D, Dropout, Conv1D, Concatenate, Lambda
from keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from keras.losses import binary_crossentropy
from keras.optimizers import Adam
from keras.backend import cast
from keras.models import Model

text_input = Input(shape=(200,), dtype='int32', name='text')
text_encoded = Embedding(input_dim=5000, output_dim=20, input_length=200)(text_input)
text_encoded = Dropout(0.1)(text_encoded)
text_encoded = Conv1D(300, 3, padding='valid', activation='relu', strides=1)(text_encoded)
text_encoded = GlobalMaxPool1D()(text_encoded)

topic_input = Input(shape=(227,), dtype='int32', name='topic')

topic_float = Lambda(lambda x:cast(x, 'float32'), name='Floatconverter')(topic_input)

concatenated = Concatenate(axis=-1)([text_encoded, topic_float])
sentiment = Dense(5, activation='softmax')(concatenated)

model = Model(inputs=[text_input, topic_input], outputs=sentiment)
# summarize layers
print(model.summary())

希望这些帮助。

关于python - ValueError : Graph disconnected: cannot obtain value for tensor Tensor. ..可以毫无问题地访问以下前几层:[],我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56589726/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com