gpt4 book ai didi

python - 在 tensorflow.data.Dataset.map() 函数中完成的操作梯度

转载 作者:行者123 更新时间:2023-12-01 05:54:13 26 4
gpt4 key购买 nike

我有一个包含 X 和 Y 部分的数据集。 X 在输入到神经网络之前需要变成 D。

我使用 tf.data.Dataset类来做到这一点:

# Making the place holders
X = tf.placeholder(shape=[n_samples, n_atoms, 3], dtype=tf.float32)
Y = tf.placeholder(shape=[n_samples, 1], dtype=tf.float32)

# Creating the data set
dataset = tf.data.Dataset.from_tensor_slices((X, Y))

# Transforming X to D using the map function
dataset = dataset.map(X_to_D)
dataset = dataset.batch(200)
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
batch_D, batch_Y = iterator.get_next()

哪里函数 X_to_D是一个采用 X 的 tensorflow 函数和 Y张量作为输入并返回 DY张量。
D然后分批拆分并用作神经网络的输入。神经网络的输出是 Y_prediction .

我需要获得 Y_prediction 的梯度关于 X .但是,在尝试时:
gradients = tf.gradients(Y_prediction, X)

出现错误:

LookupError: gradient registry has no entry for: IteratorGetNext LookupError: No gradient defined for operation 'IteratorGetNext_1' (op type: IteratorGetNext)



题:
看来很容易得到 Y_prediction的梯度关于 D .但是,我将如何计算 Y_prediction 的梯度关于 X ?

笔记: X_to_D函数是非常内存密集的,只能在非常小批量的数据上完成。所以我无法创建数据集,分批拆分并从 X 进行转换。至 D就在每个批次用于训练之前。这是因为用于训练的批量大小对于制作 X 来说太大了。至 D转型。

最佳答案

使用 tensorflow 2.0,您可以编写自定义模型,这允许您计算导数 w.r.t.输入。例如,

class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel,self).__init__(name = 'my_model')
self.dense_1 = layers.Dense(32,activation = 'relu', input_dim=2)
self.dense_2 = layers.Dense(64,activation=tf.sin)
self.dense_3 = layers.Dense(1)
def call(self, inputs):
# Define your forward pass here
x = self.dense_1(inputs)
x = self.dense_2(x)
return self.dense_3(x)

model = MyModel()
optimizer = tf.keras.optimizers.RMSprop(0.001)
model.compile(loss='mse',
optimizer=optimizer,
metrics=['mae', 'mse'])
history = model.fit(X_train, y_train, epochs=10, batch_size = 1,
validation_split = 0.2, verbose=0)

计算导数:
x = tf.constant(X_train[:1,:])
with tf.GradientTape() as g:
g.watch(x)
y = model.call(x)
dy_dx = g.gradient(y, x)
print(y)
dy_dx

关于python - 在 tensorflow.data.Dataset.map() 函数中完成的操作梯度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50784337/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com