gpt4 book ai didi

python - Keras/ tensorflow : Get predictions or output of all layers efficiently

转载 作者:行者123 更新时间:2023-12-01 09:10:39 25 4
gpt4 key购买 nike

我能够按照 Keras Docs: how-can-i-obtain-the-output-of-an-intermediate-layer 中的建议获取所有层的输出/预测

def get_output_of_all_layers(model, test_input):
output_of_all_layers = []

for count, layer in enumerate(model.layers):

# skip the input layer
if count == 0:
continue

intermediate_layer_model = Model(inputs=model.input, outputs=model.get_layer(layer.name).output)
intermediate_output = intermediate_layer_model.predict(test_input)[0]

output_of_all_layers.append(intermediate_output)

return np.array(output_of_all_layers)

但这速度慢得令人难以置信,需要一分多钟(在 6700HQGTX1070 中,时钟为 ~65s,这高得离谱,对于大约 50 层的模型,推理发生在不到一秒的时间内......!)。我想这是因为它每次都会构建一个模型,将模型加载到内存中,传递输入并获取输出。显然,如果不从其他层获取结果,就无法获得最后一层的输出,如何像上面那样保存它们,而不必创建冗余模型(或以更快、更有效的方式)?

更新:我还注意到这并没有利用我的 GPU,这意味着所有的转换层都由 CPU 执行?为什么它不使用我的 GPU 来实现这个目的?我认为如果使用我的 GPU,花费的时间会少得多。

如何更有效地做到这一点?

最佳答案

按照 Ben Usman 的建议,您可以首先将模型包装在基本的端到端 Model 中,并将其层作为输出提供给第二个 Model:

import keras.backend as K
from keras.models import Model
from keras.layers import Input, Dense

input_layer = Input((10,))

layer_1 = Dense(10)(input_layer)
layer_2 = Dense(20)(layer_1)
layer_3 = Dense(5)(layer_2)

output_layer = Dense(1)(layer_3)

basic_model = Model(inputs=input_layer, outputs=output_layer)

# some random input
import numpy as np
features = np.random.rand(100,10)

# With a second Model
intermediate_model = Model(inputs=basic_model.layers[0].input,
outputs=[l.output for l in basic_model.layers[1:]])
intermediate_model.predict(features) # outputs a list of 4 arrays

或者,您可以以类似的方式使用 Keras 函数:

# With a Keras function
get_all_layer_outputs = K.function([basic_model.layers[0].input],
[l.output for l in basic_model.layers[1:]])

layer_output = get_all_layer_outputs([features]) # return the same thing

关于python - Keras/ tensorflow : Get predictions or output of all layers efficiently,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51677631/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com