在C++中加载TorchScript模型的方法分享!

本教程已更新为可与PyTorch 1.2一起使用

顾名思义,PyTorch的主要接口是Python编程语言。尽管Python是合适于许多需要动态性和易于迭代的场景,并且是首选的语言,但同样的,在许多情况下,Python的这些属性恰恰是不利的。后者通常适用的一种环境是要求生产-低延迟和严格部署。对于生产场景,即使只将C ++绑定到Java,Rust或Go之类的另一种语言中,它也是经常选择的语言。以下各段将概述PyTorch提供的从现有Python模型到可以完全从C ++加载和执行的序列化表示形式的路径,而无需依赖Python。

步骤1:将PyTorch模型转换为Torch脚本

PyTorch模型从Python到C ++的旅程由Torch Script启动,Torch Script是PyTorch模型的一种表示形式,可以由Torch Script编译器理解,编译和序列化。如果您是从使用vanilla“eager” API编写的现有PyTorch模型开始的,则必须首先将模型转换为Torch脚本。在最常见的情况下(如下所述),这只需要花费很少的功夫。如果您已经有了Torch脚本模块,则可以跳到本教程的下一部分。

有两种将PyTorch模型转换为Torch脚本的方法。第一种称为跟踪,一种机制,其中通过使用示例输入对模型的结构进行一次评估,并记录这些输入在模型中的流量,从而捕获模型的结构。这适用于有限使用控制流的模型。第二种方法是在模型中添加显式批注,以告知Torch Script编译器可以根据Torch Script语言施加的约束直接解析和编译模型代码。

提示:您可以在官方 Torch脚本参考 中找到有关这两种方法的完整文档,以及使用方法的进一步指导。

方法1:通过跟踪转换为Torch脚本

要将PyTorch模型通过跟踪转换为Torch脚本,必须将模型的实例以及示例输入传递给 torch.jit.trace 函数。这将产生一个 torch.jit.ScriptModule 对象,该对象的模型评估痕迹将嵌入模块的 forward 方法中:

  import torch  import torchvision  # 你模型的一个实例.  model = torchvision.models.resnet18()  # 您通常会提供给模型的forward()方法的示例输入。  example = torch.rand(1, 3, 224, 224)  # 使用`torch.jit.trace `来通过跟踪生成`torch.jit.ScriptModule`  traced_script_module = torch.jit.trace(model, example)

现在可以对跟踪的 ScriptModule 进行评估,使其与常规PyTorch模块相同:

  In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))  In[2]: output[0, :5]  Out[2]: tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)

方法2:通过注释转换为Torch脚本

在某些情况下,例如,如果模型采用特定形式的控制流,则可能需要直接在Torch脚本中编写模型并相应地注释模型。例如,假设您具有以下vanilla Pytorch模型:

  import torch  class MyModule(torch.nn.Module):   def __init__(self, N, M):    super(MyModule, self).__init__()    self.weight = torch.nn.Parameter(torch.rand(N, M))     def forward(self, input):    if input.sum() > 0:     output = self.weight.mv(input)    else:     output = self.weight + input    return output

因为此模块的前向方法使用取决于输入的控制流,所以它不适合跟踪。相反,我们可以将其转换为 ScriptModule 。为了将模块转换为 ScriptModule ,需要使用 torch.jit.script 编译模块,如下所示:

  class MyModule(torch.nn.Module):   def __init__(self, N, M):    super(MyModule, self).__init__()    self.weight = torch.nn.Parameter(torch.rand(N, M))     def forward(self, input):    if input.sum() > 0:     output = self.weight.mv(input)    else:     output = self.weight + input    return output    my_module = MyModule(10,20)  sm = torch.jit.script(my_module)  

如果您需要在 nn.Module 中排除某些方法,因为它们使用了 TorchScript 尚不支持的Python功能,则可以使用 @torch.jit.ignore 对其进行注释

my_module 是 ScriptModule 的实例,可以序列化。

步骤2:将脚本模块序列化为文件

一旦有了ScriptModule(通过跟踪或注释PyTorch模型),您就可以将其序列化为文件了。稍后,您将可以使用C ++从此文件加载模块并执行它,而无需依赖Python。假设我们要序列化先前在跟踪示例中显示的 ResNet18 模型。要执行此序列化,只需在模块上调用 save 并传递一个文件名即可:

traced_script_module.save("traced_resnet_model.pt")

这将在您的工作目录中生成 traced_resnet_model.pt 文件。如果您还想序列化 my_module ,请调用 my_module.save(“my_module_model.pt”) 我们现在已经正式离开Python领域,并准备跨入C ++领域。

步骤3:在C ++中加载脚本模块

本文来自网络收集,不代表计算机技术网立场,如涉及侵权请联系管理员删除。

ctvol管理联系方式QQ:251552304

本文章地址:https://www.ctvol.com/c-cdevelopment/483716.html

(0)
上一篇 2020年11月10日
下一篇 2020年11月10日

精彩推荐