gpt4 book ai didi

tensorflow-federated - 如何用TFF做预测?

转载 作者:行者123 更新时间:2023-12-05 02:46:18 25 4
gpt4 key购买 nike

我的问题是:如何使用 Tensorflow Federated 预测此类图像的标签?

完成模型评估后,我想预测给定图像的标签。就像在 Keras 中一样,我们这样做:

# new instance where we do not know the answer
Xnew = array([[0.89337759, 0.65864154]])
# make a prediction
ynew = model.predict_classes(Xnew)
# show the inputs and predicted outputs
print("X=%s, Predicted=%s" % (Xnew[0], ynew[0]))

输出:

X=[0.89337759 0.65864154], Predicted=[0]

这里是 state 和 model_fn 是如何创建的:


def model_fn():
keras_model = create_compiled_keras_model()
return tff.learning.from_compiled_keras_model(keras_model, sample_batch)

iterative_process = tff.learning.build_federated_averaging_process(model_fn, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),client_weight_fn=None)
state = iterative_process.initialize()

我发现这个错误:

list(self._name_to_index.keys())[:10]))
AttributeError: The tuple of length 2 does not have named field "assign_weights_to". Fields (up to first 10): ['trainable', 'non_trainable']

谢谢

最佳答案

(需要 TFF 0.16.0 或更新版本)

由于代码是从 tf.keras.Model 构建 tff.learning.Model,您可以使用 assign_weights_to tff.learning.ModelWeights 上的方法对象(state.model 的类型)。此方法用于 Federated Learning for Text Generation教程。

这可能看起来像(靠近底部,早期部分是 FL 训练循环示例):


def create_keras_model() -> tf.keras.Model:
...

def model_fn():
...
return tff.learning.from_keras_model(create_keras_model())

training_process = tff.learning. build_federated_averaging_process(model_fn, ...)

state = training_process.initialize()
for _ in range(NUM_ROUNDS):
state, metrics = training_process.next(state, ...)

model_for_inference = create_keras_model()
state.model.assign_weights_to(model_for_inference)

一旦来自state 的权重被分配回 Keras 模型,代码就可以使用标准的 Keras API,例如 tf.keras.Model.predict_on_batch

predictions = model_for_inference.predict_on_batch(batch)

关于tensorflow-federated - 如何用TFF做预测?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65578020/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com