gpt4 book ai didi

tensorflow - Keras 自定义损失函数打印张量值

转载 作者:行者123 更新时间:2023-11-30 09:15:06 29 4
gpt4 key购买 nike

我尝试实现像 yolo 这样的对象检测器。它使用复杂的自定义损失函数。所以我需要打印/调试它的张量。据我了解,python 代码仅构建计算图,因此标准打印在非急切模式下无法工作。 tensorflow 1.12.0喀拉斯2.2.4

我尝试了这些帖子中的所有方法 Keras custom loss function not printing value of tensor , Debugging keras tensor values没有任何效果。我尝试了 tf.Print、tf.print、callback、K.tensor_print - 相同的结果。在控制台中我只看到标准输出消息。我什至不确定是否调用了损失函数。这篇文章的答案Keras - printing intermediate tensors in loss function (tf.Print and K.print_tensor do not work...)说损失函数有时甚至不被调用!好的,但是如何使用 tf.contrib.eager.defun 装饰器呢?该示例是纯 tensorflow 的,不明白如何在keras中使用它。

import tensorflow as tf
from keras.datasets import fashion_mnist
import matplotlib.pyplot as plt
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers import Flatten, Dense, Dropout
from keras.models import Sequential
from keras import optimizers
import numpy as np
from random import randrange
from keras.callbacks import LambdaCallback
import keras.backend as K
import keras

print(tf.__version__)
print(keras.__version__)


num_filters = 64

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

#reshape
x_train = x_train.reshape(60000,28,28,1)[:1000,...]
x_test = x_test.reshape(10000,28,28,1)[:100,...]

# One-hot encode the labels
y_train = tf.keras.utils.to_categorical(y_train, 10)[:1000,...]
y_test = tf.keras.utils.to_categorical(y_test, 10)[:100,]


labels = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']


model = Sequential()
model.add(Conv2D(input_shape=(28,28,1), filters=num_filters,kernel_size=3,strides=(1, 1),padding="valid", activation='relu', use_bias=True))
model.add(MaxPooling2D(pool_size=2))
model.add(Dropout(0.3))

model.add(Conv2D(filters=num_filters,kernel_size=3,strides=(1, 1),padding="valid", activation='relu', use_bias=True))
model.add(MaxPooling2D(pool_size=2))
model.add(Dropout(0.3))

model.add(Flatten())
model.add(Dense(256, activation = 'relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation = 'softmax'))


#model.summary()


#loss 1
def customLoss(yTrue,yPred):
d = yPred-yTrue
d = K.print_tensor(d)
return K.mean(K.square(d), axis=-1)

#loss 2
def cat_loss(y_true, y_pred):
d = y_true - y_pred
d = tf.Print(d, [d], "Inside loss function")
return tf.reduce_mean(tf.square(d))


model.compile(loss=customLoss,
optimizer='adam')



import keras.callbacks as cbks

# 3 try to print with callback
class CustomMetrics(cbks.Callback):
def on_epoch_end(self, epoch, logs=None):
for k in logs:
if k.endswith('cat_loss'):
print(logs[k])


#checkpointer = ModelCheckpoint(filepath='model.weights.best.hdf5', verbose = 1, save_best_only=True)
model.fit(x_train,
y_train,
#verbose=1,
batch_size=16,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[CustomMetrics()])


# Evaluate the model on test set
score = model.evaluate(x_test, y_test, verbose=0)

# Print test accuracy
print('\n', 'Test accuracy:', score)


rand_img = randrange(100)

result = np.argmax(model.predict(x_test[rand_img].reshape(1,28,28,1)))
plt.imshow(x_test[rand_img].reshape(28,28), cmap='gray')
plt.title(labels[result])
==========>......] - ETA: 0s - loss: 0.0243
832/1000 [=======================>......] - ETA: 0s - loss: 0.0242
Warning (from warnings module):
File "C:\Python36\lib\site-packages\keras\callbacks.py", line 122
% delta_t_median)
UserWarning: Method on_batch_end() is slow compared to the batch update (0.101474). Check your callbacks.

976/1000 [============================>.] - ETA: 0s - loss: 0.0238
992/1000 [============================>.] - ETA: 0s - loss: 0.0236
1000/1000 [==============================] - 3s 3ms/step - loss: 0.0239 - val_loss: 0.0352

Test accuracy: 0.035189545452594756```

最佳答案

真相就在附近。 Idle 不会将 tf.Print 输出,因此不会将 K.print_tensor() 输出到它的 shell,因此当我使用 cmd.exe python train.py 时,我看到了张量输出。

关于tensorflow - Keras 自定义损失函数打印张量值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57812402/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com