gpt4 book ai didi

python - PyTorch 模块如何做 back prop

转载 作者:太空狗 更新时间:2023-10-30 00:58:42 24 4
gpt4 key购买 nike

同时按照 extending PyTorch - adding a module 上的说明进行操作,我在扩展 Module 时注意到,我们真的不必实现向后功能。我们唯一需要做的就是在 forward 函数中应用 Function 实例,PyTorch 可以在执行 back prop 时自动调用 Function 实例中的 backward。这对我来说似乎很神奇,因为我们甚至没有注册我们使用的 Function 实例。我查看了源代码,但没有找到任何相关内容。任何人都可以指出所有这些实际发生的地方吗?

最佳答案

不必实现 backward() 是 PyTorch 或任何其他 DL 框架如此有值(value)的原因。事实上,实现 backward() 只应在需要扰乱网络梯度的非常特殊的情况下进行(或者当您创建无法使用 PyTorch 的内置函数表达的自定义函数时)函数)。

PyTorch 使用计算图计算向后梯度,该计算图跟踪在前向传递期间完成的操作。对 Variable 执行的任何操作都隐式地在此处注册。然后就是从调用它的变量向后遍历图,并应用导数链规则来计算梯度。

PyTorch 的 About页面有一个很好的图表可视化以及它通常是如何工作的。如果您需要更多详细信息,我还建议您在 Google 上查找计算图和自动分级机制。

编辑:所有这一切发生的源代码将在 PyTorch 代码库的 C 部分中,实际图形在此处实现。经过一番挖掘,我发现了 this :

/// Evaluates the function on the given inputs and returns the result of the
/// function call.
variable_list operator()(const variable_list& inputs) {
profiler::RecordFunction rec(this);
if (jit::tracer::isTracingVar(inputs)) {
return traced_apply(inputs);
}
return apply(inputs);
}

因此,在每个函数中,PyTorch 首先检查其输入是否需要跟踪,然后执行 trace_apply() 作为实现 here .您可以看到正在创建并附加到图中的节点:

// Insert a CppOp in the trace.
auto& graph = state->graph;
std::vector<VariableFlags> var_flags;
for(auto & input: inputs) {
var_flags.push_back(VariableFlags::of(input));
}
auto* this_node = graph->createCppOp(get_shared_ptr(), std::move(var_flags));
// ...
for (auto& input: inputs) {
this_node->addInput(tracer::getValueTrace(state, input));
}
graph->appendNode(this_node);

我最好的猜测是,每个 Function 对象在执行时都会注册自己及其输入(如果需要)。每个非功能性调用(例如 variable.dot())都只是服从相应的函数,所以这仍然适用。

注意:我没有参与 PyTorch 的开发,也绝不是其架构方面的专家。欢迎任何更正或补充。

关于python - PyTorch 模块如何做 back prop,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49594858/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com