gpt4 book ai didi

python - numpy 数组输入到 tensorflow/keras 神经网络的 dtype 有什么关系吗?

转载 作者:行者123 更新时间:2023-12-05 04:59:40 24 4
gpt4 key购买 nike

如果我采用 tensorflow.keras 模型并调用 model.fit(x, y)(其中 xy 是 numpy 数组) numpy 数组的 dtype 有什么关系吗?我最好只是让 dtype 尽可能小(例如 int8 用于二进制数据)还是这会给 tensorflow/keras 额外的工作来将其转换为 float ?

最佳答案

您应该将输入转换为 np.float32,这是 Keras 的默认数据类型。查一下:

import tensorflow as tf
tf.keras.backend.floatx()
'float32'

如果你在 np.float64 中给 Keras 输入,它会报错:

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn.datasets import load_iris
iris, target = load_iris(return_X_y=True)

X = iris[:, :3]
y = iris[:, 3]

ds = tf.data.Dataset.from_tensor_slices((X, y)).shuffle(25).batch(8)

class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.d0 = Dense(16, activation='relu')
self.d1 = Dense(32, activation='relu')
self.d2 = Dense(1, activation='linear')

def call(self, x):
x = self.d0(x)
x = self.d1(x)
x = self.d2(x)
return x

model = MyModel()

_ = model(X)

WARNING:tensorflow:Layer my_model is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because it's dtype defaults to floatx.If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.To change all layers to have dtype float64 by default, call tf.keras.backend.set_floatx('float64'). To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

通过 8bit input 可以使用 Tensorflow 进行训练,这被称为量化。但在大多数情况下,这具有挑战性且没有必要(即,除非您需要在边缘设备上部署模型)。

tl;dr 将您的输入保存在 np.float32 中。另见 this post .

关于python - numpy 数组输入到 tensorflow/keras 神经网络的 dtype 有什么关系吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63424782/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com