gpt4 book ai didi

tensorflow - 我可以在tensorflow中导出单词的嵌入矩阵吗?

转载 作者:行者123 更新时间:2023-12-05 04:08:06 24 4
gpt4 key购买 nike

def word_embedding(shape, dtype=tf.float32, name='word_embedding'):
with tf.device('/cpu:0'), tf.variable_scope(name):
return tf.get_variable('embedding', shape, dtype=dtype, initializer=tf.random_normal_initializer(stddev=0.1), trainable=True,partitioner=tf.fixed_size_partitioner(20))
embedding = word_embedding([vocab_size, embed_size])
inputs_embedding = tf.contrib.layers.embedding_lookup_unique(embedding, inputs)

这是我的代码,embedding 是 word 的变量,用于查找自己的嵌入向量。

我已经训练了嵌入矩阵,我想从保存的模型中提取它。该模型还包含其他参数,例如嵌入之上的神经网络。我可以实现吗?

最佳答案

参见 my answer类似的问题。

最简单的方法是将嵌入矩阵评估为一个 numpy 数组,并将其与已解析的单词一起写入文件。

with tf.Session() as sess:
embedding_val = sess.run(embedding)
with open('embedding.txt', 'w') as file_:
for i in range(vocabulary_size):
embed = embedding_val[i, :]
word = word_to_idx[i]
file_.write('%s %s\n' % (word, ' '.join(map(str, embed))))

如果您只想为这个图保存嵌入,您可以创建tf.train.Saver 并传递要保存的变量列表:

saver = tf.train.Saver([embedding])
with tf.Session() as sess:
saver.save(sess, 'path/to/checkpoint')

关于tensorflow - 我可以在tensorflow中导出单词的嵌入矩阵吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48019799/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com