gpt4 book ai didi

python - 如何将 perreplica 转换为张量?

转载 作者:行者123 更新时间:2023-12-03 20:14:53 25 4
gpt4 key购买 nike

在 tensorflow2.0 中使用多 GPU 进行训练时,perreplica 将通过以下代码减少:

strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

但是,如果我只想收集(没有“总和减少”或“平均减少”)所有 gpu 的预测到张量中:

per_replica_losses, per_replica_predicitions = strategy.experimental_run_v2(train_step, args=(dataset_inputs,))
# how to convert per_replica_predicitions to a tensor ?

最佳答案

总之,您可以转换PerReplica结果是这样的张量元组:

tensors_tuple = per_replica_predicitions.values

返回 tensors_tuple将是 predictions 的元组从每个副本/设备:
(predicton_tensor_from_dev0, prediction_tensor_from_dev1,...)

此元组中的元素数量由分布式策略可用的设备决定。特别地,如果策略在单个副本/设备上运行,则 strategy.experimental_run_v2 的返回值将与直接调用 train_step 函数相同(张量或张量列表由您的 train_step 决定)。所以你可能想写这样的代码:
per_replica_losses, per_replica_predicitions = strategy.experimental_run_v2(train_step, args=(dataset_inputs,))

if strategy.num_replicas_in_sync > 1:
predicition_tensors = per_replica_predicitions.values
else:
predicition_tensors = per_replica_predicitions
PerReplica是一个封装了分布式运行结果的类对象。你可以找到它的定义 here ,还有更多的属性/方法供我们操作 PerReplica目的。

关于python - 如何将 perreplica 转换为张量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57549448/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com