gpt4 book ai didi

tensorflow - Torch JIT Trace = TracerWarning : Converting a tensor to a Python boolean might cause the trace to be incorrect

转载 作者:行者123 更新时间:2023-12-04 02:27:10 31 4
gpt4 key购买 nike

我正在关注本教程:https://huggingface.co/transformers/torchscript.html
创建我的自定义 BERT 模型的痕迹,但是在运行完全相同的 dummy_input 时我收到一个错误:

TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. 
We cant record the data flow of Python values, so this value will be treated as a constant in the future.
在我的模型和标记器中加载后,创建跟踪的代码如下:
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text)

# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]

traced_model = torch.jit.trace(model, dummy_input)
dummy_input是张量列表,所以我不确定 Boolean 在哪里类型在这里发挥作用。有谁明白为什么会发生这个错误以及 bool 转换是否正在发生?
非常感谢

最佳答案

这个错误意味着什么
警告 发生,当人们尝试 torch.jit.trace 时具有 的型号数据相关控制流 .
这个简单的例子应该更清楚:

import torch


class Foo(torch.nn.Module):
def forward(self, tensor):
# It is data dependent
# Trace will only work with one path
if tensor.max() > 0.5:
return tensor ** 2
return tensor


model = Foo()
traced = torch.jit.script(model) # No warnings
traced = torch.jit.trace(model, torch.randn(10)) # Warning
本质上,BERT 模型有一些依赖于数据的控制流(如 iffor 循环),因此您会收到警告。
警告本身
你可以看到BERT forward代码 here .
你没问题,如果:
  • 参数不会改变(比如 None 传递给 forward 的值)并且在 script 之后它会保持这种状态(例如在推理调用期间)
  • 如果存在基于内部收集的数据的控制流 __init__ (如配置),因为这不会改变

  • 例如:
    elif input_ids is not None:
    input_shape = input_ids.size()
    batch_size, seq_length = input_shape
    将仅作为 torch.jit.trace 的一个分支运行,因为它只是跟踪张量上的操作并且不知道这样的控制流。
    HuggingFace 团队可能已经意识到这一点,并且此警告不是问题(尽管您可能会仔细检查您的用例或尝试使用 torch.jit.script )
    一起去 torch.jit.script这个很难,因为整个模型必须是 torchscript兼容( torchscript 有一个 Python 子集可用,而且很可能无法与 BERT 一起开箱即用)。
    仅在必要时才这样做(可能不是)。

    关于tensorflow - Torch JIT Trace = TracerWarning : Converting a tensor to a Python boolean might cause the trace to be incorrect,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66746307/

    31 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com