gpt4 book ai didi

python - Tensorflow 数据集是否会在历元之间进行洗牌,并且数据集在洗牌后会进行转换?

转载 作者:太空宇宙 更新时间:2023-11-03 20:39:14 24 4
gpt4 key购买 nike

我正在开发一个 TensorFlow 管道,将一堆信号加载到数据集中,对这些信号进行洗牌,然后对信号进行加窗,然后进行批处理并重复。该数据集用于通过 model.fit 函数调用来训练 tf.keras 模型。信号窗口不被打乱非常重要,这就是为什么这是数据集变换的顺序。

我想知道信号的顺序是否会在纪元之间打乱?我发现 dataset.shuffle().batch().repeat() 会在纪元之间对数据集进行随机播放,但这对我的应用程序不起作用,因为我需要进行窗口和其他转换洗牌后。

我使用的是 TensorFlow 版本 1.13.1。

#... some pre-processing on the signals 
signalList = [...] # a list of tuples (data, label)
dataset = tf.data.Dataset.from_generator(lambda: signalList)
dataset = dataset.shuffle(buffer_size=self.buffer_size) ## will this shuffle be repeated??
dataset = dataset.map(...) # windowing and other transforms
dataset = dataset.batch()
dataset = dataset.repeat()

model.fit(dataset, ...)

编辑:我感兴趣的行为是我希望每个时期重新调整信号的顺序。所以,如果我有 3 个信号

signal0=[window0_0,window0_1]
signal1=[window1_0,window1_1,window1_2]
signal2=[window2_0]

那么输出将如下所示:

tf.Tensor([signal0,signal2,signal1],...) # equivalent to tf.Tensor([window0_0,window0_1,window2_0,window1_0,window1_1,window1_2])
tf.Tensor([signal1,signal0,signal2],...) # equivalent to tf.Tensor([window1_0,window1_1,window1_2,window0_0,window0_1,window2_0])

其中转换 datset.map(windowing).shuffle().batch().repeat() 会产生类似这样的东西(我对此不感兴趣)

tf.Tensor([window0_1,window1_1,window2_0,window1_0,window0_0,window1_2])
tf.Tensor([window0_0,window1_2,window0_1,window2_0,window1_1,window1_0])

最佳答案

经过一番调查,我意识到,是的,shuffle 会在每个 epoch 之后调用,即使在 shuffle 之后和批处理之前还有其他转换。我不确定这对管道意味着什么(例如,我不确定是否在每个时期都会调用窗口并减慢处理速度),但我创建了一个 jupyter 笔记本,在其中创建了一个小版本的管道

signalList = [...] # a list of tuples (data, label)
dataset = tf.data.Dataset.from_generator(lambda: signalList)
dataset = dataset.shuffle(buffer_size=self.buffer_size)
dataset = dataset.map(...) # windowing and other transforms
dataset = dataset.batch()
dataset = dataset.repeat()

创建了一个迭代器

iterator = dataset.make_one_shot_iterator()

并绘制了几个时期的信号

next_ = iterator.get_next()
for i in range(10): # 10 epochs
full_signal = []
for j in range(29): # 29 events for this epoch
next_ = iterator.get_next()
full_signal = np.concatenate((full_signal, next_[0][0]), axis=None)

fig = plt.figure(figsize=(18, 5))
plt.plot(full_signal)

并且发现信号看起来总是处于不同的顺序,这意味着它们在每个纪元之后都会重新洗牌。

如果有人有更详细的答案,他们可以解释它如何与 DatasetAPI 编译一起工作,或者如果他们可以澄清这些转换的顺序是否会减慢管道速度,我将不胜感激!

关于python - Tensorflow 数据集是否会在历元之间进行洗牌,并且数据集在洗牌后会进行转换?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56960577/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com