gpt4 book ai didi

pytorch - 如何将 torch.utils.tensorboard 的 SummaryWriter add_graph 与字典输出一起使用?

转载 作者:行者123 更新时间:2023-12-05 07:15:12 25 4
gpt4 key购买 nike

我有一个继承自 nn.Module 的 PyTorch 模型并且有一个 forward返回包含多个张量(存储为值)的字典的方法。当我尝试使用 torch.utils.tensorboard 将此模型记录到 TensorBoard 时的 SummaryWriter.add_graph()方法我收到以下错误:

RuntimeError("Tracer cannot infer type of <output dictionary>

如何使用 add_graph()当模型的 forward()方法返回字典?

编辑 1:完成错误堆栈跟踪:

WARNING:root:This caffe2 python run does not have GPU support. Will run in CPU only mode.
Only tensors or tuples of tensors can be output from traced functions (getOutput at /pytorch/torch/csrc/jit/tracer.cpp:208)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x33 (0x7fc661bd4273 in /usr/local/lib/python3.6/dist-packages/torch/lib/libc10.so)
frame #1: torch::jit::tracer::TracingState::getOutput(c10::IValue const&) + 0x3e8 (0x7fc664cf77b8 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch.so)
frame #2: torch::jit::tracer::exit(std::vector<c10::IValue, std::allocator<c10::IValue> > const&) + 0x3c (0x7fc664cf7aec in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x47b0f2 (0x7fc66e4820f2 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
frame #4: <unknown function> + 0x48f061 (0x7fc66e496061 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0x19da64 (0x7fc66e1a4a64 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
frame #6: _PyCFunction_FastCallDict + 0x35c (0x5674fc in /usr/bin/python3.6)
frame #7: /usr/bin/python3.6() [0x50abb3]
frame #8: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #9: <unknown function> + 0x18d10 (0x7fc6965ded10 in /home/rylan/.PyCharmCE2019.3/system/cythonExtensions/_pydevd_frame_eval_ext/pydevd_frame_evaluator.cpython-36m-x86_64-linux-gnu.so)
frame #10: /usr/bin/python3.6() [0x508245]
frame #11: /usr/bin/python3.6() [0x50a080]
frame #12: /usr/bin/python3.6() [0x50aa7d]
frame #13: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #14: <unknown function> + 0x18d10 (0x7fc6965ded10 in /home/rylan/.PyCharmCE2019.3/system/cythonExtensions/_pydevd_frame_eval_ext/pydevd_frame_evaluator.cpython-36m-x86_64-linux-gnu.so)
frame #15: /usr/bin/python3.6() [0x508245]
frame #16: /usr/bin/python3.6() [0x50a080]
frame #17: /usr/bin/python3.6() [0x50aa7d]
frame #18: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #19: <unknown function> + 0x18d10 (0x7fc6965ded10 in /home/rylan/.PyCharmCE2019.3/system/cythonExtensions/_pydevd_frame_eval_ext/pydevd_frame_evaluator.cpython-36m-x86_64-linux-gnu.so)
frame #20: /usr/bin/python3.6() [0x508245]
frame #21: /usr/bin/python3.6() [0x50a080]
frame #22: /usr/bin/python3.6() [0x50aa7d]
frame #23: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #24: <unknown function> + 0x18d10 (0x7fc6965ded10 in /home/rylan/.PyCharmCE2019.3/system/cythonExtensions/_pydevd_frame_eval_ext/pydevd_frame_evaluator.cpython-36m-x86_64-linux-gnu.so)
frame #25: /usr/bin/python3.6() [0x508245]
frame #26: /usr/bin/python3.6() [0x50a080]
frame #27: /usr/bin/python3.6() [0x50aa7d]
frame #28: _PyEval_EvalFrameDefault + 0x1220 (0x50d390 in /usr/bin/python3.6)
frame #29: <unknown function> + 0x18d10 (0x7fc6965ded10 in /home/rylan/.PyCharmCE2019.3/system/cythonExtensions/_pydevd_frame_eval_ext/pydevd_frame_evaluator.cpython-36m-x86_64-linux-gnu.so)
frame #30: /usr/bin/python3.6() [0x508245]
frame #31: _PyFunction_FastCallDict + 0x2e2 (0x509642 in /usr/bin/python3.6)
frame #32: /usr/bin/python3.6() [0x595311]
frame #33: /usr/bin/python3.6() [0x54a6ff]
frame #34: /usr/bin/python3.6() [0x551b81]
frame #35: _PyObject_FastCallKeywords + 0x19c (0x5aa6ec in /usr/bin/python3.6)
frame #36: /usr/bin/python3.6() [0x50abb3]
frame #37: _PyEval_EvalFrameDefault + 0x1220 (0x50d390 in /usr/bin/python3.6)
frame #38: <unknown function> + 0x18d10 (0x7fc6965ded10 in /home/rylan/.PyCharmCE2019.3/system/cythonExtensions/_pydevd_frame_eval_ext/pydevd_frame_evaluator.cpython-36m-x86_64-linux-gnu.so)
frame #39: /usr/bin/python3.6() [0x509d48]
frame #40: /usr/bin/python3.6() [0x50aa7d]
frame #41: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #42: <unknown function> + 0x18d10 (0x7fc6965ded10 in /home/rylan/.PyCharmCE2019.3/system/cythonExtensions/_pydevd_frame_eval_ext/pydevd_frame_evaluator.cpython-36m-x86_64-linux-gnu.so)
frame #43: /usr/bin/python3.6() [0x508245]
frame #44: /usr/bin/python3.6() [0x516915]
frame #45: /usr/bin/python3.6() [0x50a8af]
frame #46: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #47: <unknown function> + 0x18d10 (0x7fc6965ded10 in /home/rylan/.PyCharmCE2019.3/system/cythonExtensions/_pydevd_frame_eval_ext/pydevd_frame_evaluator.cpython-36m-x86_64-linux-gnu.so)
frame #48: /usr/bin/python3.6() [0x508245]
frame #49: /usr/bin/python3.6() [0x50a080]
frame #50: /usr/bin/python3.6() [0x50aa7d]
frame #51: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #52: <unknown function> + 0x18d10 (0x7fc6965ded10 in /home/rylan/.PyCharmCE2019.3/system/cythonExtensions/_pydevd_frame_eval_ext/pydevd_frame_evaluator.cpython-36m-x86_64-linux-gnu.so)
frame #53: /usr/bin/python3.6() [0x509d48]
frame #54: /usr/bin/python3.6() [0x50aa7d]
frame #55: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #56: /usr/bin/python3.6() [0x508245]
frame #57: /usr/bin/python3.6() [0x50a080]
frame #58: /usr/bin/python3.6() [0x50aa7d]
frame #59: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #60: /usr/bin/python3.6() [0x509d48]
frame #61: /usr/bin/python3.6() [0x50aa7d]
frame #62: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #63: /usr/bin/python3.6() [0x508245]

Error occurs, No graph saved
Traceback (most recent call last):
File "/usr/lib/python3.6/contextlib.py", line 99, in __exit__
self.gen.throw(type, value, traceback)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 42, in set_training
yield
File "/usr/local/lib/python3.6/dist-packages/torch/utils/tensorboard/_pytorch_graph.py", line 234, in graph
raise e
File "/usr/local/lib/python3.6/dist-packages/torch/utils/tensorboard/_pytorch_graph.py", line 229, in graph
trace = torch.jit.trace(model, args)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/__init__.py", line 772, in trace
check_tolerance, _force_outplace, _module_class)
File "/usr/local/lib/python3.6/dist-packages/torch/jit/__init__.py", line 904, in trace_module
module._c._create_method_from_trace(method_name, func, example_inputs, var_lookup_fn, _force_outplace)
RuntimeError: Only tensors or tuples of tensors can be output from traced functions (getOutput at /pytorch/torch/csrc/jit/tracer.cpp:208)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x33 (0x7fc661bd4273 in /usr/local/lib/python3.6/dist-packages/torch/lib/libc10.so)
frame #1: torch::jit::tracer::TracingState::getOutput(c10::IValue const&) + 0x3e8 (0x7fc664cf77b8 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch.so)
frame #2: torch::jit::tracer::exit(std::vector<c10::IValue, std::allocator<c10::IValue> > const&) + 0x3c (0x7fc664cf7aec in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x47b0f2 (0x7fc66e4820f2 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
frame #4: <unknown function> + 0x48f061 (0x7fc66e496061 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0x19da64 (0x7fc66e1a4a64 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
frame #6: _PyCFunction_FastCallDict + 0x35c (0x5674fc in /usr/bin/python3.6)
frame #7: /usr/bin/python3.6() [0x50abb3]
frame #8: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #9: <unknown function> + 0x18d10 (0x7fc6965ded10 in /home/rylan/.PyCharmCE2019.3/system/cythonExtensions/_pydevd_frame_eval_ext/pydevd_frame_evaluator.cpython-36m-x86_64-linux-gnu.so)
frame #10: /usr/bin/python3.6() [0x508245]
frame #11: /usr/bin/python3.6() [0x50a080]
frame #12: /usr/bin/python3.6() [0x50aa7d]
frame #13: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #14: <unknown function> + 0x18d10 (0x7fc6965ded10 in /home/rylan/.PyCharmCE2019.3/system/cythonExtensions/_pydevd_frame_eval_ext/pydevd_frame_evaluator.cpython-36m-x86_64-linux-gnu.so)
frame #15: /usr/bin/python3.6() [0x508245]
frame #16: /usr/bin/python3.6() [0x50a080]
frame #17: /usr/bin/python3.6() [0x50aa7d]
frame #18: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #19: <unknown function> + 0x18d10 (0x7fc6965ded10 in /home/rylan/.PyCharmCE2019.3/system/cythonExtensions/_pydevd_frame_eval_ext/pydevd_frame_evaluator.cpython-36m-x86_64-linux-gnu.so)
frame #20: /usr/bin/python3.6() [0x508245]
frame #21: /usr/bin/python3.6() [0x50a080]
frame #22: /usr/bin/python3.6() [0x50aa7d]
frame #23: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #24: <unknown function> + 0x18d10 (0x7fc6965ded10 in /home/rylan/.PyCharmCE2019.3/system/cythonExtensions/_pydevd_frame_eval_ext/pydevd_frame_evaluator.cpython-36m-x86_64-linux-gnu.so)
frame #25: /usr/bin/python3.6() [0x508245]
frame #26: /usr/bin/python3.6() [0x50a080]
frame #27: /usr/bin/python3.6() [0x50aa7d]
frame #28: _PyEval_EvalFrameDefault + 0x1220 (0x50d390 in /usr/bin/python3.6)
frame #29: <unknown function> + 0x18d10 (0x7fc6965ded10 in /home/rylan/.PyCharmCE2019.3/system/cythonExtensions/_pydevd_frame_eval_ext/pydevd_frame_evaluator.cpython-36m-x86_64-linux-gnu.so)
frame #30: /usr/bin/python3.6() [0x508245]
frame #31: _PyFunction_FastCallDict + 0x2e2 (0x509642 in /usr/bin/python3.6)
frame #32: /usr/bin/python3.6() [0x595311]
frame #33: /usr/bin/python3.6() [0x54a6ff]
frame #34: /usr/bin/python3.6() [0x551b81]
frame #35: _PyObject_FastCallKeywords + 0x19c (0x5aa6ec in /usr/bin/python3.6)
frame #36: /usr/bin/python3.6() [0x50abb3]
frame #37: _PyEval_EvalFrameDefault + 0x1220 (0x50d390 in /usr/bin/python3.6)
frame #38: <unknown function> + 0x18d10 (0x7fc6965ded10 in /home/rylan/.PyCharmCE2019.3/system/cythonExtensions/_pydevd_frame_eval_ext/pydevd_frame_evaluator.cpython-36m-x86_64-linux-gnu.so)
frame #39: /usr/bin/python3.6() [0x509d48]
frame #40: /usr/bin/python3.6() [0x50aa7d]
frame #41: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #42: <unknown function> + 0x18d10 (0x7fc6965ded10 in /home/rylan/.PyCharmCE2019.3/system/cythonExtensions/_pydevd_frame_eval_ext/pydevd_frame_evaluator.cpython-36m-x86_64-linux-gnu.so)
frame #43: /usr/bin/python3.6() [0x508245]
frame #44: /usr/bin/python3.6() [0x516915]
frame #45: /usr/bin/python3.6() [0x50a8af]
frame #46: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #47: <unknown function> + 0x18d10 (0x7fc6965ded10 in /home/rylan/.PyCharmCE2019.3/system/cythonExtensions/_pydevd_frame_eval_ext/pydevd_frame_evaluator.cpython-36m-x86_64-linux-gnu.so)
frame #48: /usr/bin/python3.6() [0x508245]
frame #49: /usr/bin/python3.6() [0x50a080]
frame #50: /usr/bin/python3.6() [0x50aa7d]
frame #51: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #52: <unknown function> + 0x18d10 (0x7fc6965ded10 in /home/rylan/.PyCharmCE2019.3/system/cythonExtensions/_pydevd_frame_eval_ext/pydevd_frame_evaluator.cpython-36m-x86_64-linux-gnu.so)
frame #53: /usr/bin/python3.6() [0x509d48]
frame #54: /usr/bin/python3.6() [0x50aa7d]
frame #55: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #56: /usr/bin/python3.6() [0x508245]
frame #57: /usr/bin/python3.6() [0x50a080]
frame #58: /usr/bin/python3.6() [0x50aa7d]
frame #59: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #60: /usr/bin/python3.6() [0x509d48]
frame #61: /usr/bin/python3.6() [0x50aa7d]
frame #62: _PyEval_EvalFrameDefault + 0x449 (0x50c5b9 in /usr/bin/python3.6)
frame #63: /usr/bin/python3.6() [0x508245]

最佳答案

你可以使用这个模型包装器:

from collections import namedtuple

从键入 import Any

类 ModelWrapper(torch.nn.Module):"""具有 dict/list 右值的模型的包装器类。"""

def __init__(self, model: torch.nn.Module) -> None:
"""
Init call.
"""
super().__init__()
self.model = model

def forward(self, input_x: torch.Tensor) -> Any:
"""
Wrap forward call.
"""
data = self.model(input_x)

if isinstance(data, dict):
data_named_tuple = namedtuple("ModelEndpoints", sorted(data.keys())) # type: ignore
data = data_named_tuple(**data) # type: ignore

elif isinstance(data, list):
data = tuple(data)

return data

关于pytorch - 如何将 torch.utils.tensorboard 的 SummaryWriter add_graph 与字典输出一起使用?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59656306/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com