PyTorch 源码阅读笔记(0):代码结构与编译
代码结构与编译
PyTorch代码目录结构
参考PyTorch官方描述,大致代码结构如下所述:
- c10 - Core library —— 核心库,包含最基本的功能,aten/src/ATen/core中的代码在逐渐往此处迁移
- aten - PyTorch C++ 张量库,不包括自动梯度支持
- aten/src -
- aten/src/ATen
- aten/src/ATen/core - 核心函数库,逐步往c10迁移中。
- aten/src/ATen/native - 原生算子库,大部分CPU算子在此一级目录下,除了一些有特殊编译需求的算子在cpu目录下
- aten/src/ATen/native/cpu - 某些需要类似AVX等特殊指令集编译的cpu算子实现
- aten/src/ATen/native/cuda - CUDA 算子
- aten/src/ATen/native/sparse - CPU 和 CUDA 的稀疏矩阵算子
- aten/src/ATen/native/mkl,aten/src/ATen/native/mkldnn,…… - 如文件夹描述,对应的算子
- aten/src/ATen
- aten/src -
- torch - 实际的PyTorch库,除了torch/csrc之外的都是Python模块
- torch/csrc - Python 和 C++ 混编库
- torch/csrc/jit - TorchScript JIT frontend
- torch/csrc/autograd - 自动微分实现
- torch/csrc/api - The PyTorch C++ frontend.
- torch/csrc/distributed -
- torch/csrc - Python 和 C++ 混编库
- tools - 代码生成模块,PyTorch很多代码是在编译时自动生成的
- test - Python前端单元测试模块,C++前端的单元测试在其他文件夹
- caffe2 - Caffe2 库合并入PyTorch,具体合并了哪些官方说的太抽象,以后看到了再更新
PyTorch c++ 编译
PyTorch 官方有单独打包的 C++ 库 libtorch
首先clone仓库,由于会下载很多第三方仓库,需要较长时间,并且克隆失败的库可能会导致编译失败:
1
git clone -b master --recurse-submodule https://github.com/PyTorch/PyTorch.git
直接使用官方的cmake文件编译就可以,即使是单独编译 C++ 模块也需要安装Python,因为有大量代码是通过Python脚本生成
我使用额外参数:
1
-DBUILD_SHARED_LIBS:BOOL=ON -DCMAKE_BUILD_TYPE:STRING=Debug -DPYTHON_EXECUTABLE:PATH=/opt/miniconda3/bin/python3
api
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include <ATen/TensorIndexing.h>
#include <ATen/ops/add.h>
#include <ATen/ops/matmul.h>
#include <ATen/ops/mse_loss.h>
#include <c10/core/TensorOptions.h>
#include <torch/optim/adam.h>
#include <torch/optim/sgd.h>
#include "../include/linear.hpp"
void linear_test() {
int data_size = 50000, batch_size = 5000, in_feat = 3, out_feat = 4;
torch::Tensor real_w =
torch::randn({in_feat, out_feat}, torch::dtype(torch::kDouble));
auto real_w_key_set = (real_w.unsafeGetTensorImpl())->key_set();
torch::Tensor real_b = torch::randn({out_feat}, torch::dtype(torch::kDouble));
torch::Tensor train_x =
torch::randn({data_size, in_feat}, torch::dtype(torch::kDouble));
auto train_y = torch::matmul(train_x, real_w);
train_y = torch::add(train_y, real_b);
OpsLinear linear{in_feat, out_feat};
torch::optim::SGD sgd(linear->parameters(), torch::optim::SGDOptions(1e-6));
int num_batch = data_size / batch_size;
int epochs = 500;
double min_loss{99999};
torch::Tensor best_w = torch::rand_like(linear->W);
for (int epoch = 0; epoch < epochs; epoch++) {
for (int batch = 0; batch < num_batch; batch++) {
torch::Tensor real_x = train_x.index(
{torch::indexing::Slice(batch * batch_size, (batch + 1) * batch_size),
"..."});
torch::Tensor real_y = train_y.index(
{torch::indexing::Slice(batch * batch_size, (batch + 1) * batch_size),
"..."});
torch::Tensor pred_y = linear->forward(real_x);
torch::Tensor loss = torch::mse_loss(pred_y, real_y);
// sgd.zero_grad();
loss.backward();
sgd.step();
std::printf("\r[%2d/%2d][%3d/%3d] loss: %.4f", epoch, epochs, batch,
num_batch, loss.item<double>());
if (loss.item<double>() < min_loss) {
min_loss = loss.item<double>();
best_w = torch::clone(linear->W);
}
}
}
std::cout << "" << std::endl;
std::cout << "real_w: " << real_w << "\n"
<< "real_b: " << real_b << std::endl;
std::cout << "min loss: " << min_loss << "\n"
<< "w: " << best_w << std::endl;
}
cmake配置文件:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
cmake_minimum_required(VERSION 3.15)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CUDA_ARCHITECTURES "75")
set(CMAKE_CUDA_COMPILER "/usr/local/cuda/bin/nvcc")
set(CMAKE_PREFIX_PATH "/data/build/libtorch")
project(ops LANGUAGES CXX C CUDA)
set(Python3_ROOT_DIR "/opt/miniconda3")
find_package(Python3 REQUIRED COMPONENTS Interpreter Development)
find_package(Torch REQUIRED)
message("torch : " ${TORCH_LIBRARIES})
aux_source_directory(./src src)
aux_source_directory(./include include)
aux_source_directory(./test test)
add_executable(ops ${src} ${include})
target_link_libraries(ops ${TORCH_LIBRARIES} Python3::Python)
可能有一些库的依赖问题,安装对应的库即可。
This post is licensed under CC BY 4.0 by the author.