Post

PyTorch 源码阅读笔记(4):自动微分张量库

自动微分张量库

张量库

张量接口定义可以在 aten/src/ATen/core/TensorBody.h 看到,Tensor 类含有大量自动生成的代码,可以进行算子调用。
Tensor 类继承自 TensorBase 类,张量相关的大量函数调用自父类 TensorBase ,TensorBase 类有一个关键的成员变量:

1
2
protected:
  c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_;

TensorImpl 类为张量的底层表示,包含了实际的数据指针和用以描述张量的元数据,它继承自 c10::intrusive_ptr_target,intrusive_ptr_target 是 c10 模块的侵入式指针模块。
PyTorch 实现了一个侵入式指针来替代 C++ 的 shared_ptr,shared_ptr 使用时需要创建单独的对象进行引用计数,而侵入式指针在使用的类中进行引用计数,所以侵入式指针具有更好的性能。
使用侵入式指针的类都需要实现引用计数的函数,在这里则是都需要继承 c10::intrusive_ptr_target 类,intrusive_ptr_target 有如下两个成员变量,refcount_ 记录引用计数,weakcount_ 记录弱引用计数,弱引用计数可以处理循环引用的问题:

1
2
  mutable std::atomic<size_t> refcount_;
  mutable std::atomic<size_t> weakcount_;

TensorImpl 有一个 Storage 类的成员变量,Storage 有如下成员变量:

1
2
 protected:
  c10::intrusive_ptr<StorageImpl> storage_impl_;

StorageImpl 继承了 c10::intrusive_ptr_target, 是实质上的底层数据类,保存了原始数据指针,对于 Storage 类的设计官方备注是继承自原始的 Torch7 项目,倾向于去掉此模块的设计,但是比较麻烦没人有空做。

Variable 与 Tensor

在较新版本的 PyTorch 中,Variable 与 Tensor 进行了合并,有如下的命名空间定义,不过没有完全去掉 Variable 相关的 api:

1
using torch::autograd::Variable = at::Tensor

自动微分

反向传播 api

backward 函数的调用会进行反向传播:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
  void backward(const Tensor & gradient={}, c10::optional<bool> retain_graph=c10::nullopt, bool create_graph=false, c10::optional<TensorList> inputs=c10::nullopt) const {
    // NB: Adding this wrapper to _backward here because we'd like our
    // 'backwards' api to accept the 'inputs' argument optionally. Since code gen
    // currently does not support optional of TensorList our approach is to replace
    // backward in native_functions.yaml with _backward and call it here instead.
    if (inputs.has_value()) {
      TORCH_CHECK(inputs.value().size() > 0, "'inputs' argument to backward cannot be empty")
      this->_backward(inputs.value(), gradient, retain_graph, create_graph);
    } else {
      this->_backward({}, gradient, retain_graph, create_graph);
    }
  }

void Tensor::_backward(TensorList inputs,
        const c10::optional<Tensor>& gradient,
        c10::optional<bool> keep_graph,
        bool create_graph) const {
  return impl::GetVariableHooks()->_backward(*this, inputs, gradient, keep_graph, create_graph);
}

namespace at { namespace impl {
namespace {
VariableHooksInterface* hooks = nullptr;
}
void SetVariableHooks(VariableHooksInterface* h) {
  hooks = h;
}
VariableHooksInterface* GetVariableHooks() {
  TORCH_CHECK(hooks, "Support for autograd has not been loaded; have you linked against libtorch.so?")
  return hooks;
}
}} // namespace at::impl

VariableHooksInterface 类定义了一些虚函数,包括:

1
 virtual void _backward(const Tensor&, TensorList, const c10::optional<Tensor>&, c10::optional<bool>, bool) const = 0;

VariableHooks 结构体继承自 VariableHooksInterface ,有如下函数:

1
2
3
4
5
6
7
8
9
10
11
12
void VariableHooks::_backward(
    const Tensor& self,
    at::TensorList inputs,
    const c10::optional<Tensor>& gradient,
    c10::optional<bool> keep_graph,
    bool create_graph) const {
  // TODO torch::autograd::backward should take the c10::optional<Tensor> gradient directly
  // instead of us having to unwrap it to Tensor _gradient here.
  Tensor _gradient = gradient.has_value() ? *gradient : Tensor();
  std::vector<torch::autograd::Variable> input_vars(inputs.begin(), inputs.end());
  torch::autograd::backward({self}, {_gradient}, keep_graph, create_graph, input_vars);
}

计算图

Node

计算图的点,对计算操作的抽象,接受0个或者多个输入,产生0个或者多个输出

1
2
3
4
5
using edge_list = std::vector<Edge>;
struct TORCH_API Node : std::enable_shared_from_this<Node> {
// 记录 Node 输出指向的 Edge。
edge_list next_edges_;
}

edge

计算图的边,Node 之间连接关系的抽象

1
2
3
4
5
6
7
8
9
struct Edge {
  Edge() noexcept : function(nullptr), input_nr(0) {}
  Edge(std::shared_ptr<Node> function_, uint32_t input_nr_) noexcept
      : function(std::move(function_)), input_nr(input_nr_) {}
  // 记录边指向的点。
  std::shared_ptr<Node> function;
  // 记录边是指向点的第几个输入
  uint32_t input_nr;
};

动态构建

反向传播依赖于计算图,PyTorch 的计算图是动态构建的,当运行如下代码时:

1
2
// ...
train_y = torch::add(train_y, real_b);

调用 add 算子,当标记了 Autograd 这个 key 时,调用会被 dispatcher 派发至函数 VariableType::add_Tensor(参考前几篇):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
// torch/csrc/autograd/generated/VariableType_2.cpp
at::Tensor add_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
  auto& self_ = unpack(self, "self", 0);
  auto& other_ = unpack(other, "other", 1);
  auto _any_requires_grad = compute_requires_grad( self, other );
  
  (void)_any_requires_grad;
  auto _any_has_forward_grad_result = (isFwGradDefined(self) || isFwGradDefined(other));
  (void)_any_has_forward_grad_result;
  std::shared_ptr<AddBackward0> grad_fn;
  if (_any_requires_grad) {
    grad_fn = std::shared_ptr<AddBackward0>(new AddBackward0(), deleteNode);
    grad_fn->set_next_edges(collect_next_edges( self, other ));
    grad_fn->other_scalar_type = other.scalar_type();
    grad_fn->alpha = alpha;
    grad_fn->self_scalar_type = self.scalar_type();
  }
  #ifndef NDEBUG
  c10::optional<Storage> self__storage_saved =
    self_.has_storage() ? c10::optional<Storage>(self_.storage()) : c10::nullopt;
  c10::intrusive_ptr<TensorImpl> self__impl_saved;
  if (self_.defined()) self__impl_saved = self_.getIntrusivePtr();
  c10::optional<Storage> other__storage_saved =
    other_.has_storage() ? c10::optional<Storage>(other_.storage()) : c10::nullopt;
  c10::intrusive_ptr<TensorImpl> other__impl_saved;
  if (other_.defined()) other__impl_saved = other_.getIntrusivePtr();
  #endif
  auto _tmp = ([&]() {
    at::AutoDispatchBelowADInplaceOrView guard;
    return at::redispatch::add(ks & c10::after_autograd_keyset, self_, other_, alpha);
  })();
  auto result = std::move(_tmp);
  #ifndef NDEBUG
  if (self__storage_saved.has_value() &&
      !at::impl::dispatch_mode_enabled() &&
      !at::impl::tensor_has_dispatch(self_))
    AT_ASSERT(self__storage_saved.value().is_alias_of(self_.storage()));
  if (self__impl_saved && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(self_))
    AT_ASSERT(self__impl_saved == self_.getIntrusivePtr());
  if (other__storage_saved.has_value() &&
      !at::impl::dispatch_mode_enabled() &&
      !at::impl::tensor_has_dispatch(other_))
    AT_ASSERT(other__storage_saved.value().is_alias_of(other_.storage()));
  if (other__impl_saved && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(other_))
    AT_ASSERT(other__impl_saved == other_.getIntrusivePtr());
  if (result.has_storage() && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(result)) {
    AT_ASSERT(result.storage().use_count() == 1, "function: add_Tensor");
  }
  if (!at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(result))
    AT_ASSERT(result.use_count() <= 1, "function: add_Tensor");
  #endif
  if (grad_fn) {
      set_history(flatten_tensor_args( result ), grad_fn);
  }
  c10::optional<at::Tensor> result_new_fw_grad_opt = c10::nullopt;
  if (_any_has_forward_grad_result && (result.defined())) {
      auto self_t_raw = toNonOptFwGrad(self);
      auto self_tensor = toNonOptTensor(self);
      auto self_t = (self_t_raw.defined() || !self_tensor.defined())
        ? self_t_raw : at::_efficientzerotensor(self_tensor.sizes(), self_tensor.options());
      auto other_t_raw = toNonOptFwGrad(other);
      auto other_tensor = toNonOptTensor(other);
      auto other_t = (other_t_raw.defined() || !other_tensor.defined())
        ? other_t_raw : at::_efficientzerotensor(other_tensor.sizes(), other_tensor.options());
      result_new_fw_grad_opt = self_t + maybe_multiply(other_t, alpha);
  }
  if (result_new_fw_grad_opt.has_value() && result_new_fw_grad_opt.value().defined() && result.defined()) {
    // The hardcoded 0 here will need to be updated once we support multiple levels.
    result._set_fw_grad(result_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ false);
  }
  return result;
}

上面代码中会进行反向传播函数指针的生成、存储与前向计算,“动态”即指的是反向传播的计算图会在正向计算时构建。

更多

更多细节有时间再补充

This post is licensed under CC BY 4.0 by the author.