gpt4 book ai didi

tensorflow - 以 double 运行 Keras 失败

转载 作者:行者123 更新时间:2023-12-03 16:47:24 28 4
gpt4 key购买 nike

我试图在 Keras 上以 double 运行 LeNet,但失败并显示错误:TypeError: Input 'filter' of 'Conv2D' Op has type float64 that does not match type float32 of argument 'input'. .我使用的代码如下:

import numpy as np
from sklearn.utils import shuffle
import keras
from keras.models import Sequential
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, Dropout,Flatten
from keras import backend as K
from keras.models import Model
from keras.utils import np_utils
import time
import tensorflow as tf
K.set_floatx('float64') # Note: the code works if we comment this line, i.e., with single precision
from mlxtend.data import mnist_data
X, y = mnist_data()
X = X.astype(np.float64)
X, y = shuffle(X, y)
keras_model = Sequential()
keras_model.add(Conv2D(32, kernel_size=(5, 5), activation='relu', input_shape=(28,28,1), padding='same'))
keras_model.add(MaxPooling2D(pool_size=(2, 2)))
keras_model.add(Conv2D(64, (5, 5), activation='relu', padding='same'))
keras_model.add(MaxPooling2D(pool_size=(2, 2)))
keras_model.add(Flatten())
keras_model.add(Dense(512, activation='relu'))
keras_model.add(Dropout(0.5))
keras_model.add(Dense(10, activation='softmax'))
keras_model.compile(loss='categorical_crossentropy', optimizer=keras.optimizers.SGD(lr=0.01, momentum=0.95, decay=5e-4, nesterov=True))
keras_model.fit(X.reshape((-1, 28,28, 1)), np_utils.to_categorical(y, 10), epochs=1, batch_size=64)

任何建议都非常感谢:)

最佳答案

你有一个游戏 NVIDIA GPU。
您只能使用 float32int32 , 就这样。
这是 TensorFlow 的默认设置。
由于 Nvidia 具有 CUDA 功能的 GPU 的限制,Tensorflow 引入了此默认值。 Best explanation I found here.因此,优质 Tesla GPU 在 float16 上运行良好和 float64同样,但游戏 GPU 仅适用于 float32并且对 float16 的表现非常糟糕或 float64 .
我认为我们都在关注价格更高的 AMD GPU 支持的 OpenCL。不幸的是,目前 TensorFlow 不支持 OpenCL。

建议#1:
你被float32困住了。拥有该硬件时忘记更改它。

建议#2:
一旦你使用 float16 获得了良好的 GPU,就改变它。机器学习不需要高于此的高精度。

关于tensorflow - 以 double 运行 Keras 失败,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48552508/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com