PyTorch 源码阅读笔记(3):算子调用
算子调用
算子注册
参考 原生算子注册
算子调用过程
找到 OperatorHandle
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// cmake-build-debug-wsl-gcc/aten/src/ATen/core/TensorBody.h
inline at::Tensor Tensor::add(const at::Tensor & other, const at::Scalar & alpha) const {
return at::_ops::add_Tensor::call(const_cast<Tensor&>(*this), other, alpha);
}
// cmake-build-debug-wsl-gcc/aten/src/ATen/ops/add_ops.h
struct TORCH_API add_Tensor {
using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Scalar &);
using ptr_schema = schema*;
// See Note [static constexpr char* members for windows NVCC]
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::add")
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "Tensor")
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor")
static at::Tensor call(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha);
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha);
}
// cmake-build-debug-wsl-gcc/aten/src/ATen/Operators_2.cpp
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(add_Tensor, name, "aten::add")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(add_Tensor, overload_name, "Tensor")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(add_Tensor, schema_str, "add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor")
static C10_NOINLINE c10::TypedOperatorHandle<add_Tensor::schema> create_add_Tensor_typed_handle() {
return c10::Dispatcher::singleton()
.findSchemaOrThrow(add_Tensor::name, add_Tensor::overload_name)
.typed<add_Tensor::schema>();
}
at::Tensor add_Tensor::call(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
static auto op = create_add_Tensor_typed_handle();
return op.call(self, other, alpha);
}
at::Tensor add_Tensor::redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
static auto op = create_add_Tensor_typed_handle();
return op.redispatch(dispatchKeySet, self, other, alpha);
}
最后的 typed 函数返回的是一个 TypedOperatorHandle 类,call 函数调用来到 dispatcher:
1
2
3
4
5
6
// Return (Args...) -> add_Tensor::schema -> at::Tensor (const at::Tensor &, const at::Tensor &, const at::Scalar &)
template<class Return, class... Args>
class TypedOperatorHandle<Return (Args...)> final : public OperatorHandle {
C10_ALWAYS_INLINE Return call(Args... args) const {
return c10::Dispatcher::singleton().call<Return, Args...>(*this, std::forward<Args>(args)...);
}}
调用算子函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
template<class Return, class... Args>
C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorHandle<Return(Args...)>& op, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor()
.template getDispatchKeySetUnboxed<Args...>(args...);
#ifndef NDEBUG
DispatchTraceNestingGuard debug_guard;
if (show_dispatch_trace()) {
auto nesting_value = dispatch_trace_nesting_value();
for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " ";
std::cerr << "[call] op=[" << op.operator_name() << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl;
}
#endif
const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet);
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
if (C10_UNLIKELY(step_callbacks.has_value() && op.operatorDef_->op.isObserved())) {
return callWithDispatchKeySlowPath<Return, Args...>(op, *step_callbacks, dispatchKeySet, kernel, std::forward<Args>(args)...);
}
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
}
上面的代码会使用 dispatcher 根据之前注册的信息生成一个 dispatchKeySet,然后根据之前注册的算子找到对应的函数运行。
自动梯度
This post is licensed under CC BY 4.0 by the author.