gpt4 book ai didi

tensorflow - tf.Variables 的维度在一些时期后发生变化

转载 作者:行者123 更新时间:2023-12-04 04:12:40 27 4
gpt4 key购买 nike

我是 TensorFlow 的新手,正在学习。我定义了一些变量并开始训练。第一个时期一切顺利,但突然抛出以下错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: Matrix size-incompatible: In[0]: [17952,50], In[1]: [0,20]
[[{{node gradients/Embeddings_1/MatMul_grad/MatMul_1}}]]
[[gradients/Embeddings_1/MatMul_grad/tuple/control_dependency/_1867]]
(1) Invalid argument: Matrix size-incompatible: In[0]: [17952,50], In[1]: [0,20]
[[{{node gradients/Embeddings_1/MatMul_grad/MatMul_1}}]]

我的问题是为什么它在一些时代之后而不是一开始就给出了错误。通常,在构建图形时会抛出这些类型的错误。

这是我创建变量和嵌入树的代码:

    def __init__(self, vocab, embedding):

self.add_model_variables()

with tf.variable_scope("Embeddings", reuse=True):
with tf.device('/cpu:0'):
w_embed = tf.get_variable('WE', [self.vocab_embedding_size, self.embed_size])
b_embed = tf.get_variable('bE', [1, self.embed_size])
embeddings = tf.get_variable('embeddings')
self.embeddings = tf.add(tf.matmul(embeddings, w_embed), b_embed)


def add_model_variables(self):
myinitilizer = tf.random_uniform_initializer(-self.calc_wt_init(),self.calc_wt_init())

with tf.variable_scope('Embeddings'):
with tf.device('/cpu:0'):
w_embed = tf.get_variable('WE', [self.vocab_embedding_size, self.embed_size], initializer = myinitilizer)
b_embed = tf.get_variable('bE', [1, self.embed_size], initializer = myinitilizer)
embeddings = tf.get_variable('embeddings',
initializer=tf.convert_to_tensor(self.pretrained_embedding),
dtype=tf.float32)


with tf.variable_scope('Composition'):
self.W1 = tf.get_variable('W1', [2 * self.embed_size, self.embed_size], initializer = myinitilizer)
self.b1 = tf.get_variable('b1', [1, self.embed_size], initializer = myinitilizer)

with tf.variable_scope('Projection'):
self.U = tf.get_variable('U', [self.embed_size, 1], initializer = myinitilizer)
self.bu = tf.get_variable('bu', [self.max_number_nodes, 1], initializer = myinitilizer)


def embed_tree(self, batch_index):
def combine_children( left_tensor, right_tensor):
return tf.nn.relu(tf.matmul(tf.concat([left_tensor, right_tensor], axis=1, name='combine_children'), self.W1) + self.b1)

def embed_word(word_index):
with tf.device('/cpu:0'):
return tf.expand_dims(tf.gather(self.embeddings, word_index), 0)

def loop_body(node_tensors, i):
node_is_leaf = tf.gather(is_leaf, i)
word = tf.gather(words, i)
left_child = tf.gather(left_children, i)
right_child = tf.gather(right_children, i)
node_tensor = tf.cond(
node_is_leaf,
lambda: embed_word(word),
lambda: combine_children(
node_tensors.read(n-right_child),
node_tensors.read(n-left_child)))
node_tensors = node_tensors.write(i, node_tensor)
i = tf.add(i, 1)
return node_tensors, i

is_leaf = tf.gather(self.batch_is_leaf, batch_index)
left_children = tf.gather(self.batch_left_children, batch_index)
right_children = tf.gather(self.batch_right_children, batch_index)
words = tf.gather(self.batch_words, batch_index)
n = tf.reduce_sum(tf.cast(tf.not_equal(left_children, -1), tf.int32))-2
#iself.batch_operation = tf.print(batch_index,'N::::::::',output_stream=sys.stdout)

node_tensors = tf.TensorArray(tf.float32, size=self.max_number_nodes,
dynamic_size=False, clear_after_read=False, element_shape=[1, self.embed_size])
loop_cond = lambda node_tensors, i: tf.less(i, n+2)
#with tf.control_dependencies([self.batch_operation]):
node_tensors, _ = tf.while_loop(loop_cond, loop_body, [node_tensors, 0], parallel_iterations=1)
tree_embedding = tf.convert_to_tensor(node_tensors.stack())
return tree_embedding


另一个问题是我无法复制错误,因为它偶尔会发生。

更新:

当我减小 batch_size 时,出现此错误的几率就会降低。这可能是因为工作接近 GPU 内存限制吗?

最佳答案

tf.gather 在 GPU 上为无效索引生成零(但它在 CPU 上正常工作)。换句话说,Tensorflow 在 GPU 上运行时不会检查索引的范围。

由返回的 0 引起的错误在梯度上累积,最终导致与原始问题无关的困惑错误消息。

供引用:

https://github.com/tensorflow/tensorflow/issues/3638

我将 tf.gather 更改为基于索引的检索(a[i]),问题已解决。我不知道为什么!

关于tensorflow - tf.Variables 的维度在一些时期后发生变化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61447546/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com