gpt4 book ai didi

python - 修剪模型不会提高推理速度或减小模型大小

转载 作者:行者123 更新时间:2023-12-03 16:11:11 25 4
gpt4 key购买 nike

我正在尝试使用 torch.nn.utils.prune 在 PyTorch 中修剪我的模型,它提供了 2 个张量,

  • 一个是原始重量和
  • 另一个是包含 0 和 1 的掩码,可帮助我们关闭网络中的某些连接。

  • 我已经尝试了这两种解决方案,但都没有提高推理速度:
  • 使用剪枝后的网络来推断哪个会先关闭一些与掩码的连接,然后再运行推断。
  • 使用掩码将原始权重归零,然后从 state_dict 中删除掩码以进行推断。

  • 有没有办法通过模型张量和掩码来提高速度?与 0 的非零浮点数相乘不会比将 2 个浮点数相乘更快吗?
    这是我的修剪功能和修剪速度计算过程:
    def prune_net(net):
    """Prune 20% net's weights that have abs(value) approx. 0
    Function that will be use when an iteration is reach
    Args:

    Return:
    newnet (nn.Module): a newnet contain mask that help prune network's weight
    """
    if not isinstance(net,nn.Module):
    print('Invalid input. Must be nn.Module')
    return
    newnet = copy.copy(net)
    modules_list = []

    for name, module in newnet.named_modules():
    if isinstance(module, torch.nn.Conv2d):
    modules_list += [(module,'weight'),(module,'bias')]
    if isinstance(module, torch.nn.Linear):
    modules_list += [(module,'weight'),(module,'bias')]

    prune.global_unstructured(
    modules_list,
    pruning_method=prune.L1Unstructured,
    amount=0.2,)
    return newnet

    测试推理速度第一种情况:
    import torch
    from torch import nn
    import torch.nn.utils.prune as prune
    import torch.nn.functional as F
    import time
    from torch.autograd import Variable


    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    old_net = init_your_net()

    new_net = prune_net(old_net)
    new_net = prune_net(new_net)

    old_net.eval()
    new_net.eval()

    old_net = old_net.cuda()
    new_net = new_net.cuda()
    dataset = load_your_dataset()

    for i in range(100):
    x = dataset[i]
    x = x.cuda()
    y = x.cuda()

    #new infer
    start_time = time.perf_counter()
    detections = new_net(x).data
    time_new += time.perf_counter() - start_time

    #old infer
    start_time = time.perf_counter()
    detections = old_net(y).data
    time_old += time.perf_counter() - start_time
    print('old ',time_old)
    print('new ', time_new)

    测试推理速度第二种情况:
    import torch
    from torch import nn
    import torch.nn.utils.prune as prune
    import torch.nn.functional as F
    import time
    from torch.autograd import Variable


    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    old_net = init_your_net()

    new_net = prune_net(old_net)
    new_net = prune_net(new_net)
    # Apply mask to model tensor and remove mask from state_dict
    for name, module in new_net.named_modules():
    if isinstance(module, torch.nn.Conv2d):
    prune.remove(module,'weight')
    prune.remove(module,'bias')
    if isinstance(module, torch.nn.Linear):
    prune.remove(module,'weight')
    prune.remove(module,'bias')

    old_net.eval()
    new_net.eval()

    old_net = old_net.cuda()
    new_net = new_net.cuda()
    dataset = load_your_dataset()

    for i in range(100):
    x = dataset[i]
    x = x.cuda()
    y = x.cuda()

    #new infer
    start_time = time.perf_counter()
    detections = new_net(x).data
    time_new += time.perf_counter() - start_time

    #old infer
    start_time = time.perf_counter()
    detections = old_net(y).data
    time_old += time.perf_counter() - start_time
    print('old ',time_old)
    print('new ', time_new)

    更新
    我发现 torch 有一个稀疏模块,如果我们修剪足够的参数,它可以减少内存使用,但它还不支持 nn.Module,只支持 Tensor 对象。以下是一些有用的链接:
    https://github.com/pytorch/pytorch/issues/36214#issuecomment-619586452
    https://pytorch.org/docs/stable/sparse.html

    最佳答案

    了解非结构化修剪和结构化修剪之间的区别很重要。

  • 结构化剪枝:通过删除张量的整行/列来减小权重张量的尺寸。这转化为移除所有传入和传出连接(在密集层中)或整个卷积过滤器(在卷积层中)的神经元。
  • 非结构化剪枝:可以“删除”(归零)单个权重,而不受最终张量形状的限制。这转化为删除神经元之间的单个连接(在密集层中)或删除卷积滤波器的单个权重(在卷积层中)。请注意,生成的权重张量可能是稀疏的,但会保持其原始形状。

  • 目前, torch.nn.utils.prune仅支持非结构化剪枝,这几乎无助于降低推理成本,因为 GPU 未针对稀疏矩阵乘法进行优化。虽然您可能希望减小权重张量的维度以减少浮点运算的数量,但非结构化剪枝会产生具有许多零的权重张量,但不会自动减小此类张量的大小。
    只有在移除大量权重时,非结构化剪枝才能帮助提高性能。在这种情况下,您可以依赖 PyTorch sparse operations或尝试查找包含全零的行/列,因此可以将其删除。
    相反,如果您想研究结构化修剪,可以查看 TorchPruner ,一个我自己开发的用于研究目的的库,它提供实用程序来查找最不重要的神经元并相应地对权重张量进行切片。

    关于python - 修剪模型不会提高推理速度或减小模型大小,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62326683/

    25 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com