gpt4 book ai didi

tensorflow - 在训练过程中是否每个时期都初始化/创建验证数据集?

转载 作者:行者123 更新时间:2023-12-04 10:00:45 25 4
gpt4 key购买 nike

设置:

  • U-Net 网络经过训练可以处理小块(例如 64x64 像素)。
  • 使用 Tensorflow Dataset API 为网络提供训练数据集和验证数据集。
  • 小补丁是通过采样(随机)更大的
    图片。
  • 图像块的采样发生在训练过程中
    (训练和验证图像补丁都是即时裁剪的)。
  • Tensorflow 2.1(急切执行模式)

  • 训练和验证数据集是相同的:
    dataset = tf.data.Dataset.from_tensor_slices((large_images, large_targets))
    dataset = dataset.shuffle(buffer_size=num_large_samples)
    dataset = dataset.map(get_patches_from_large_images, num_parallel_calls=num_parallel_calls)
    dataset = dataset.unbatch()
    dataset = dataset.shuffle(buffer_size=num_small_patches)
    dataset = dataset.batch(patches_batch_size)
    dataset = dataset.prefetch(1)
    dataset = dataset.repeat()

    功能 get_patches_from_large_images使用 tf.image.random_crop 从单个大图像中采样预定义数量的小块.有两个嵌套循环 forwhile .外环 for负责生成预定义数量的小补丁和 while用于检查是否使用 tf.image.random_crop 随机生成的补丁满足一些预定义的标准(例如,应丢弃仅包含背景的补丁)。内循环 while如果在某些预定义的迭代次数中无法生成适当的补丁,则放弃,因此我们不会陷入此循环。此方法基于提供的解决方案 here .
    for i in range(number_of_patches_from_one_large_image):
    num_tries = 0
    patches = []
    while num_tries < max_num_tries_befor_giving_up:
    patch = tf.image.random_crop(large_input_and_target_image,[patch_size, patch_size, 2])
    if patch_meets_some_criterions:
    break
    num_tries = num_tries + 1
    patches.append(patch)

    实验:
  • 用于提供模型的训练和验证数据集是相同的(5 大对输入目标图像),两个数据集从单个大图像中产生完全相同数量的小块
  • 用于训练和验证的 batch_size 相同,等于 50 个图像块,
  • steps_per_epochvalidation_steps相等(20 批)

  • validation_freq=5 运行训练时
    unet_model.fit(dataset_train, epochs=10, steps_per_epoch=20, validation_data = dataset_val, validation_steps=20, validation_freq=5)


    Train for 20 steps, validate for 20 steps
    Epoch 1/10
    20/20 [==============================] - 44s 2s/step - loss: 0.6771 - accuracy: 0.9038
    Epoch 2/10
    20/20 [==============================] - 4s 176ms/step - loss: 0.4952 - accuracy: 0.9820
    Epoch 3/10
    20/20 [==============================] - 4s 196ms/step - loss: 0.0532 - accuracy: 0.9916
    Epoch 4/10
    20/20 [==============================] - 4s 194ms/step - loss: 0.0162 - accuracy: 0.9942
    Epoch 5/10
    20/20 [==============================] - 42s 2s/step - loss: 0.0108 - accuracy: 0.9966 - val_loss: 0.0081 - val_accuracy: 0.9975
    Epoch 6/10
    20/20 [==============================] - 1s 36ms/step - loss: 0.0074 - accuracy: 0.9978
    Epoch 7/10
    20/20 [==============================] - 4s 175ms/step - loss: 0.0053 - accuracy: 0.9985
    Epoch 8/10
    20/20 [==============================] - 3s 169ms/step - loss: 0.0034 - accuracy: 0.9992
    Epoch 9/10
    20/20 [==============================] - 3s 171ms/step - loss: 0.0023 - accuracy: 0.9995
    Epoch 10/10
    20/20 [==============================] - 43s 2s/step - loss: 0.0016 - accuracy: 0.9997 - val_loss: 0.0013 - val_accuracy: 0.9998

    我们可以看到第一个 epoch 和有验证的 epoch(每 5 个 epoch)比没有验证的 epoch 花费的时间要多得多。相同的实验,但这次验证在每个 epoch 中运行,结果如下:
    history = unet_model.fit(dataset_train, epochs=10, steps_per_epoch=20, validation_data = dataset_val, validation_steps=20)
    Train for 20 steps, validate for 20 steps
    Epoch 1/10
    20/20 [==============================] - 84s 4s/step - loss: 0.6775 - accuracy: 0.8971 - val_loss: 0.6552 - val_accuracy: 0.9542
    Epoch 2/10
    20/20 [==============================] - 41s 2s/step - loss: 0.5985 - accuracy: 0.9833 - val_loss: 0.4677 - val_accuracy: 0.9951
    Epoch 3/10
    20/20 [==============================] - 43s 2s/step - loss: 0.1884 - accuracy: 0.9950 - val_loss: 0.0173 - val_accuracy: 0.9948
    Epoch 4/10
    20/20 [==============================] - 44s 2s/step - loss: 0.0116 - accuracy: 0.9962 - val_loss: 0.0087 - val_accuracy: 0.9969
    Epoch 5/10
    20/20 [==============================] - 44s 2s/step - loss: 0.0062 - accuracy: 0.9979 - val_loss: 0.0051 - val_accuracy: 0.9983
    Epoch 6/10
    20/20 [==============================] - 45s 2s/step - loss: 0.0039 - accuracy: 0.9989 - val_loss: 0.0033 - val_accuracy: 0.9991
    Epoch 7/10
    20/20 [==============================] - 44s 2s/step - loss: 0.0025 - accuracy: 0.9994 - val_loss: 0.0023 - val_accuracy: 0.9995
    Epoch 8/10
    20/20 [==============================] - 44s 2s/step - loss: 0.0019 - accuracy: 0.9996 - val_loss: 0.0017 - val_accuracy: 0.9996
    Epoch 9/10
    20/20 [==============================] - 44s 2s/step - loss: 0.0014 - accuracy: 0.9997 - val_loss: 0.0013 - val_accuracy: 0.9997
    Epoch 10/10
    20/20 [==============================] - 45s 2s/step - loss: 0.0012 - accuracy: 0.9998 - val_loss: 0.0011 - val_accuracy: 0.9998

    问题:
    在第一个示例中,我们可以看到训练数据集 (dataset_train) 的初始化/创建花费了大约 40 秒。然而,随后的 epochs(没有验证)更短,大约需要 4 秒。尽管如此,对于具有验证步骤的纪元,持续时间再次延长至约 40 秒。验证数据集 (dataset_val) 与训练数据集 (datasat_train) 完全相同,因此其创建/初始化过程大约需要 40 秒。但是,令我惊讶的是,每个验证步骤都非常耗时。我预计第一次验证需要 40 秒,但下一次验证应该需要大约 4 秒。我认为验证数据集的行为会像训练数据集一样,因此第一次获取将花费很长时间,但随后的时间应该会短得多。我是对的还是我遗漏了什么?

    更新:
    我已经检查过从数据集创建迭代器需要大约 40 秒
    dataset_val_it = iter(dataset_val) #40s

    如果我们看一下 fit函数,我们会看到 data_handler对象 is created once对于整个培训,它 returns the data iterator用于训练过程的主循环。迭代器是通过调用函数 enumerate_epochs 创建的。 .当 fit 函数想要执行验证过程时,它 calls the evaluate function .每当 evaluate函数被调用 it creates new data_handler object .然后 it calls enumerate_epochs function反过来从数据集创建迭代器。不幸的是,在复杂数据集的情况下,这个过程非常耗时。

    最佳答案

    如果你只想快速修复来加速你的输入管道,你可以尝试 caching the elements of the validation dataset .

    If we look inside the fit function, we will see that data_handler object is created once for the whole training, and it returns the data iterator that is used in the main loop of the training process. The iterator is created by calling the function enumerate_epochs. When the fit function wants to perform the validation process, it calls the evaluate function. Whenever evaluate function is called it creates new data_handler object. And then it calls enumerate_epochs function what in turn creates the iterator from the dataset. Unfortunately, in the case of complicated datasets, this process is time-consuming.



    我从来没有在 tf.data 中挖得很深代码,但你似乎在这里提出了一个观点。我认为在 Github 上为此打开一个问题会很有趣。

    关于tensorflow - 在训练过程中是否每个时期都初始化/创建验证数据集?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61835616/

    25 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com