gpt4 book ai didi

sorting - 如何在 Keras 中使用 argsort

转载 作者:行者123 更新时间:2023-12-03 22:21:21 24 4
gpt4 key购买 nike

我学习 ML,我使用开源(blstm),我想添加一个 argsort 层,但 keras 没有 argsort 层。

我搜索了这个问题,人们说使用 Lambda 函数。
(我使用 Tensorflow 后端)
但我不知道如何定义函数以及如何使用它。

这是原始代码:

tagger = Dropout(self.dropout_ratio)(tagger)
prediction = TimeDistributed(Dense(self.output_vocab_size,
activation='softmax'))(tagger)
self.model = Model(input=raw_current, output=prediction)
self.model.compile(loss='categorical_crossentropy', optimizer=opt_func)

预测形状是 (?,48,224)。

训练后,当我测试这个模型时,我想从预测 [0][-1] 中获得一个 top_10 索引,所以我使用了一些排序代码。
这是我的测试代码。
prediction = self.model.predict(test_batch_data)
my_prediction = prediction[0][-1]
top10_my_prediction_idx = sorted(range(len(rule_prediction)), key=lambda k: rule_prediction[k] , reverse=True)[0:10]

我想在训练模型时在预测层之后添加一个 get_top_10 层。

像这个:
tagger = Dropout(self.dropout_ratio)(tagger)
prediction = TimeDistributed(Dense(self.output_vocab_size, activation='softmax'))(tagger)
**top10_prediction = Lambda(get_top10_prediction)**
self.model = Model(input=raw_current, output=prediction)
self.model.compile(loss='categorical_crossentropy', optimizer=opt_func)

如何定义 Lambda 函数以及如何使用它?

最佳答案

如果您使用的是 tensorflow 后端,则可以通过 tf.nn.top_k 对张量进行 arg 排序。功能与 sorted=True争论。一个 lambda 层看起来像这样:

def top_k(input, k):
# Can also use `.values` to return a sorted tensor
return tf.nn.top_k(input, k=k, sorted=True).indices

...
sorted = Lambda(top_k, arguments={'k': 10})(prediction)

这是一个可运行的测试:

import numpy as np
import tensorflow as tf

from keras import Sequential
from keras.layers import Lambda

def top_k(input, k):
return tf.nn.top_k(input, k=k, sorted=True).indices

model = Sequential()
model.add(Lambda(top_k, input_shape=(10,), arguments={'k': 10}))

data = np.array([
[0, 5, 2, 1, 3, 6, 1, 2, 7, 4],
[2, 4, 3, 1, 2, 0, 1, 5, 2, 4],
[8, 9, 1, 8, 3, 0, 1, 3, 2, 6],
])

print(model.predict(x=data))
# Prints:
# [[8 5 1 9 4 2 7 3 6 0]
# [7 1 9 2 0 4 8 3 6 5]
# [1 0 3 9 4 7 8 2 6 5]]

参数 k可以是小于您要排序的维度的任何值,例如,对于 k=4输出将是:

[[8 5 1 9]
[7 1 9 2]
[1 0 3 9]]

关于sorting - 如何在 Keras 中使用 argsort,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48374905/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com