gpt4 book ai didi

python - TensorFlow 的 map_fn 只在 CPU 上运行

转载 作者:太空宇宙 更新时间:2023-11-04 02:24:19 25 4
gpt4 key购买 nike

当我试图让 TensorFlow 的 map_fn 在我的 GPU 上运行时,我遇到了一个奇怪的问题。这是一个最小的错误示例:

import numpy as np
import tensorflow as tf

with tf.Session() as sess:
with tf.device("/gpu:0"):
def test_func(i):
return i
test_range = tf.constant(np.arange(5))
test = sess.run(tf.map_fn(test_func, test_range, dtype=tf.float32))
print(test)

这会导致错误:

InvalidArgumentError: Cannot assign a device for operation 'map/TensorArray_1': Could not satisfy explicit device specification '' because the node was colocated with a group of nodes that required incompatible device '/device:GPU:0' Colocation Debug Info: Colocation group had the following types and devices: TensorArrayScatterV3: CPU TensorArrayGatherV3: GPU CPU Range: GPU CPU TensorArrayWriteV3: CPU TensorArraySizeV3: GPU CPU TensorArrayReadV3: CPU Enter: GPU CPU TensorArrayV3: CPU Const: GPU CPU

Colocation members and user-requested devices:
map/TensorArrayStack/range/delta (Const)
map/TensorArrayStack/range/start (Const) map/TensorArray_1 (TensorArrayV3) map/while/TensorArrayWrite/TensorArrayWriteV3/Enter (Enter) /device:GPU:0 map/TensorArrayStack/TensorArraySizeV3 (TensorArraySizeV3) map/TensorArrayStack/range (Range)
map/TensorArrayStack/TensorArrayGatherV3 (TensorArrayGatherV3)
map/TensorArray (TensorArrayV3) map/while/TensorArrayReadV3/Enter (Enter) /device:GPU:0 Const (Const) /device:GPU:0
map/TensorArrayUnstack/TensorArrayScatter/TensorArrayScatterV3 (TensorArrayScatterV3) /device:GPU:0 map/while/TensorArrayReadV3 (TensorArrayReadV3) /device:GPU:0
map/while/TensorArrayWrite/TensorArrayWriteV3 (TensorArrayWriteV3) /device:GPU:0

[[Node: map/TensorArray_1 = TensorArrayV3clear_after_read=true, dtype=DT_FLOAT, dynamic_size=false, element_shape=, identical_element_shapes=true, tensor_array_name=""]]

代码在我的 CPU 上运行时表现符合预期,简单操作如下:

import numpy as np
import tensorflow as tf

with tf.Session() as sess:
with tf.device("/gpu:0"):
def test_func(i):
return i
test_range = tf.constant(np.arange(5))
test = sess.run(tf.add(test_range, test_range))
print(test)

在我的 GPU 上运行良好。 This post似乎描述了一个类似的问题。有没有人有任何提示?该帖子的答案暗示 map_fn 应该在 GPU 上正常工作。我在 Arch Linux 上的 Python 3.6.4 上运行 TensorFlow 1.8.0 版,在 GeForce GTX 1050 上运行 CUDA 9.0 版和 cuDNN 7.0 版。

谢谢!

最佳答案

该错误实际上源于 np.arange 默认生成 int32 但您指定了 float32 返回类型。错误消失了

import numpy as np
import tensorflow as tf

with tf.Session() as sess:
with tf.device("/gpu:0"):
def test_func(i):
return i
test_range = tf.constant(np.arange(5, dtype=np.float32))
test = sess.run(tf.map_fn(test_func, test_range, dtype=tf.float32))
print(test)

我同意您收到的错误消息相当困惑。您通过删除设备放置获得“真实”错误消息:

import numpy as np
import tensorflow as tf

with tf.Session() as sess:
def test_func(i):
return i
test_range = tf.constant(np.arange(5))
test = sess.run(tf.map_fn(test_func, test_range, dtype=tf.float32))
print(test)
# InvalidArgumentError (see above for traceback): TensorArray dtype is float but Op is trying to write dtype int32.

关于python - TensorFlow 的 map_fn 只在 CPU 上运行,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50845372/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com