gpt4 book ai didi

python - tf.data.Dataset.map() 用于由多个切片组成的数据集

转载 作者:太空宇宙 更新时间:2023-11-03 20:18:47 25 4
gpt4 key购买 nike

从单个切片创建的数据集的 tf.data.Dataset.map() 看起来像 dataset.map(lambda x: x/2)。如果数据集是由两个切片创建的,会是什么样子?例如,请参见以下代码。代码最后一行中的 map() 函数适用于从单个切片创建的数据集,但对于我的两切片情况会导致错误。

import tensorflow as tf, numpy as np     # tensorflow 2.0
from tensorflow import keras as kr

dataset = tf.data.Dataset.from_tensor_slices((features_int8, labels_int8)) # features, labels are numpy arrays

model = kr.Sequential()
model.add(kr.layers.InputLayer(6)
model.add(kr.layers.Dense( 8, activation=tf.nn.tanh))
model.add(kr.layers.Dense( 3, activation=tf.nn.tanh))

model.compile(optimizer = kr.optimizers.RMSprop(), loss = kr.losses.MeanSquaredError())

model.fit(dataset.batch(64).map(lambda x: x/9), epochs = 10)

最佳答案

将 lambda 函数传递到单独的函数中,如图所示

def map_fn(x, y):
return x / 9, y

model.fit(dataset.batch(64).map(map_fn), epochs = 10)

关于python - tf.data.Dataset.map() 用于由多个切片组成的数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58283724/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com