gpt4 book ai didi

python - 传递无限重复的数据集时,必须指定 `steps_per_epoch` 参数

转载 作者:行者123 更新时间:2023-12-04 12:25:05 24 4
gpt4 key购买 nike

我正在尝试使用这个 Google 的示例,但使用我自己的数据集:

https://github.com/tensorflow/examples/blob/master/tensorflow_examples/lite/model_customization/demo/text_classification.ipynb

我创建了一个文件夹,类似于在他们的代码中下载的内容,其中包含训练和测试文件夹以及 txt 文件。

在我的情况下 data_path 如下:data_path = '/Users/developer/.keras/datasets/chat'
每当我尝试运行它时 model = text_classifier.create(train_data)抛出错误ValueError: When passing an infinitely repeating dataset, you must specify the `steps_per_epoch` argument.这甚至意味着什么,我应该在哪里寻找问题?


import numpy as np
import os
import tensorflow as tf
assert tf.__version__.startswith('2')

from tensorflow_examples.lite.model_customization.core.data_util.text_dataloader import TextClassifierDataLoader
from tensorflow_examples.lite.model_customization.core.model_export_format import ModelExportFormat
import tensorflow_examples.lite.model_customization.core.task.text_classifier as text_classifier


# data_path = tf.keras.utils.get_file(
# fname='aclImdb',
# origin='http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz',
# untar=True)

data_path = '/Users/developer/.keras/datasets/chat'

train_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'train'), class_labels=['greeting', 'goodbye'])
test_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'test'), shuffle=False)

model = text_classifier.create(train_data)
loss, acc = model.evaluate(test_data)
model.export('movie_review_classifier.tflite', 'text_label.txt', 'vocab.txt')

最佳答案

问题是,当您针对所需的时期数训练模型时,您的训练代码部分可能无法确定特定时期何时开始以及该时期何时结束。
因此,在训练期间,可以添加“steps_per_epoch”参数,以便它知道如何针对单个时期的特定有限步数进行操作和训练。
在验证的情况下,我们添加特定的“validation_steps”来解决相同的问题。
model.fit
我通过将 steps_per_epoch 和 validation_steps 参数添加到我的 tf.Keras model.fit() 代码中解决了这个问题,就像上面一样。
需要总结如何在代码中提供这些参数。
引用:

  • https://keras.io/api/models/model_training_apis/#fit-method
  • https://www.pyimagesearch.com/2018/12/24/how-to-use-keras-fit-and-fit_generator-a-hands-on-tutorial/
  • 关于python - 传递无限重复的数据集时,必须指定 `steps_per_epoch` 参数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58842925/

    24 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com