gpt4 book ai didi

python - tensorflow.python.framework.errors.InvalidArgumentError

转载 作者:行者123 更新时间:2023-11-28 17:18:00 26 4
gpt4 key购买 nike

我用tensorflow实现了一个语言模型。训练数据只是 feed_dict 中的很多句子,如下所示:

feed_dict = {
model.inputs: x,
model.seqlen:seqlen
}

x 看起来像:

[[127713, 68665, 211766, 2698, 138657, 36122, 138963, 198149, 4975, 104939, 205505], [81512, 161790, 4191, 131922, 38206, 123973, 102593, 147631, 117256, 153046, 190414], [213013, 159996, 9461, 131922, 175230, 191825, 102593, 201242, 6535, 160687, 15960], [39155, 160687, 2236, 117259, 200449, 120265, 214117, 102593, 117198, 138657, 159996], [9959, 136465, 121296, 96619, 184509, 10843, 117256, 102593, 187463, 213648, 102593], [11370, 189417, 127691, 43487, 109775, 19315, 102593, 130793, 36122, 160023, 138657], [221903, 102593, 76854, 215208, 146459, 172190, 99562, 54144, 141328, 134798, 176905], [102593, 29004, 189417, 77559, 11370, 102593, 201121, 436, 127713, 85797, 71369], [67515, 90422, 141328, 102593, 222023, 107914, 155883, 102593, 148221, 169199, 36122], [205336, 11191, 127713, 115425, 147700, 152270, 80276, 143317, 4190, 2373, 24519], [61049, 144035, 219863, 54144, 111851, 104926, 117256, 182871, 10033, 188890, 102593], [97804, 95468, 72416, 178512, 56040, 190225, 169304, 214785, 127713, 106900, 32960], [220409, 11370, 117249, 213607, 89611, 34385, 117256, 198815, 49674, 94546, 37171], [179753, 176347, 160687, 32912, 72416, 189281, 203515, 44526, 190225, 160687, 189417], [49035, 165055, 100531, 102593, 187465, 6535, 174629, 175940, 208552, 124145, 42418], [136713, 67575, 193443, 24519, 0, 67515, 71905, 36122, 78050, 36122, 117492], [67575, 201558, 169304, 25531, 102593, 152308, 124145, 129101, 75544, 117256, 102593], [127713, 58045, 7814, 90422, 36130, 26354, 11370, 169304, 71048, 196602, 133966], [223954, 127713, 135835, 111851, 36122, 102593, 16398, 24622, 11370, 102593, 90879], [34539, 46136, 72416, 79125, 214125, 31507, 117256, 127713, 21687, 150290, 102593], [172081, 117256, 127713, 148704, 193249, 189417, 57754, 204591, 117256, 127713, 217441], [156885, 213648, 102593, 137549, 24519, 102593, 81722, 159996, 92404, 102593, 158063], [117256, 102593, 1481, 36122, 102593, 188983, 117249, 189417, 2698, 4190, 198149], [146627, 188890, 102593, 220327, 36122, 26266, 11370, 32603, 67575, 136465, 102593], [117249, 189417, 179882, 190414, 115744, 138657, 117249, 189417, 190225, 215006, 51726], [70710, 152185, 129802, 137980, 95640, 119899, 102593, 203527, 4191, 131922, 57303], [138657, 189417, 75401, 117256, 102593, 39587, 131922, 110117, 138657, 138963, 42664], [35145, 15678, 65575, 11370, 131922, 202552, 190414, 102593, 195413, 209716, 61049], [213218, 158064, 190414, 72416, 99562, 145256, 68055, 190414, 112808, 102593, 94655], [36117, 45024, 170008, 158664, 201179, 162247, 36117, 72039, 436, 63876, 210529], [121778, 11370, 169304, 51713, 72416, 160980, 100531, 102593, 187465, 127691, 160687], [196602, 190414, 115744, 152185, 117249, 211349, 190414, 198056, 152386, 219761, 212195], [106606, 127713, 34109, 154924, 119235, 36122, 127713, 133841, 114413, 102593, 195413], [161791, 163058, 49084, 99562, 98981, 160687, 11191, 127713, 116409, 117256, 102593], [49674, 144174, 189417, 127689, 222397, 36122, 161717, 436, 107573, 11370, 186602], [102593, 76854, 14223, 180403, 150708, 196787, 36117, 186602, 8374, 102593, 148453], [189417, 53675, 58648, 11370, 102593, 130984, 141328, 157511, 190414, 102593, 137453], [190786, 213013, 99562, 54144, 25531, 101525, 127222, 11370, 144108, 11370, 149922], [76179, 107914, 43486, 174088, 161609, 38367, 166913, 160687, 4188, 40566, 190414], [111186, 176905, 188890, 182871, 100952, 11370, 221875, 182871, 199204, 36117, 127713], [216479, 11370, 196787, 123973, 58648, 138657, 164316, 117256, 102593, 214093, 118878], [127689, 190225, 141334, 67575, 89207, 189281, 36166, 36122, 35179, 102593, 173841], [73827, 45780, 140996, 61049, 35145, 134798, 190414, 102593, 210662, 36122, 102593], [220833, 181338, 138657, 102593, 131688, 36122, 22599, 11370, 102593, 203636, 28886], [77513, 189417, 190414, 72416, 189281, 146384, 190414, 83835, 102593, 141940, 36122], [159996, 43486, 72416, 190414, 177756, 159391, 213648, 102593, 123641, 36122, 82016], [145098, 117249, 117247, 87334, 11370, 126458, 37923, 140495, 102593, 113303, 11370], [102593, 69762, 70104, 67575, 180545, 214125, 53255, 190414, 102593, 198785, 117249], [116408, 138657, 138963, 36122, 102593, 20362, 76179, 35145, 136290, 214125, 102593], [35406, 160687, 121032, 136465, 102593, 181712, 169923, 58974, 36117, 92968, 102593]]

我的模型代码:

import numpy as np
import tensorflow as tf
from tensorflow.python.ops import array_ops


class Model(object):
def __init__(
self,
batch_size,
vocab_size,
hidden_size,
learning_rate):

self.inputs = tf.placeholder(tf.int32, [batch_size, None])

self.seqlen = tf.placeholder(tf.float32)

with tf.device('/cpu:0'), tf.name_scope("embedding"):
# embed = tf.get_variable(name="Embedding", shape=[vocab_size, hidden_size])
embed = tf.Variable(
tf.random_uniform([vocab_size, hidden_size], -1.0, 1.0))
self.embedded_chars = tf.nn.embedding_lookup(embed, self.inputs)
self.rev_input = tf.reverse(self.inputs, [False,True])
self.embedded_chars_rev = tf.nn.embedding_lookup(embed, self.rev_input)

with tf.variable_scope('forward'):
forward_lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size)
forward_outputs, _ = tf.nn.dynamic_rnn(forward_lstm_cell, self.embedded_chars,
sequence_length=self.seqlen,
dtype=tf.float32)

with tf.variable_scope('backward'):
backward_lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size)
backward_outputs, _ = tf.nn.dynamic_rnn(backward_lstm_cell,
self.embedded_chars_rev,
sequence_length=self.seqlen,
dtype=tf.float32)
lstm_outputs = tf.add(forward_outputs, backward_outputs, name="lstm_outputs")
self.outputs = tf.nn.relu(lstm_outputs)
# W = tf.Variable(tf.truncated_normal([hidden_size,vocab_size, 1], -0.1, 0.1))
W = tf.get_variable('Weights', shape=[hidden_size, 1])
b = tf.get_variable('Bias', shape=[1])
outputs = self.outputs[:,1,:]
y_pred = tf.squeeze(tf.matmul(outputs, W)) + b
inputs_0 = tf.cast(self.inputs[:,0], tf.float32)
self.loss = tf.nn.sigmoid_cross_entropy_with_logits(y_pred, inputs_0)

self.train_op = tf.train.AdamOptimizer(0.0002).minimize(self.loss)

当我输入数据并运行时,发生错误:

     Traceback (most recent call last):
File "h1_mscc/train_model.py", line 156, in <module>
tf.app.run()
File "/usr/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 30, in run
sys.exit(main(sys.argv))
File "h1_mscc/train_model.py", line 153, in main
train()
File "h1_mscc/train_model.py", line 143, in train
train_step(batch,seqlen)
File "h1_mscc/train_model.py", line 134, in train_step
feed_dict)
File "/usr/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 372, in run
run_metadata_ptr)
File "/usr/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 636, in _run
feed_dict_string, options, run_metadata)
File "/usr/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 708, in _do_run
target_list, options, run_metadata)
File "/usr/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 728, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors.InvalidArgumentError: indices[0,0] = 205505 is not in [0, 50000)
[[Node: embedding/embedding_lookup_1 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, _class=["loc:@embedding/Variable"], validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](embedding/Variable/read, embedding/Reverse)]]
Caused by op u'embedding/embedding_lookup_1', defined at:
File "h1_mscc/train_model.py", line 156, in <module>
tf.app.run()
File "/usr/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 30, in run
sys.exit(main(sys.argv))
File "h1_mscc/train_model.py", line 153, in main
train()
File "h1_mscc/train_model.py", line 81, in train
model = create_model(sess)
File "h1_mscc/train_model.py", line 59, in create_model
learning_rate=FLAGS.learning_rate)
File "/home/liac/code/Project3-preprocess-master/h1_mscc/model1.py", line 24, in __init__
self.embedded_chars_rev = tf.nn.embedding_lookup(embed, self.rev_input)
File "/usr/lib/python2.7/site-packages/tensorflow/python/ops/embedding_ops.py", line 86, in embedding_lookup
validate_indices=validate_indices)
File "/usr/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 780, in gather
validate_indices=validate_indices, name=name)
File "/usr/lib/python2.7/site-packages/tensorflow/python/ops/op_def_library.py", line 704, in apply_op
op_def=op_def)
File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2260, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/usr/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1230, in __init__
self._traceback = _extract_stack()

这让我很困惑,希望能得到一些回应!输入

最佳答案

嘿,你会不会把 vocab_size 定义错了? https://github.com/tensorflow/tensorflow/issues/2734
看起来可能是这样的问题。

多说一点,了解如何使用参数执行模型可能会有所帮助。

关于python - tensorflow.python.framework.errors.InvalidArgumentError,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43089598/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com