gpt4 book ai didi

input: dynamic input is missing dimensions in profile

转载 作者:知者 更新时间:2024-03-12 08:54:55 30 4
gpt4 key购买 nike

input: dynamic input is missing dimensions in profile

onnx2trt代码报错:

import numpy as np
import tensorrt as trt
import os
import pycuda.driver as cuda
import argparse

def GiB(val):
    return val * 1 << 30

def ONNX_build_engine(onnx_file_path, write_engine=True):
    # :return: engine

    G_LOGGER = trt.Logger(trt.Logger.WARNING)
    # 1、动态输入第一点必须要写的
    explicit_batch = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    batch_size = 8  # trt推理时最大支持的batchsize
    with trt.Builder(G_LOGGER) as builder, builder.create_network(explicit_batch) as network, trt.OnnxParser(network,
                                                                                                             G_LOGGER) as parser:
        builder.max_batch_size = batch_size
        config = builder.create_builder_config()
        config.max_workspace_size = GiB(2)
        config.set_flag(trt.BuilderFlag.FP16)
        print('Loading ONNX file from path {}...'.format(onnx_file_path))
        with open(onnx_file_path, 'rb') as model:
            print('Beginning ONNX file parsing')
            parser.parse(model.read())
        print('Completed parsing of ONNX file')
        print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
        # 重点
        profile = builder.create_optimization_profile()  # 动态输入时候需要 分别为最小输入、常规输入、最大输入
        # 有几个输入就要写几个profile.set_shape 名字和转onnx的时候要对应
        # tensorrt6以后的版本是支持动态输入的,需要给每个动态输入绑定一个profile,用于指定最小值,常规值和最大值,如果超出这个范围会报异常。
        profile.set_shape("input", (1, 3, 128, 128), (4, 3, 128, 128), (16, 3, 128, 128))
        config.add_optimization_profile(profile)

        engine = builder.build_engine(network, config)
        print("Completed creating Engine")
        # 保存engine文件
        if write_engine:
            engine_file_path = 'efficientnet_b1.trt'
            with open(engine_file_path, "wb") as f:
                f.write(engine.serialize())
        return engine

onnx_file_path = r'skipnet_0712.onnx'
onnx_file_path = r'model2.onnx'
onnx_file_path = r'skip_simp2.onnx'
# onnx_file_path = r'mobileone_0713.onnx'
write_engine = True
engine = ONNX_build_engine(onnx_file_path, write_engine)

原错误代码:

profile.set_shape("inputs", (1, 3, 240, 240), (8, 3, 240, 240), (16, 3, 480, 480))

改之后代码:

profile.set_shape("inputs", (1, 3, 128, 128), (8, 3, 128, 128), (16, 3, 128, 128))

30 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com