PyTorch 源码阅读笔记(5):TorchScript
TorchScript
TorchScript 的使用
python api:
1
2
3
4
5
6
7
8
9
10
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
scripted_module = torch.jit.script(MyCell().eval())
C++ api:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include <torch/script.h> // One-stop header.
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "usage: example-app <path-to-exported-script-module>\n";
return -1;
}
torch::jit::script::Module module;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(argv[1]);
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
std::cout << "ok\n";
// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));
// Execute the model and turn its output into a tensor.
at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
}
关于 TorchScript
细节参考:TorchScript
TorchScript 实现的功能,从使用角度看是用 Python 编写模型,然后在 C++ 内运行,大致有如下步骤:
- 解析 Python 代码为抽象语法树
- 转化语法树为模型中间表示 IR
- 根据 IR 生成模型
- 执行模型(根据运行时信息优化模型-JIT)
从流程上看,PyTorch 在 C++ 端(LibTorch)实现了一个编译器,编译运行了一个 Python 的子集语言,即为 TorchScript:
1 ~ 3为编译器的前端(语法分析、类型检查、中间代码生成),4为编译器后端(代码优化、执行代码生成与优化)
more
- 使用角度看,TorchScript 适用于生产部署 PyTorch 模型,不过实际工作中没有直接使用过,一般训练完成以后会选择导出 onnx,openvino等格式(导出过程其实使用了相关模块),单独部署为推理服务
- 原理涉及较多编译原理相关内容,学习后再补充
This post is licensed under CC BY 4.0 by the author.