gpt4 book ai didi

tensorflow - 谷歌 Colab : Why is CPU faster than TPU?

转载 作者:行者123 更新时间:2023-12-03 19:43:33 24 4
gpt4 key购买 nike

我正在使用 Google colab TPU 训练一个简单的 Keras 模型。删除分布式 策略 并在 上运行相同的程序CPU 快得多TPU .这怎么可能?

import timeit
import os
import tensorflow as tf
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam

# Load Iris dataset
x = load_iris().data
y = load_iris().target

# Split data to train and validation set
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.30, shuffle=False)

# Convert train data type to use TPU
x_train = x_train.astype('float32')
x_val = x_val.astype('float32')

# Specify a distributed strategy to use TPU
resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.contrib.distribute.initialize_tpu_system(resolver)
strategy = tf.contrib.distribute.TPUStrategy(resolver)

# Use the strategy to create and compile a Keras model
with strategy.scope():
model = Sequential()
model.add(Dense(32, input_shape=(4,), activation=tf.nn.relu, name="relu"))
model.add(Dense(3, activation=tf.nn.softmax, name="softmax"))
model.compile(optimizer=Adam(learning_rate=0.1), loss='logcosh')

start = timeit.default_timer()

# Fit the Keras model on the dataset
model.fit(x_train, y_train, batch_size=20, epochs=20, validation_data=[x_val, y_val], verbose=0, steps_per_epoch=2)

print('\nTime: ', timeit.default_timer() - start)

最佳答案

谢谢你的问题。

我认为这里发生的事情是一个开销问题——因为 TPU 在一个单独的 VM 上运行(可在 grpc://$COLAB_TPU_ADDR 访问),每次在 TPU 上运行模型的调用都会产生一些作为客户端的开销(Colab notebook in在这种情况下)将图形发送到 TPU,然后编译并运行。与运行所需的时间相比,此开销很小,例如一个时期的 ResNet50,但与运行像您示例中的模型这样的简单模型相比较大。

为了在 TPU 上获得最佳效果,我们建议使用 tf.data.Dataset .我为 TensorFlow 2.2 更新了您的示例:

%tensorflow_version 2.x
import timeit
import os
import tensorflow as tf
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam

# Load Iris dataset
x = load_iris().data
y = load_iris().target

# Split data to train and validation set
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.30, shuffle=False)

# Convert train data type to use TPU
x_train = x_train.astype('float32')
x_val = x_val.astype('float32')

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(20)
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(20)

# Use the strategy to create and compile a Keras model
with strategy.scope():
model = Sequential()
model.add(Dense(32, input_shape=(4,), activation=tf.nn.relu, name="relu"))
model.add(Dense(3, activation=tf.nn.softmax, name="softmax"))
model.compile(optimizer=Adam(learning_rate=0.1), loss='logcosh')

start = timeit.default_timer()

# Fit the Keras model on the dataset
model.fit(train_dataset, epochs=20, validation_data=val_dataset)

print('\nTime: ', timeit.default_timer() - start)

这大约需要 30 秒才能运行,而在 CPU 上运行大约需要 1.3 秒。我们可以通过重复数据集并运行一个长时期而不是几个小时期来大大减少这里的开销。我用这个替换了数据集设置:

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).repeat(20).batch(20)
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(20)

并更换了 fit用这个调用:

model.fit(train_dataset, validation_data=val_dataset)

这使我的运行时间减少到大约 6 秒。这仍然比 CPU 慢,但对于这样一个可以轻松在本地运行的小模型来说,这并不奇怪。一般来说,您会发现将 TPU 用于更大的模型会带来更多好处。我建议浏览 TensorFlow's official TPU guide ,它为 MNIST 数据集提供了一个更大的图像分类模型。

关于tensorflow - 谷歌 Colab : Why is CPU faster than TPU?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59264851/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com