gpt4 book ai didi

python - Tensorflow 序列扩展的动态 __len__?

转载 作者:行者123 更新时间:2023-12-03 20:53:50 26 4
gpt4 key购买 nike

基本上,我想要一个生成器在每个时期重新读取魔术方法 _len__ 以重新计算该时期将完成多少批次。

下面是一段代码:

import tensorflow as tf
import numpy as np

class GeneratorFile(tf.keras.utils.Sequence):
def __init__(self, file_list):
self.file_list = file_list
self.desired_file = self.file_list[0]
print('This should be file 1:', self.desired_file)
def __len__(self):
if self.desired_file == 'file1':
return 2
else:
return 3
def on_epoch_end(self):
self.desired_file = self.file_list[1]
print('This should be file 2:', self.desired_file)
def __getitem__(self, item):
return np.zeros((16, 1)), np.zeros((16,))

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(1, input_dim=1, activation="softmax"))
model.compile(
optimizer='Adam',
loss='binary_crossentropy',
metrics=['accuracy']
)

model_file_train = ['file1', 'file2']
generator_train = GeneratorFile(model_file_train)
model.fit(generator_train, epochs=2, initial_epoch = 0)

在 len 魔术方法中,我想在文件更改时切换该时代中的批次数,它在第一个时代结束时进行。但是,目前, len 在训练过程开始时运行一次,然后再也不会运行。有什么可以改变这一点吗?

最佳答案

我尝试用简单的代码来查看 __len__每个程序或每个时期调用一次函数。事实证明,每个 epoch 都会多次调用它。已经给出了不同版本的 tensorflow 的输出。

代码 -

import numpy as np
import tensorflow as tf
print(tf.__version__)
from tensorflow.keras import layers, models
from tensorflow.keras.utils import Sequence

FEATURE_SIZE = 512 ** 2

class DataGenerator(Sequence):
def __init__(self, batch_size):
self.batch_size = batch_size
def __len__(self):
print("in __len__")
return 1
def __getitem__(self, i):
return np.ones((self.batch_size, FEATURE_SIZE)), np.ones((self.batch_size, 1)) # Some dummy data
def on_epoch_end(self):
print('on_epoch_end() called')

def train(batch_size):
print('Training with batch_size =', batch_size)
training_generator = DataGenerator(batch_size)
test_generator = DataGenerator(batch_size)
model = models.Sequential()
model.add(layers.Dense(4, activation='sigmoid', input_shape=[FEATURE_SIZE]))
model.add(layers.Dense(1, activation='sigmoid', input_shape=[FEATURE_SIZE]))
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.fit(training_generator, validation_data=test_generator, epochs=5, verbose=0 )

train(batch_size=1)

输出 - 对于 tensorflow 版本 1.15.2 .
1.15.2
Training with batch_size = 1
in __len__
in __len__
on_epoch_end() called
in __len__
in __len__
in __len__
in __len__
on_epoch_end() called
in __len__
in __len__
on_epoch_end() called
in __len__
in __len__
on_epoch_end() called
in __len__
in __len__
on_epoch_end() called

输出 - 对于 tensorflow 版本 2.2.0 .
2.2.0
Training with batch_size = 1
in __len__
in __len__
in __len__
in __len__
in __len__
on_epoch_end() called
on_epoch_end() called
in __len__
in __len__
in __len__
in __len__
on_epoch_end() called
on_epoch_end() called
in __len__
in __len__
in __len__
in __len__
on_epoch_end() called
on_epoch_end() called
in __len__
in __len__
in __len__
in __len__
on_epoch_end() called
on_epoch_end() called
in __len__
in __len__
in __len__
in __len__
on_epoch_end() called
on_epoch_end() called

同样可以请你试试看。

关于python - Tensorflow 序列扩展的动态 __len__?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61733161/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com