gpt4 book ai didi

python - 有效地向模型中的所有可训练权重添加噪声

转载 作者:行者123 更新时间:2023-12-04 03:50:14 26 4
gpt4 key购买 nike

如何在 tf.fuction 中有效访问 Keras 模型的所有可训练变量将自定义噪声添加到 全部 变量?
让我们假设这个简单的模型:

my_model = Sequential()
my_model.add(Dense(300, input_dim=40, activation='relu'))
my_model.add(Dense(200, activation='relu'))
my_model.add(Dense(8, activation='sigmoid'))
热切地运行,我可以通过以下方式做到这一点:
@tf.function
def weight_perturbation(model, generator):
n_layers = len(model.layers)

# iterate over all layers
for i in tf.range(n_layers):
trainable_weights = model.layers[i].trainable_variables

# iterate over all weight vectors in a layer
for j in tf.range(len(trainable_weights)):
trainable_weights[j].assign_add(generator.normal(tf.shape(trainable_weights[j]), dtype=tf.float64))

但是,当不急切地运行时,我收到以下错误:
     trainable_weights = model.layers[i].trainable_variables

TypeError: list indices must be integers or slices, not Tensor
我怎样才能规避这个问题?我看了 tf.gather但这似乎不适用于列表。
最好的是,如果有一种方法可以将模型的所有可训练权重作为扁平张量而不循环。不幸的是,我还没有找到这样的东西。

最佳答案

如果我遍历层和权重而不是遍历 tf.range,它对我有用:

def weight_perturbation(model):
for layer in model.layers:
trainable_weights = layer.trainable_variables

for weight in trainable_weights :
random_weights = tf.random.uniform(tf.shape(weight),
1e-4, 1e-5, dtype=tf.float32)
weight.assign_add(random_weights)
我将您的数据类型更改为 tf.float32 ,在大多数情况下应该是这样。我在这里添加了重量操作:
import tensorflow as tf
from tensorflow import keras as K
from tensorflow.keras.layers import Conv2D, Flatten, Dense, MaxPooling2D, Dropout
from tensorflow import nn as nn
from functools import partial

(xtrain, ytrain), (xtest, ytest) = tf.keras.datasets.mnist.load_data()

train = tf.data.Dataset.from_tensor_slices((xtrain, ytrain))
test = tf.data.Dataset.from_tensor_slices((xtest, ytest))

def prepare(inputs, outputs):
inputs = tf.cast(inputs, tf.float32)
inputs = tf.divide(x=inputs, y=255)
inputs = tf.expand_dims(inputs, -1)
targets = tf.one_hot(outputs, depth=10)
return inputs, targets

train = train.map(prepare).batch(64)
test = test.map(prepare).batch(64)

class MyCNN(K.Model):
def __init__(self):
super(MyCNN, self).__init__()
Conv = partial(Conv2D, kernel_size=(3, 3), activation=nn.relu)
MaxPool = partial(MaxPooling2D, pool_size=(2, 2))

self.conv1 = Conv(filters=8)
self.maxp1 = MaxPool()
self.conv2 = Conv(filters=8)
self.maxp2 = MaxPool()
self.flatt = Flatten()
self.dens1 = Dense(8, activation=nn.relu)
self.drop1 = Dropout(.5)
self.dens2 = Dense(10, activation=nn.softmax)

def call(self, x, training=None, **kwargs):
x = self.conv1(x)
x = self.maxp1(x)
x = self.conv2(x)
x = self.maxp2(x)
x = self.flatt(x)
x = self.dens1(x)
x = self.drop1(x)
x = self.dens2(x)
return x

model = MyCNN()

loss_object = tf.losses.CategoricalCrossentropy(from_logits=False)

def compute_loss(model, x, y, training):
out = model(inputs=x, training=training)
loss = loss_object(y_true=y, y_pred=out)
return loss, out

def get_grad(model, x, y):
with tf.GradientTape() as tape:
loss, out = compute_loss(model, x, y, training=False)
return loss, tape.gradient(loss, model.trainable_variables), out

def weight_perturbation(model):
for layer in model.layers:
trainable_weights = layer.trainable_variables

for weight in trainable_weights :
random_weights = tf.random.uniform(tf.shape(weight),
1e-4, 1e-5, dtype=tf.float32)
weight.assign_add(random_weights)

optimizer = tf.optimizers.Adam()

verbose = "Epoch {:2d} Loss: {:.3f} TLoss: {:.3f} Acc: {:.3%} TAcc: {:.3%}"

for epoch in range(1, 10 + 1):
train_loss = tf.metrics.Mean()
train_acc = tf.metrics.CategoricalAccuracy()
test_loss = tf.metrics.Mean()
test_acc = tf.metrics.CategoricalAccuracy()

weight_perturbation(model)

for x, y in train:
loss_value, grads, out = get_grad(model, x, y)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
train_loss.update_state(loss_value)
train_acc.update_state(y, out)

for x, y in test:
loss_value, _, out = get_grad(model, x, y)
test_loss.update_state(loss_value)
test_acc.update_state(y, out)

print(verbose.format(epoch,
train_loss.result(),
test_loss.result(),
train_acc.result(),
test_acc.result()))

关于python - 有效地向模型中的所有可训练权重添加噪声,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64542231/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com