作者热门文章
- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
我正在使用 Tensorflow 1.9 在 GTX1080 Ti 上使用卷积神经网络训练一组约 9000 张图像 (300x500),但每次都会遇到超出内存的问题。我收到一条关于系统内存超出 10% 的警告,几分钟后该进程被终止。我的代码如下。
import tensorflow as tf
from os import listdir
train_path = '/media/NewVolume/colorizer/img/train/'
col_train_path = '/media/NewVolume/colorizer/img/colored/train/'
val_path = '/media/NewVolume/colorizer/img/val/'
col_val_path = '/media/NewVolume/colorizer/img/colored/val/'
def load_image(image_file):
image = tf.read_file(image_file)
image = tf.image.decode_jpeg(image)
return image
train_dataset = []
col_train_dataset = []
val_dataset = []
col_val_dataset = []
for i in listdir(train_path):
train_dataset.append(load_image(train_path + i))
col_train_dataset.append(load_image(col_train_path + i))
for i in listdir(val_path):
val_dataset.append(load_image(val_path + i))
col_val_dataset.append(load_image(col_val_path + i))
train_dataset = tf.stack(train_dataset)
col_train_dataset = tf.stack(col_train_dataset)
val_dataset = tf.stack(val_dataset)
col_val_dataset = tf.stack(col_val_dataset)
input1 = tf.placeholder(tf.float32, [None, 300, 500, 1])
color = tf.placeholder(tf.float32, [None, 300, 500, 3])
#MODEL
conv1 = tf.layers.conv2d(inputs = input1, filters = 8, kernel_size=[5, 5], activation=tf.nn.relu, padding = 'same')
pool1 = tf.layers.max_pooling2d(inputs = conv1, pool_size=[2, 2], strides=2)
conv2 = tf.layers.conv2d(inputs = pool1, filters = 16, kernel_size=[5, 5], activation=tf.nn.relu, padding = 'same')
pool2 = tf.layers.max_pooling2d(inputs = conv2, pool_size=[2, 2], strides=2)
conv3 = tf.layers.conv2d(inputs = pool2, filters = 32, kernel_size=[5, 5], activation=tf.nn.relu, padding = 'same')
pool3 = tf.layers.max_pooling2d(inputs = conv3, pool_size=[2, 2], strides=2)
flat = tf.layers.flatten(inputs = pool3)
dense = tf.layers.dense(flat, 2432, activation = tf.nn.relu)
reshaped = tf.reshape(dense, [tf.shape(dense)[0],38, 64, 1])
conv_trans1 = tf.layers.conv2d_transpose(inputs = reshaped, filters = 32, kernel_size=[5, 5], activation=tf.nn.relu, padding = 'same')
upsample1 = tf.image.resize_nearest_neighbor(conv_trans1, (2*tf.shape(conv_trans1)[1],2*tf.shape(conv_trans1)[2]))
conv_trans2 = tf.layers.conv2d_transpose(inputs = upsample1, filters = 16, kernel_size=[5, 5], activation=tf.nn.relu, padding = 'same')
upsample2 = tf.image.resize_nearest_neighbor(conv_trans2, (2*tf.shape(conv_trans2)[1],2*tf.shape(conv_trans2)[2]))
conv_trans3 = tf.layers.conv2d_transpose(inputs = upsample2, filters = 8, kernel_size=[5, 5], activation=tf.nn.relu, padding = 'same')
upsample3 = tf.image.resize_nearest_neighbor(conv_trans3, (2*tf.shape(conv_trans3)[1],2*tf.shape(conv_trans3)[2]))
conv_trans4 = tf.layers.conv2d_transpose(inputs = upsample3, filters = 3, kernel_size=[5, 5], activation=tf.nn.relu, padding = 'same')
reshaped2 = tf.reshape(dense, [tf.shape(conv_trans4)[0],300,500,3])
#TRAINING
loss = tf.losses.mean_squared_error(color, reshaped2)
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)
EPOCHS = 10
BATCH_SIZE = 3
dataset = tf.data.Dataset.from_tensor_slices((train_dataset,col_train_dataset)).repeat().batch(BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(EPOCHS):
x,y=iterator.get_next()
_, loss_value = sess.run([train_step, loss],feed_dict={input1:x.eval(session=sess),color:y.eval(session=sess)})
print("Iter: {}, Loss: {:.4f}".format(i, loss_value))
最佳答案
我认为您的问题出在以下代码中。
def load_image(image_file):
image = tf.read_file(image_file)
image = tf.image.decode_jpeg(image)
return image
...
for i in listdir(train_path):
train_dataset.append(load_image(train_path + i))
col_train_dataset.append(load_image(col_train_path + i))
您正在尝试将 TF 张量运算用作常规代码。但最终得到的是图表上仅在 session 中进行评估的节点。在这种情况下,您尝试将训练和评估数据集中的每个图像加载到 GPU 内存中(因为您的 session 在 GPU 上运行)。我猜你的图像比 GPU 的内存还要多。
这个问题有多种解决方案。您可以将 tf.read_image 操作作为图形的一部分,并将每个批处理的图像名称作为训练循环中的 feed 字典传递。您可以构建一个适当的输入管道,其中文件名的加载、批处理和文件数据将在图中处理,或者您可以使用一些外部库将图像加载到 numpy 数组中,并将 numpy 数组输入到图中。
关于python - 在 Tensorflow 上训练卷积神经网络时 GPU 内存不足,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51934087/
我是一名优秀的程序员,十分优秀!