Post

PyTorch 源码阅读笔记(2):原生算子注册

原生算子注册

算子定义

按照官方描述,所有的原生算子(函数)都定义在aten/src/ATen/native/native_functions.yaml文件里面,以一个add算子为例:
如下原生算子:

1
2
3
4
5
6
7
8
9
10
11
- func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
  device_check: NoCheck   # TensorIterator
  structured_delegate: add.out
  variants: function, method
  dispatch:
    SparseCPU, SparseCUDA: add_sparse
    SparseCsrCPU, SparseCsrCUDA: add_sparse_csr
    MkldnnCPU: mkldnn_add
    ZeroTensor: add_zerotensor
    NestedTensorCPU, NestedTensorCUDA: NestedTensor_add_Tensor
  tags: [canonical, pointwise]

算子信息注册

算子通过如下宏进行 schema 注册:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 文件自动生成在 cmake-build-debug-wsl-gcc/aten/src/ATen/RegisterSchema.cpp
// TORCH_LIBRARY(aten, m) 展开如下
static void TORCH_LIBRARY_init_aten(torch::Library&);
static const torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_aten(
    torch::Library::DEF,
    &TORCH_LIBRARY_init_aten,
    "aten",
    c10::nullopt,
    "_file_name_",
    6);
void TORCH_LIBRARY_init_aten(torch::Library& m)
{
 m.def("add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", {at::Tag::core, at::Tag::pointwise});
}

注册发生在 m.def(…):

1
2
3
4
5
  template <typename Schema>
  Library& def(Schema&& raw_schema, const std::vector<at::Tag>& tags = {}, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & {
    c10::FunctionSchema s = schema(std::forward<Schema>(raw_schema));
    return _def(std::move(s), nullptr, tags, rv);
  }

首先从 “add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor” 生成一个 c10::FunctionSchema 对象实例:

1
2
3
4
5
6
7
8
9
struct TORCH_API FunctionSchema {
//...
 private:
  OperatorName name_;
  std::vector<Argument> arguments_;
  std::vector<Argument> returns_;
  bool is_vararg_;
  bool is_varret_;
}

然后调用 _def(…),关键步骤在:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
  switch (rv) {
    case _RegisterOrVerify::REGISTER:
      registrars_.emplace_back(
        c10::Dispatcher::singleton().registerDef(
          std::move(schema),
          debugString(file_, line_),
          tags
        )
      );
      break;
    case _RegisterOrVerify::VERIFY:
      c10::Dispatcher::singleton().waitForDef(schema);
      break;
  }

上面代码中的c10::Dispatcher::singleton()会返回 Dispatcher 单例对象,然后调用 registerDef:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::string debug, std::vector<at::Tag> tags) {
  // we need a lock to avoid concurrent writes
  std::lock_guard<std::mutex> lock(mutex_);

  OperatorName op_name = schema.operator_name();
  auto op = findOrRegisterName_(op_name);

  TORCH_CHECK(op.operatorDef_->def_count == 0, "Tried to register an operator (", schema, ") with the same name and overload name multiple times.",
                                                    " Each overload's schema should only be registered with a single call to def().",
                                                    " Duplicate registration: ", debug, ". Original registration: ", op.operatorDef_->op.debug());
  op.operatorDef_->op.registerSchema(std::move(schema), std::move(debug), std::move(tags));
  listeners_->callOnOperatorRegistered(op);

  // NB: do not increment the counts until AFTER error checking
  ++op.operatorDef_->def_count;
  ++op.operatorDef_->def_and_impl_count;

  cond_var_.notify_all();

  return RegistrationHandleRAII([this, op, op_name] {
    deregisterDef_(op, op_name);
  });
}

name 注册

findOrRegisterName_ 函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
OperatorHandle Dispatcher::findOrRegisterName_(const OperatorName& op_name) {
  const auto found = findOp(op_name);
  if (found != c10::nullopt) {
    return *found;
  }

  operators_.emplace_back(OperatorName(op_name));
  OperatorHandle handle(--operators_.end());
  operatorLookupTable_.write([&] (ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) {
    operatorLookupTable.emplace(op_name, handle);
  });

  return handle;
}

算子还未注册,findOp 返回空指针,在 operators_.emplace_back(OperatorName(op_name)) 这一步,隐式构造了一个 OperatorDef 对象:

1
2
3
4
5
6
7
8
9
10
11
class TORCH_API Dispatcher final {
    // 嵌套结构体
    struct OperatorDef final {
      explicit OperatorDef(OperatorName&& op_name)
      : op(std::move(op_name)) {}

      impl::OperatorDef op;
      size_t def_count = 0;
      size_t def_and_impl_count = 0;
    };
}

OperatorDef 内又隐式构造了一个 类成员op,op 保存了算子的信息:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
OperatorEntry::OperatorEntry(OperatorName&& operator_name)
: name_(std::move(operator_name))
, schema_()
#ifndef C10_MOBILE
, tags_()
#endif
, dispatchTable_()
, dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized())
, kernels_()
, cpp_signature_()
, sym_cpp_signature_()
, is_observed_(ObservedOperators::isObserved(name_))
{
  // Pick up any backend fallbacks that were registered prior to this
  // OperatorEntry being created
  updateDispatchTableFull_(c10::Dispatcher::singleton());
}

class TORCH_API OperatorEntry final {
public:
  explicit OperatorEntry(OperatorName&& operator_name);
private:
  OperatorName name_;
  c10::optional<AnnotatedSchema> schema_;
  #ifndef C10_MOBILE
    std::vector<at::Tag> tags_;
  #endif
  std::array<KernelFunction, c10::num_runtime_entries> dispatchTable_;
  DispatchKeyExtractor dispatchKeyExtractor_;
  // Pointer to the torch.ops.ns.op.overload object for speed
  c10::PyHandleCache py_cache_;

  ska::flat_hash_map<DispatchKey,
#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
        // On mobile, we needn't worry about Jupyter notebooks.
        std::array<AnnotatedKernel, 1>
#else
        std::list<AnnotatedKernel>
#endif
        > kernels_;
}

kernels_ 存储的是 DispatchKey 与对应 key 下面注册的核函数哈希表:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// This data structure represents a kernel that was registered to us from a
// user.  Unlike KernelFunction, AnnotatedKernel contains some extra metadata
// about the kernel that isn't necessary for actual dispatching (this is why
// we don't put AnnotatedKernel in the actual DispatchTable), but is useful for
// giving good error messages.
struct AnnotatedKernel final {
  AnnotatedKernel(KernelFunction k, std::unique_ptr<FunctionSchema> s, std::string d)
    : kernel(std::move(k))
    , inferred_function_schema(std::move(s))
    , debug(std::move(d))
    {}
  AnnotatedKernel() = default;
  KernelFunction kernel;
  std::unique_ptr<FunctionSchema> inferred_function_schema;
  // A little debug string to help us identify the kernel in question.
  // Most importantly it records the TORCH_LIBRARY block that did the
  // registration.
  std::string debug;
};

然后 OperatorHandle handle(–operators_.end()) 构造了一个 OperatorHandle 对象:

1
2
3
4
5
6
7
8
9
private:
  explicit OperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
  : operatorDef_(&*operatorIterator), operatorIterator_(operatorIterator)  {}
  friend class Dispatcher;
  template<class> friend class TypedOperatorHandle;
  // 当前算子信息
  Dispatcher::OperatorDef* operatorDef_;
  // 全局算子列表迭代器
  std::list<Dispatcher::OperatorDef>::iterator operatorIterator_;

最后往全局单例 dispatcher 的成员变量 operatorLookupTable_ 写入 name - handle 对

schema 注册

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug, std::vector<at::Tag> tags) {
  TORCH_INTERNAL_ASSERT(!schema_.has_value());
  for (const auto& kernel : kernels_) {
    for (const auto &j : kernel.second) {
      if (j.inferred_function_schema != nullptr) {
        checkSchema(name_, schema, debug, j.kernel, *j.inferred_function_schema, j.debug);
      }
    }
  }
  // NB: don't register schema until after we've checked everything!
  dispatchKeyExtractor_.registerSchema(schema);
  schema_ = AnnotatedSchema(std::move(schema), std::move(debug));
  #ifndef C10_MOBILE
    tags_ = std::move(tags);
  #endif
}

registerSchema 首先遍历 kernels_ , 对 AnnotatedKernel 进行检查;
然后调用 dispatchKeyExtractor_.registerSchema(schema) 参考 dispatcher 记录参数信息;
最后生成成员变量 schema_ 。

算子函数注册

完成算子信息注册后,对于每个算子在对应的平台的实现,会调用下面的宏:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
// 对应代码是自动生成的,路径为 cmake-build-debug-wsl-gcc/aten/src/ATen/RegisterXXX.cpp
// TORCH_LIBRARY_IMPL(aten, CPU, m),展开如下
static void TORCH_LIBRARY_IMPL_init_aten_CPU_1(torch::Library&);
static const torch::detail::
    TorchLibraryInit TORCH_LIBRARY_IMPL_static_init_aten_CPU_1(
        torch::Library::IMPL,
        c10::guts::if_constexpr<
            c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::CPU)>(
            []() { return &TORCH_LIBRARY_IMPL_init_aten_CPU_1; },
            []() { return [](torch::Library&) -> void {}; }),
        "aten",
        c10::make_optional(c10::DispatchKey::CPU),
        "_file_name_",
        31034);
void TORCH_LIBRARY_IMPL_init_aten_CPU_1(torch::Library& m)
{
    // ...
    m.impl("add.Tensor", TORCH_FN(wrapper_CPU_add_Tensor));
    // ...
}

与第一个宏的方式类似,impl 函数第一个参数是算子名称,第二个参数是函数指针:

1
2
3
4
5
6
TORCH_FN(wrapper_CPU_add_Tensor);
// 展开为
::c10::CompileTimeFunctionPointer<
    std::remove_pointer_t<
        std::remove_reference_t<decltype(wrapper_CPU_add_Tensor)>>,
    wrapper_CPU_add_Tensor>()

impl 函数

1
2
3
4
5
6
7
8
9
10
11
  template <typename Name, typename Func>
  Library& impl(Name name, Func&& raw_f, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & {
    // TODO: need to raise an error when you impl a function that has a
    // catch all def
#if defined C10_MOBILE
    CppFunction f(std::forward<Func>(raw_f), NoInferSchemaTag());
#else
    CppFunction f(std::forward<Func>(raw_f));
#endif
    return _impl(name, std::move(f), rv);
  }

函数内部实例化了一个 CppFunction 类型的变量:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class TORCH_API CppFunction final {
  /// This overload accepts compile time function pointers, e.g.,
  /// `CppFunction(TORCH_FN(add_impl))`
  template <typename FuncPtr>
  explicit CppFunction(
      FuncPtr f,
      std::enable_if_t<
          c10::is_compile_time_function_pointer<FuncPtr>::value,
          std::nullptr_t> = nullptr)
      : func_(c10::KernelFunction::makeFromUnboxedFunction(f)),
        cpp_signature_(
            c10::impl::CppSignature::make<typename FuncPtr::FuncType>()),
        schema_(c10::detail::inferFunctionSchemaFromFunctor<
                typename FuncPtr::FuncType>()),
        debug_() {}
 private:
  c10::optional<c10::DispatchKey> dispatch_key_;
  c10::KernelFunction func_;
  c10::optional<c10::impl::CppSignature> cpp_signature_;
  std::unique_ptr<c10::FunctionSchema> schema_;
  std::string debug_;
}

然后对算子函数进行注册:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
Library& Library::_impl(const char* name_str, CppFunction&& f, _RegisterOrVerify rv) & {
  at::OperatorName name = _parseNameForLib(name_str);
  // See Note [Redundancy in registration code is OK]
  TORCH_CHECK(!(f.dispatch_key_.has_value() &&
                dispatch_key_.has_value() &&
                *f.dispatch_key_ != *dispatch_key_),
    IMPL_PRELUDE,
    "Explicitly provided dispatch key (", *f.dispatch_key_, ") is inconsistent "
    "with the dispatch key of the enclosing ", toString(kind_), " block (", *dispatch_key_, ").  "
    "Please declare a separate ", toString(kind_), " block for this dispatch key and "
    "move your impl() there.  "
    ERROR_CONTEXT
  );
  auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
  switch (rv) {
    case _RegisterOrVerify::REGISTER:
      registrars_.emplace_back(
        c10::Dispatcher::singleton().registerImpl(
          std::move(name),
          dispatch_key,
          std::move(f.func_),
          std::move(f.cpp_signature_),
          std::move(f.schema_),
          debugString(std::move(f.debug_), file_, line_)
        )
      );
      break;
    case _RegisterOrVerify::VERIFY:
      c10::Dispatcher::singleton().waitForImpl(name, dispatch_key);
      break;
  }
  return *this;
}

函数内再次使用了 Dispatcher 类的全局单例,调用 registerImpl 方法,找到之前注册的 OperatorEntry 后,调用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel(
  const c10::Dispatcher& dispatcher,
  c10::optional<DispatchKey> dispatch_key,
  KernelFunction kernel,
  c10::optional<CppSignature> cpp_signature,
  std::unique_ptr<FunctionSchema> inferred_function_schema,
  std::string debug
) {
  // 注册函数签名
  // ...

  // 注册函数指针
  auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : kernels_[DispatchKey::CompositeImplicitAutograd];
  // ...
  #ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
    k[0].kernel = std::move(kernel);
    k[0].inferred_function_schema = std::move(inferred_function_schema);
    k[0].debug = std::move(debug);
  #else
    k.emplace_front(std::move(kernel), std::move(inferred_function_schema), std::move(debug));
  #endif
    AnnotatedKernelContainerIterator inserted = k.begin();
    // update the dispatch table, i.e. re-establish the invariant
    // that the dispatch table points to the newest kernel
    if (dispatch_key.has_value()) {
      updateDispatchTable_(dispatcher, *dispatch_key);
    } else {
      updateDispatchTableFull_(dispatcher);
    }
    return inserted;
}

算子函数实现

TORCH_FN 移除了算子函数的引用和指针类型,统一声明成了CompileTimeFunctionPointer 类型的结构体参与后续的注册。宏内的函数定义是自动生成的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
struct structured_ufunc_add_CPU_functional final : public at::native::structured_ufunc_add_CPU {
    void set_output_strided(
        int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
        TensorOptions options, DimnameList names
    ) override {
        outputs_[output_idx] = create_out(sizes, strides, options);
        if (!names.empty()) {
          namedinference::propagate_names(*outputs_[output_idx], names);
        }
        // super must happen after, so that downstream can use maybe_get_output
        // to retrieve the output
        at::native::structured_ufunc_add_CPU::set_output_raw_strided(output_idx, sizes, strides, options, names);
    }
    void set_output_raw_strided(
        int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
        TensorOptions options, DimnameList names
    ) override {
        outputs_[output_idx] = create_out(sizes, strides, options);
        if (!names.empty()) {
          namedinference::propagate_names(*outputs_[output_idx], names);
        }
        // super must happen after, so that downstream can use maybe_get_output
        // to retrieve the output
        at::native::structured_ufunc_add_CPU::set_output_raw_strided(output_idx, sizes, strides, options, names);
    }
    const Tensor& maybe_get_output(int64_t output_idx) override {
      return *outputs_[output_idx];
    }
    std::array<c10::ExclusivelyOwned<Tensor>, 1> outputs_;
};
at::Tensor wrapper_CPU_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
structured_ufunc_add_CPU_functional op;
op.meta(self, other, alpha);
op.impl(self, other, alpha, *op.outputs_[0]);
return std::move(op.outputs_[0]).take();
}

wrapper_CPU_add_Tensor 函数内变量 op 相关的声明和定义在多个地方:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// cmake-build-debug-wsl-gcc/aten/src/ATen/ops/add_meta.h
namespace at {
namespace meta {
struct TORCH_API structured_add_Tensor : public TensorIteratorBase {
    void meta(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha);
};}}
// aten/src/ATen/native/BinaryOps.cpp
void structured_add_Tensor::meta // 由 TORCH_META_FUNC2(add, Tensor) 展开
(
  const Tensor& self, const Tensor& other, const Scalar& alpha
) {
  build_borrowing_binary_op(maybe_get_output(), self, other);
  native::alpha_check(dtype(), alpha);
}
// cmake-build-debug-wsl-gcc/aten/src/ATen/ops/add_native.h
namespace at {
namespace native {
struct TORCH_API structured_ufunc_add_CPU : public at::meta::structured_add_Tensor {
void impl(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, const at::Tensor & out);
};}}
// cmake-build-debug-wsl-gcc/aten/src/ATen/UfuncCPU_add.cpp
void structured_ufunc_add_CPU::impl // 由 TORCH_IMPL_FUNC(ufunc_add_CPU) 展开
(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, const at::Tensor & out) {
  add_stub(device_type(), *this, alpha);
}

最终调用的是 add_stub ,add_stub 结构体重载了()操作符 add_stub 的声明与定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
// aten/src/ATen/native/BinaryOps.h
using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
// DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub) 宏展开如下:
struct add_stub : DispatchStub<structured_binary_fn_alpha, add_stub> { 
  add_stub() = default; 
  add_stub(const add_stub&) = delete; 
  add_stub& operator=(const add_stub&) = delete; 
  }; 
extern TORCH_API struct add_stub add_stub
// aten/src/ATen/native/BinaryOps.cpp 中也定义了一个add_stub,但是可见性不同
// DEFINE_DISPATCH(add_stub); 宏展开
struct add_stub add_stub

add_stub 结构体继承自 DispatchStub 的一个偏特化模板:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
template <typename rT, typename T, typename... Args>
struct DispatchStub<rT (*)(Args...), T> {
  using FnPtr = rT (*) (Args...);

  DispatchStub() = default;
  DispatchStub(const DispatchStub&) = delete;
  DispatchStub& operator=(const DispatchStub&) = delete;

private:
  FnPtr get_call_ptr(DeviceType device_type) {
    return reinterpret_cast<FnPtr>(
      impl.get_call_ptr(device_type
      , reinterpret_cast<void*>(DEFAULT)
#ifdef HAVE_AVX512_CPU_DEFINITION
      , reinterpret_cast<void*>(AVX512)
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
      , reinterpret_cast<void*>(AVX2)
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
      , reinterpret_cast<void*>(VSX)
#endif
      )
    );
  }

public:
  template <typename... ArgTypes>
  rT operator()(DeviceType device_type, ArgTypes&&... args) {
    FnPtr call_ptr = get_call_ptr(device_type);
    return (*call_ptr)(std::forward<ArgTypes>(args)...);
  }

  void set_cuda_dispatch_ptr(FnPtr fn_ptr) {
    impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  }

  void set_hip_dispatch_ptr(FnPtr fn_ptr) {
    impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  }

  static FnPtr DEFAULT;
#ifdef HAVE_AVX512_CPU_DEFINITION
  static FnPtr AVX512;
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
  static FnPtr AVX2;
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
  static FnPtr VSX;
#endif
private:
  DispatchStubImpl impl;
};

运行 REGISTER_DISPATCH 宏,在最后展开的代码中给 DispatchStub 的全特化模板的静态成员变量赋予对应硬件下的函数指针:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
void add_kernel(TensorIteratorBase& iter, const Scalar& alpha_scalar) {
  if (iter.dtype() == ScalarType::Bool) {
      using scalar_t = bool;
      auto alpha = alpha_scalar.to<scalar_t>();
      cpu_kernel(iter,
        [=](scalar_t a, scalar_t b) __ubsan_ignore_undefined__ -> scalar_t { return a + alpha * b; });
  } else {
    AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "add_cpu/sub_cpu", [&]() {
      auto alpha = alpha_scalar.to<scalar_t>();
      auto alpha_vec = Vectorized<scalar_t>(alpha);
      cpu_kernel_vec(iter,
        [=](scalar_t a, scalar_t b) __ubsan_ignore_undefined__ -> scalar_t { return a + alpha * b; },
        [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) __ubsan_ignore_undefined__ {
          return vec::fmadd(b, alpha_vec, a);
        });
      });
  }
}
REGISTER_DISPATCH(add_stub, &add_kernel);
// REGISTER_DISPATCH 宏在不同的平台下有不同的定义,cpu下:
#elif defined(CPU_CAPABILITY)
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
#define REGISTER_NO_AVX512_DISPATCH(name, fn_type)                             \
  REGISTER_AVX512_DISPATCH(name, static_cast<fn_type>(nullptr))

#define REGISTER_ARCH_DISPATCH(name, arch, fn) \
  template <> decltype(fn) DispatchStub<decltype(fn), struct name>::arch = fn;

每个原生算子的实现代码都在 native 文件夹下,经由如上步骤生成了对应的函数指针包装,参与到算子注册过程。

This post is licensed under CC BY 4.0 by the author.