gpt4 book ai didi

python - 将numpy uint8输入tensorflow float32占位符

转载 作者:太空宇宙 更新时间:2023-11-03 14:00:13 27 4
gpt4 key购买 nike

我最近发现了一种概念验证实现,它使用 numpy.zeros 在 one-hot 编码中准备功能:

data = np.zeros((len(raw_data), n_input, vocab_size),dtype=np.uint8)

如上所示,单个的类型为 np.uint8。检查模型后,我意识到tensorflow模型的输入占位符定义为tf.float32:

x = tf.placeholder(tf.float32, [None, n_input, vocab_size], name="onehotin")

我的具体问题: tensorflow 如何处理这种输入类型的“不匹配”。这些值 (0/1) 是否由 tensorflow 正确解释或转换。如果是这样,文档中是否提到了这一点。谷歌搜索后我找不到答案。应该提到的是,该模型的运行和值似乎是合理的。但是,将输入 numpy 特征输入为 np.float32 会导致需要大量内存。

相关性:正在运行但经过错误训练的模型在采用输入管道/将模型投入生产后会有不同的行为。

最佳答案

Tensorflow 支持这样的数据类型转换。

x + 1等运算中,值1会经过tf.convert_to_tensor负责验证和转换的函数。有时会在后台手动调用该函数,并且当设置 dtype 参数时,该值会自动转换为此类型。

当您将数组输入到这样的占位符中时:

session.run(..., feed_dict={x: data})

...数据通过 np.asarray 调用显式转换为正确类型的 numpy 数组。源代码见python/client/session.py 。请注意,当数据类型不同时,此方法可能会重新分配缓冲区,而这正是您的情况所发生的情况。因此,您的内存优化并不完全按照您的预期工作:临时 32 位数据是在内部分配的。

关于python - 将numpy uint8输入tensorflow float32占位符,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49309679/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com