gpt4 book ai didi

python - TF keras API with TF数据集问题-steps_per_epoch参数问题

转载 作者:行者123 更新时间:2023-11-28 16:57:57 25 4
gpt4 key购买 nike

当尝试拟合 Keras 模型时,在 tensorflow.keras API 中使用 tf.Dataset 诱导迭代器编写,模型提示 steps_per_epoch参数,即使我已将其设置为具体值。

下面是我的模型类

import tensorflow as tf
import numpy as np
from typing import Union, List
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras import layers
from tftools import TFTools


class TestServe():
def __init__(self, tfrecords: Union[List[tf.train.Example], tf.train.Example], batch_size: int = 10, input_shape: tuple = (64, 23)) -> None:
self.tfrecords = tfrecords
self.batch_size = batch_size
self.input_shape = input_shape

def get_model(self):
ins = layers.Input(shape=(64, 23))

l = layers.Reshape((*self.input_shape, 1))(ins)
l = layers.Conv2D(8, (30, 23), padding='same', activation='relu')(l)
l = layers.MaxPool2D((4, 5), strides=(4, 5))(l)
l = layers.Conv2D(16, (3, 3), padding='same', activation='relu')(l)
l = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(l)
l = layers.MaxPool2D((2, 2), strides=(2, 2))(l)
l = layers.Flatten()(l)

out = layers.Dense(1, activation='softmax')(l)
return tf.keras.models.Model(ins, out)

def train(self):

# Create Dataset
dataset = TFTools.create_dataset(self.tfrecords)
dataset = dataset.repeat(6).batch(self.batch_size)

val_iterator = dataset.take(300).make_one_shot_iterator()
train_iterator = dataset.skip(300).make_one_shot_iterator()

model = self.get_model()
model.summary()
model.compile(optimizer='rmsprop',
loss='binary_crossentropy', metrics=['accuracy'])
model.fit(train_iterator, validation_data=val_iterator,
epochs=10, verbose=1, steps_per_epoch=20)

def predict(self, X: np.array) -> np.array:
pass

ts = TestServe(['./ok.tfrecord', './nok.tfrecord'])
ts.train()

但是我一开始训练,在第一批完成之前,我从 tensorflow 得到一个异常

2019-06-13 14:22:25.393398: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 1995445000 Hz
2019-06-13 14:22:25.393681: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x2f7d120 executing computations on platform Host. Devices:
2019-06-13 14:22:25.393708: I tensorflow/compiler/xla/service/service.cc:158] StreamExecutor device (0): <undefined>, <undefined>
Epoch 1/2
19/20 [===========================>..] - ETA: 0s - loss: 1.1921e-07 - acc: 1.0000Traceback (most recent call last):
File "TestServe.py", line 62, in <module>
ts.train()
File "TestServe.py", line 56, in train
epochs=2, verbose=1, callbacks=callbacks, steps_per_epoch=20) #The steps_per_epoch is typically samples_per_epoch / batch_size
File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 880, in fit
validation_steps=validation_steps)
File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 364, in model_iteration
validation_in_fit=True)
File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 202, in model_iteration
steps_per_epoch)
File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 76, in _get_num_samples_or_steps
'steps_per_epoch')
File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_utils.py", line 230, in check_num_samples
if check_steps_argument(ins, steps, steps_name):
File "/home/josef/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_utils.py", line 960, in check_steps_argument
input_type=input_type_str, steps_name=steps_name))
ValueError: When using data tensors as input to a model, you should specify the `steps_per_epoch` argument.

原始数据集包含大约 1500 个样本,但我想将多个 tfrecord 文件连接到 TFRecordDataset,所以我没有关于长度的信息。

有人见过类似的东西吗?我不知道去哪里寻求帮助,因为 tf.keras API 相对较新。 create_dataset 函数只返回用正确的解析函数映射的数据集。

最佳答案

找到解决方案。

不仅有steps_per_epoch还有validation_steps参数,你也必须指定。

关于python - TF keras API with TF数据集问题-steps_per_epoch参数问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56580538/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com