gpt4 book ai didi

C++调用和定义不匹配

转载 作者:太空宇宙 更新时间:2023-11-03 17:24:28 25 4
gpt4 key购买 nike

我正在查看此代码块 https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/profiler.cpp#L141

pushCallback(
[config](const RecordFunction& fn) {
auto* msg = (fn.seqNr() >= 0) ? ", seq = " : "";
if (config.report_input_shapes) {
std::vector<std::vector<int64_t>> inputSizes;
inputSizes.reserve(fn.inputs().size());
for (const c10::IValue& input : fn.inputs()) {
if (!input.isTensor()) {
inputSizes.emplace_back();
continue;
}
const at::Tensor& tensor = input.toTensor();
if (tensor.defined()) {
inputSizes.push_back(input.toTensor().sizes().vec());
} else {
inputSizes.emplace_back();
}
}
pushRangeImpl(fn.name(), msg, fn.seqNr(), std::move(inputSizes));
} else {
pushRangeImpl(fn.name(), msg, fn.seqNr(), {});
}
},
[](const RecordFunction& fn) {
if (fn.getThreadId() != 0) {
// If we've overridden the thread_id on the RecordFunction, then find
// the eventList that was created for the original thread_id. Then,
// record the end event on this list so that the block is added to
// the correct list, instead of to a new list. This should only run
// when calling RecordFunction::end() in a different thread.
if (state == ProfilerState::Disabled) {
return;
} else {
std::lock_guard<std::mutex> guard(all_event_lists_map_mutex);
const auto& eventListIter =
all_event_lists_map.find(fn.getThreadId());
TORCH_INTERNAL_ASSERT(
eventListIter != all_event_lists_map.end(),
"Did not find thread_id matching ",
fn.getThreadId());

auto& eventList = eventListIter->second;
eventList->record(
EventKind::PopRange,
StringView(""),
fn.getThreadId(),
state == ProfilerState::CUDA);
}
} else {
popRange();
}
},
config.report_input_shapes);

这只有三个参数。但是pushCallback的定义好像是在这个位置 https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/record_function.cpp#L35并接受四个参数。

void pushCallback(
RecordFunctionCallback start,
RecordFunctionCallback end,
bool needs_inputs,
bool sampled) {
start_callbacks.push_back(std::move(start));
end_callbacks.push_back(std::move(end));
if (callback_needs_inputs > 0 || needs_inputs) {
++callback_needs_inputs;
}
is_callback_sampled.push_back(sampled);
if (sampled) {
++num_sampled_callbacks;
}
}

我不知道为什么那个函数调用可以那样工作。

最佳答案

如果您查看 header ,您会发现它使用 4 个参数声明,其中最后三个参数具有默认值:

TORCH_API void pushCallback(
RecordFunctionCallback start,
RecordFunctionCallback end = [](const RecordFunction&){},
bool needs_inputs = false,
bool sampled = false);

默认参数只出现在声明中而不出现在定义中。

关于C++调用和定义不匹配,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59722586/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com