gpt4 book ai didi

python - Tensorflow - 如何为 tf.Estimator() CNN 使用 GPU 而不是 CPU

转载 作者:太空狗 更新时间:2023-10-29 21:11:46 25 4
gpt4 key购买 nike

我认为它应该与 with tf.device("/gpu:0") 一起使用,但我应该把它放在哪里?我不认为它是:

with tf.device("/gpu:0"):
tf.app.run()

那么我应该把它放在 tf.appmain() 函数中,还是放在我用于估算器的模型函数中?

编辑:如果这有帮助,这是我的 main() 函数:

def main(unused_argv):
"""Code to load training folds data pickle or generate one if not present"""

# Create the Estimator
mnist_classifier = tf.estimator.Estimator(
model_fn=cnn_model_fn2, model_dir="F:/python_machine_learning_codes/tmp/custom_age_adience_1")

# Set up logging for predictions
# Log the values in the "Softmax" tensor with label "probabilities"
tensors_to_log = {"probabilities": "softmax_tensor"}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)

# Train the model
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": train_data},
y=train_labels,
batch_size=64,
num_epochs=None,
shuffle=True)
mnist_classifier.train(
input_fn=train_input_fn,
steps=500,
hooks=[logging_hook])

# Evaluate the model and print results
"""Code to load eval fold data pickle or generate one if not present"""

eval_logs = {"probabilities": "softmax_tensor"}
eval_hook = tf.train.LoggingTensorHook(
tensors=eval_logs, every_n_iter=100)
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": eval_data},
y=eval_labels,
num_epochs=1,
shuffle=False)
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn, hooks=[eval_hook])

如您所见,我在这里的任何地方都没有显式声明 session ,所以我到底应该把 with tf.device("/gpu:0") 放在哪里?

最佳答案

你可以把它放在你的模型函数的开头,即当你定义你的模型时,你应该写:

def cnn_model_fn2(...):
with tf.device('/gpu:0'):
...

但是,我希望 tensorflow 自动为您的模型使用 gpu。您可能想检查它是否被正确检测到:

from tensorflow.python.client import device_lib
device_lib.list_local_devices()

关于python - Tensorflow - 如何为 tf.Estimator() CNN 使用 GPU 而不是 CPU,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47277165/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com