gpt4 book ai didi

python - TorchScript 需要源代码访问才能对 collections.deque 进行编译

转载 作者:行者123 更新时间:2023-12-05 06:49:11 31 4
gpt4 key购买 nike

我正在尝试转换 PyTorch FOMM模型到 TorchScript。一旦我开始用 @torch.jit.script 注释一些类我有一个错误:

OSError: Can't get source for <class 'collections.deque'>. TorchScript requires source access in order to carry out compilation, make sure original .py files are available.

据我所知,CPython 中实现的类因此无法被 TorchScript 编译器读取。我没有找到任何纯 Python 实现。我该如何克服这个问题?

这是我要注释的类:

import queue
import collections
import threading
import torch

@torch.jit.script
class SyncMaster(object):
"""An abstract `SyncMaster` object.

- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
and passed to a registered callback.
- After receiving the messages, the master device should gather the information and determine to message passed
back to each slave devices.
"""

def __init__(self, master_callback):
"""

Args:
master_callback: a callback to be invoked after having collected messages from slave devices.
"""
self._master_callback = master_callback
self._queue = queue.Queue()
self._registry = collections.OrderedDict()
self._activated = False

def __getstate__(self):
return {'master_callback': self._master_callback}

def __setstate__(self, state):
self.__init__(state['master_callback'])

def register_slave(self, identifier):
"""
Register an slave device.

Args:
identifier: an identifier, usually is the device id.

Returns: a `SlavePipe` object which can be used to communicate with the master device.

"""
if self._activated:
assert self._queue.empty(), 'Queue is not clean before next initialization.'
self._activated = False
self._registry.clear()
future = FutureResult()
self._registry[identifier] = _MasterRegistry(future)
return SlavePipe(identifier, self._queue, future)

def run_master(self, master_msg):
"""
Main entry for the master device in each forward pass.
The messages were first collected from each devices (including the master device), and then
an callback will be invoked to compute the message to be sent back to each devices
(including the master device).

Args:
master_msg: the message that the master want to send to itself. This will be placed as the first
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.

Returns: the message to be sent back to the master device.

"""
self._activated = True

intermediates = [(0, master_msg)]
for i in range(self.nr_slaves):
intermediates.append(self._queue.get())

results = self._master_callback(intermediates)
assert results[0][0] == 0, 'The first result should belongs to the master.'

for i, res in results:
if i == 0:
continue
self._registry[i].result.put(res)

for i in range(self.nr_slaves):
assert self._queue.get() is True

return results[0][1]

@property
def nr_slaves(self):
return len(self._registry)

最佳答案

将 TorchScript 生成方法从 torch.jit.script 切换到 torch.jit.trace 并且它有效,不需要注释任何东西。或者,torch.onnx.export 有时会起作用。

关于python - TorchScript 需要源代码访问才能对 collections.deque 进行编译,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66628965/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com