gpt4 book ai didi

python - Tensorflow 文本生成未返回有效索引

转载 作者:行者123 更新时间:2023-11-30 09:58:22 26 4
gpt4 key购买 nike

我正在尝试训练 Tensorflow 模型来生成文本。我主要使用来自 Tensorflow 网站的代码,但是当我尝试生成文本时,模型返回不在 word_index 中的索引。

文本生成功能:

model = create_model(vocab_size = vocab_size,
embed_dim=embed_dim,
rnn_neurons=rnn_neurons,
batch_size=1)

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

model.build(tf.TensorShape([1, None]))

char_2_index = tokenizer.word_index
index_2_char = {ind:char for char, ind in char_2_index.items()}

def generate_text(model, start_string):

num_generate = 1000

input_eval = [char_2_index[s] for s in start_string]
input_eval = tf.expand_dims(input_eval, 0)

text_generated = []

temperature = 1.0

model.reset_states()
for i in range(num_generate):
print(text_generated)
predictions = model(input_eval)

predictions = tf.squeeze(predictions, 0)

predictions = predictions / temperature
print(predictions)
predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()
print(predicted_id)

input_eval = tf.expand_dims([predicted_id], 0)


text_generated.append(index_2_char[predicted_id])

return (start_string + ''.join(text_generated))

错误

KeyError                                  Traceback (most recent call last)
<ipython-input-52-9517558352c4> in <module>()
----> 1 print(generate_text(model, start_string=u"Is Baby yoda "))

<ipython-input-47-75973c66de6c> in generate_text(model, start_string)
37
38
---> 39 text_generated.append(index_2_char[predicted_id])
40
41 return (start_string + ''.join(text_generated))

KeyError: 133

单词索引以及训练文本仅包含大小写字母。

编辑有关更多上下文,这是我的数据准备和结构

结构[['SENTENCE'], ['SENTENCE2']...]

数据准备

tokenizer = keras.preprocessing.text.Tokenizer(num_words=209, lower=False, char_level=True, filters='#$%&()*+-<=>@[\\]^_`{|}~\t\n')
tokenizer.fit_on_texts(df['title'].values)
df['encoded_with_keras'] = tokenizer.texts_to_sequences(df['title'].values)

dataset = df['encoded_with_keras'].values
dataset = tf.keras.preprocessing.sequence.pad_sequences(dataset, padding='post')

dataset = dataset.flatten()

dataset = tf.data.Dataset.from_tensor_slices(dataset)

sequences = dataset.batch(seq_len+1, drop_remainder=True)

def create_seq_targets(seq):
input_txt = seq[:-1]
target_txt = seq[1:]
return input_txt, target_txt

dataset = sequences.map(create_seq_targets)

dataset = dataset.shuffle(buffer_size).batch(batch_size, drop_remainder=True)

最佳答案

create_model(...) 中使用的 vocab_size 似乎不等于 index_2_char 的长度。

关于python - Tensorflow 文本生成未返回有效索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60096831/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com