Skip to content

PyTorch dispatcher walkthrough

Jianhui Yuan Xu edited this page Jan 31, 2024 · 7 revisions

Page Maintainers: @bdhirsh

Codegen + Structured Kernels Overview

Ed has a really great overview of code-generation and why we have it in PyTorch: check out his podcast: https://pytorch-dev-podcast.simplecast.com/episodes/code-generation.

This document will go over our codegen subsystem + structured kernels in more detail, and involve you using gdb to jump through the different code-generated files that are part of a call into torch.add().

What it is

We have a code-generation pipeline that runs as part of the PyTorch build - it reads in some yaml files, and spits out a bunch of C++ files.

Why we have it

So, why do we have codegen? One big motivating factor is to reduce boilerplate. PyTorch has a lot of operators, and there’s a lot of stuff that should “just work” for every operator. We don’t want to make someone hand-write all of that functionality whenever a new operator is added. Instead, we code-generate it.

A (non-exhaustive) list of functionality (we need all of this for every operator, so multiply by ~2000):

  • bindings to python
  • The frontend C++ API
  • autograd support
  • registering kernels to the dispatcher
  • other stuff
    • special logic for factory functions
    • torch.jit.trace functionality

Inputs

We have a yaml file, native_functions.yaml, which describes metadata about each operator that gets consumed by the codegen: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml

We’re going to focus on the operator torch.add(a, b, out=c), which corresponds to the yaml entry add.out:

- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
  device_check: NoCheck # TensorIterator
  structured: True
  structured_inherits: TensorIteratorBase
  ufunc_inner_loop:
    Generic: add (AllAndComplex, BFloat16, Half, ComplexHalf)
    ScalarOnly: add (Bool)
  dispatch:
    SparseCPU: add_out_sparse_cpu
    SparseCUDA: add_out_sparse_cuda
    SparseCsrCPU: add_out_sparse_compressed_cpu
    SparseCsrCUDA: add_out_sparse_compressed_cuda
    MkldnnCPU: mkldnn_add_out
    MPS: add_out_mps
  tags: pointwise

There’s public documentation on each of the different pieces of yaml here: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/README.md

The codegen is written in a functional style using python dataclasses to represent the different inputs/intermediates/outputs. For example, each entry in native_functions.yaml is represented in the codegen as a NativeFunction object: https://github.com/pytorch/pytorch/blob/6596a3f23dfe1ea4175637fa979bcbfbff397737/torchgen/model.py#L427

Finally, one of the main entry points to the codegen is in tools/codegen/gen.py (there’s also a separate entry point for the autograd codegen pipeline). You can see the part of the file where we generate the C++ API for example, Functions.h: (https://github.com/pytorch/pytorch/blob/f8e14f3b46e68a5271a8c57ce749ad8057d77ddd/torchgen/gen.py#L1781) It reads in a template file, aten/src/ATen/templates/Functions.h (https://github.com/pytorch/pytorch/blob/f8e14f3b46e68a5271a8c57ce749ad8057d77ddd/aten/src/ATen/templates/Functions.h), and generates the file build/aten/src/ATen/Functions.h:

Exercise 0: Full Stack Trace of torch.add

For this exercise you’ll need to have pytorch built with debug symbols. I usually do that with USE_CUDA=0 DEBUG=1 python setup.py develop (The USE_CUDA=0 is because we don’t need it, and building with cuda takes a long time).

We’re going to run a small python program using gdb to view the full stack trace. Create a python script, tmp.py, with the following:

import torch
a = torch.tensor([1, 1])
b = torch.tensor([1, 1])
c = torch.add(a, b)

Run gdb python(or lldb python -- tmp.py) to start up gdb. We’re going to set a breakpoint in the add kernel - to do that, in the gdb prompt, type break structured_ufunc_add_CPU::impl(or b structured_ufunc_add_CPU::impl in lldb). Then run your script inside of gdb with run tmp.py(or r in lldb).

The debugger should pause inside of the add kernel. Type bt to view the current stack trace.

Ignoring the first ~10 function calls through the python interpreter, you should see a stack trace that looks something like the following:

* thread #1, name = 'python', stop reason = breakpoint 1.1
  * frame #0: 0x00007fffd38ed42a libtorch_cpu.so`at::native::structured_ufunc_add_CPU::impl(this=0x00007fffffffae60, self=0x00007fffffffbbe0, other=0x00007fffffffbbd8, alpha=0x00007fffffffbbb0, out=0x00007fffffffb190) at UfuncCPU_add.cpp:30:11
    frame #1: 0x00007fffd2ae81aa libtorch_cpu.so`at::(anonymous namespace)::wrapper_CPU_add_Tensor(self=0x00007fffffffbbe0, other=0x00007fffffffbbd8, alpha=0x00007fffffffbbb0) at RegisterCPU.cpp:1576:8
    frame #2: 0x00007fffd2c9079d libtorch_cpu.so`c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(const at::Tensor&, const at::Tensor&, const c10::Scalar&), at::(anonymous namespace)::wrapper_CPU_add_Tensor>, at::Tensor, c10::guts::typelist::typelist<const at::Tensor&, const at::Tensor&, const c10::Scalar&> >, at::Tensor(const at::Tensor&, const at::Tensor&, const c10::Scalar&)>::call(c10::OperatorKernel *, c10::DispatchKeySet, const at::Tensor &, const at::Tensor &, const c10::Scalar &) [inlined] operator(args#2=0x00007fffffffbbb0, args#1=0x00007fffffffbbd8, args#0=0x00007fffffffbbe0, this=0x0000555556633ac0) at WrapFunctionIntoFunctor.h:13:72
    frame #3: 0x00007fffd2c90759 libtorch_cpu.so`c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(const at::Tensor&, const at::Tensor&, const c10::Scalar&), at::(anonymous namespace)::wrapper_CPU_add_Tensor>, at::Tensor, c10::guts::typelist::typelist<const at::Tensor&, const at::Tensor&, const c10::Scalar&> >, at::Tensor(const at::Tensor&, const at::Tensor&, const c10::Scalar&)>::call(functor=0x0000555556633ac0, (null)=(repr_ = 32769), args#0=0x00007fffffffbbe0, args#1=0x00007fffffffbbd8, args#2=0x00007fffffffbbb0) at make_boxed_from_unboxed_functor.h:468:63
    frame #4: 0x00007fffd1f5dec7 libtorch_cpu.so`at::Tensor c10::callUnboxedKernelFunction<at::Tensor, at::Tensor const&, at::Tensor const&, c10::Scalar const&>(unboxed_kernel_func=0x00007fffd2c906ee, functor=0x0000555556633ac0, dispatchKeySet=(repr_ = 32769), (null)=0x00007fffffffbbe0, (null)=0x00007fffffffbbd8, (null)=0x00007fffffffbbb0) at KernelFunction_impl.h:52:72
    frame #5: 0x00007fffd1e1ea24 libtorch_cpu.so`at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, at::Tensor const&, c10::Scalar const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, c10::Scalar const&)> const&, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::Scalar const&) const at KernelFunction_impl.h:104:87
    frame #6: 0x00007fffd1e1e9aa libtorch_cpu.so`at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, at::Tensor const&, c10::Scalar const&>(this=0x00007fffe61a1de0, op=0x00007fffe61c7db0, currentDispatchKeySet=(repr_ = 32769), (null)=0x00007fffffffbbe0, (null)=0x00007fffffffbbd8, (null)=0x00007fffffffbbb0) const at Dispatcher.h:712:102
    frame #7: 0x00007fffd2332b7c libtorch_cpu.so`at::_ops::add_Tensor::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::Scalar const&) [inlined] c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, c10::Scalar const&)>::redispatch(args#2=0x00007fffffffbbb0, args#1=0x00007fffffffbbd8, args#0=0x00007fffffffbbe0, currentDispatchKeySet=(repr_ = 32769), this=<unavailable>) const at Dispatcher.h:532:126
    frame #8: 0x00007fffd2332acd libtorch_cpu.so`at::_ops::add_Tensor::redispatch(dispatchKeySet=(repr_ = 32769), self=0x00007fffffffbbe0, other=0x00007fffffffbbd8, alpha=0x00007fffffffbbb0) at Operators_2.cpp:1049:60
    frame #9: 0x00007fffd502cbf2 libtorch_cpu.so`at::redispatch::add(dispatchKeySet=(repr_ = 32769), self=0x00007fffffffbbe0, other=0x00007fffffffbbd8, alpha=0x00007fffffffbbb0) at RedispatchFunctions.h:607:83
    frame #10: 0x00007fffd4ef7650 libtorch_cpu.so`operator(__closure=0x00007fffffffb5c0) at VariableType_2.cpp:5969:85
    frame #11: 0x00007fffd4ef7b7c libtorch_cpu.so`torch::autograd::VariableType::(anonymous namespace)::add_Tensor(ks=(repr_ = 274877939713), self=0x00007fffffffbbe0, other=0x00007fffffffbbd8, alpha=0x00007fffffffbbb0) at VariableType_2.cpp:5970:6
    frame #12: 0x00007fffd4ff0ad9 libtorch_cpu.so`c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(c10::DispatchKeySet, const at::Tensor&, const at::Tensor&, const c10::Scalar&), torch::autograd::VariableType::(anonymous namespace)::add_Tensor>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, const at::Tensor&, const at::Tensor&, const c10::Scalar&> >, at::Tensor(c10::DispatchKeySet, const at::Tensor&, const at::Tensor&, const c10::Scalar&)>::call(c10::OperatorKernel *, c10::DispatchKeySet, const at::Tensor &, const at::Tensor &, const c10::Scalar &) [inlined] operator(args#3=0x00007fffffffbbb0, args#2=0x00007fffffffbbd8, args#1=0x00007fffffffbbe0, args#0=(repr_ = 274877939713), this=0x0000555557b02710) at WrapFunctionIntoFunctor.h:13:72
    frame #13: 0x00007fffd4ff0a80 libtorch_cpu.so`c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(c10::DispatchKeySet, const at::Tensor&, const at::Tensor&, const c10::Scalar&), torch::autograd::VariableType::(anonymous namespace)::add_Tensor>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, const at::Tensor&, const at::Tensor&, const c10::Scalar&> >, at::Tensor(c10::DispatchKeySet, const at::Tensor&, const at::Tensor&, const c10::Scalar&)>::call(functor=0x0000555557b02710, dispatchKeySet=(repr_ = 274877939713), args#0=0x00007fffffffbbe0, args#1=0x00007fffffffbbd8, args#2=0x00007fffffffbbb0) at make_boxed_from_unboxed_functor.h:485:79
    frame #14: 0x00007fffd1f5dec7 libtorch_cpu.so`at::Tensor c10::callUnboxedKernelFunction<at::Tensor, at::Tensor const&, at::Tensor const&, c10::Scalar const&>(unboxed_kernel_func=0x00007fffd4ff0a0b, functor=0x0000555557b02710, dispatchKeySet=(repr_ = 274877939713), (null)=0x00007fffffffbbe0, (null)=0x00007fffffffbbd8, (null)=0x00007fffffffbbb0) at KernelFunction_impl.h:52:72
    frame #15: 0x00007fffd233293b libtorch_cpu.so`at::_ops::add_Tensor::call(at::Tensor const&, at::Tensor const&, c10::Scalar const&) at KernelFunction_impl.h:104:87
    frame #16: 0x00007fffd23328ac libtorch_cpu.so`at::_ops::add_Tensor::call(at::Tensor const&, at::Tensor const&, c10::Scalar const&) at Dispatcher.h:694:97
    frame #17: 0x00007fffd2332695 libtorch_cpu.so`at::_ops::add_Tensor::call(at::Tensor const&, at::Tensor const&, c10::Scalar const&) [inlined] c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, c10::Scalar const&)>::call(args#2=0x00007fffffffbbb0, args#1=0x00007fffffffbbd8, args#0=0x00007fffffffbbe0, this=<unavailable>) const at Dispatcher.h:527:97
    frame #18: 0x00007fffd233257b libtorch_cpu.so`at::_ops::add_Tensor::call(self=0x00007fffffffbbe0, other=0x00007fffffffbbd8, alpha=0x00007fffffffbbb0) at Operators_2.cpp:1042:38
    frame #19: 0x00007fffe74b677c libtorch_python.so`at::Tensor::add(this=0x00007fffffffbbe0, other=0x00007fffffffbbd8, alpha=0x00007fffffffbbb0) const at TensorBody.h:1664:79
    frame #20: 0x00007fffe75f1164 libtorch_python.so`operator(__closure=0x00007fffffffbaad, self=0x00007fffffffbbe0, other=0x00007fffffffbbd8, alpha=0x00007fffffffbbb0) at python_torch_functions_2.cpp:1400:39
    frame #21: 0x00007fffe75f1777 libtorch_python.so`torch::autograd::THPVariable_add(self_=0x0000000000000000, args=0x00007ffd68f74b80, kwargs=0x0000000000000000) at python_torch_functions_2.cpp:1402:33

That’s a lot of function calls! We’re going walk through the main pieces that are relevant to codegen and where they live. For each piece, I listed the relevant numbers in the gdb stack trace.

Tip: In lldb, if you are curious about the abstract path of the source files listed in the frame backtrace, for example the 12th frame , you may first switch to that frame f 12 and then show its source info so i.

(1) Python Bindings

#21: torch::autograd::THPVariable_add

This is the first stop that we hit after going through the python interpreter: python bindings. This is the code that interfaces directly with cpython to bind our C++ functions to python.

You can see a snippet of the function below: Its job is basically to take all of the PyObjects that it was handed from CPython, parse them into actual C++ types (like at::Tensor), and call into the C++ API. It does that below by calling into the Tensor add method: self.add(other, alpha).

static PyObject * THPVariable_add(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  HANDLE_TH_ERRORS
  static PythonArgParser parser({
    "add(Tensor input, Scalar alpha, Tensor other, *, Tensor out=None)|deprecated",
    "add(Tensor input, Tensor other, *, Scalar alpha=1, Tensor out=None)",
  }, /*traceable=*/true);

  ParsedArgs<4> parsed_args;
  auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
  ...
        auto dispatch_add = [](const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) -> at::Tensor
          pybind11::gil_scoped_release no_gil;
          return self.add(other, alpha);                                                                                  return self.add(other, alpha);
        };
        return wrap(dispatch_add(_r.tensor(0), _r.tensor(1), _r.scalar(2)));
  ...
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

These are all codegen’d and live in torch/csrc/autograd/generated/python_torch_functions_2.cpp.

(2) C++ API

#18: at::_ops::add_Tensor::call #19: at::Tensor::add

The next stop is the C++ method API, which is one of the top-level API’s for calling into the dispatcher. The dispatcher then looks at all of the arguments + any thread-local state to figure out which kernel to dispatch to. https://github.com/pytorch/pytorch/wiki/PyTorch-dispatcher-walkthrough has some more details about the dispatcher key-calculation process.

In build/aten/src/ATen/core/TensorBody.h:

//namespace at
inline at::Tensor Tensor::add(const at::Tensor & other, const at::Scalar & alpha) const {
    return at::_ops::add_Tensor::call(const_cast<Tensor&>(*this), other, alpha);
}

In build/aten/src/ATen/Operators_2.cpp:

static C10_NOINLINE c10::TypedOperatorHandle<add_Tensor::schema> create_add_Tensor_typed_handle() {
  return c10::Dispatcher::singleton()
      .findSchemaOrThrow(add_Tensor::name, add_Tensor::overload_name)
      .typed<add_Tensor::schema>();
}

at::Tensor add_Tensor::call(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {

    static auto op = create_add_Tensor_typed_handle();
    return op.call(self, other, alpha);
}

(3) Autograd kernel

#11: torch::autograd::VariableType::(anonymous namespace)::add_Tensor

After a bunch of dispatcher-related functions, the dispatcher eventually takes us to the autograd add kernel. The autograd kernel:

  • saves some metadata for autograd
  • re-invokes the dispatcher by calling at::redispatch::add(ks & c10::after_autograd_keyset, self_, other_, alpha);

In torch/csrc/autograd/generated/VariableType_2.cpp:

// namespace at::VariableType
at::Tensor add_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
  ...
}

// Register `add_Tensor` so that it can be found
TORCH_LIBRARY_IMPL(aten, Autograd, m) {
  ...
  m.impl("add.Tensor",TORCH_FN(VariableType::add_Tensor));
  ...
}

The autograd kernel ends up calling back into the C++ API (by calling at::redispatch::add), which then calls back into the dispatcher and calculates the next kernel to dispatch to.

(4) CPU kernel

#0: at::native::structured_ufunc_add_CPU::impl #1: at::(anonymous namespace)::wrapper_CPU_add_Tensor

After a few more function hops through the dispatcher, we eventually dispatch to the CPU add kernel, which has to actually carry out the computation. The code for the cpu kernel (and the code that registers the kernel to the dispatcher) looks like this:

In build/aten/src/ATen/RegisterCPU.cpp:

at::Tensor wrapper_CPU_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
  structured_ufunc_add_CPU_functional op;
  op.meta(self, other, alpha);
  op.impl(self, other, alpha, op.outputs_[0]);
  return std::move(op.outputs_[0]);
}

TORCH_LIBRARY_IMPL(aten, CPU, m) {
  m.impl("add.Tensor", TORCH_FN(wrapper_CPU_add_Tensor));
}

This code looks a little funky; it calls into a meta() and impl() function that are defined elsewhere. This is because add is implemented as a structured kernel - a new way of implementing operators in pytorch.

That code is some scaffolding that contains a call to the hand-written “cpu add” kernel. The call to op.impl() corresponds directly to the add kernel written in build/aten/src/ATen/UfuncCPU_add.cpp

TORCH_IMPL_FUNC(ufunc_add_CPU)(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, const at::Tensor & out) {
  add_stub(device_type(), *this, alpha);
}

(Note: there’s a bit more indirection inside of the handwritten kernel before reaching main part of the add kernel, which lives in build/aten/src/ATen/UfuncCPUKernel_add.cpp

void add_kernel(TensorIteratorBase& iter, const at::Scalar & alpha) {
  AT_DISPATCH_SWITCH(iter.common_dtype(), "add_stub",
    ...
    AT_DISPATCH_CASE(at::ScalarType::Long,
      [&]() {
        auto _s_alpha = alpha.to<scalar_t>();
        auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha);
        cpu_kernel_vec(iter,
          [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); },
          [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha)};
        );
      }
    )
    ...
  )
}

The code-generated code above lives in the code-generated file build/aten/src/ATen/UfuncCPUKernel_add.cpp

So, the code above calls into our hand-written CPU add kernel, returns a new output tensor containing the result, and we’re done!

Takeaway

The main takeaway from the exercise above is that:

  • A lot of stuff happens when you call an operator
  • ...most of which is code-generated! A lot of this logic is really similar across PyTorch’s ~2000 operators, and ripe for abstracting over (through something like code generation).

Sometimes when you’re working on / debugging a feature, it can be useful to know which bits of logic are codegen’d, and where that logic lives.

Structured Kernels

Another big part of the codegen is “structured kernels” - a new way of writing kernels in PyTorch, which uses some clever factoring + a bunch of codegen to reduce the amount of boilerplate required when writing kernels. torch.add is implemented as a “structured kernel”, so we’re going to walk through the bits of it related to structured kernels.

The process of implementing an operator as a structured kernel involves writing two functions:

  • A “meta” function, which asserts that the inputs have the correct shape/dtype and figures out what size the output tensor should be.
  • An “impl” function, which does the actual computation. There will be a separate impl() function for every backend (cpu, cuda, xla, etc).

The codegen is responsible for taking these two functions, and plugging them together in the right way to create all 3 variants of the operator for you:

  • at::add() (functional version)
  • at::add_() (inplace version)
  • at::add_out() (out= version)

Helpful reading: this presentation on structured kernels, including a diagram on the class hierarchy (which will be useful in the exercise further down). https://drive.google.com/file/d/16qPvpCF4Jbh7ss2lCQMk5hmcyzJvUyQj/view?usp=sharing

See also: the structured kernels RFC contains a more detailed overview of what they are and what the codegen creates: https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md

Structured Kernel codegen output example: torch.add

The CPU kernel for the torch.add operator lives in https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BinaryOps.cpp#L16, and has two components:

Meta function:

// expands to structured_add_Tensor::meta() { ... }
TORCH_META_FUNC2(add, Tensor) (
  const Tensor& self, const Tensor& other, const Scalar& alpha
) {
  build_borrowing_binary_op(maybe_get_output(), self, other);
  native::alpha_check(dtype(), alpha);
}

Impl function:

// expands to structured_add_out::impl() { ... }
TORCH_IMPL_FUNC(add_out) (
  const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& result
) {
  add_stub(device_type(), *this, alpha);
  TORCH_INTERNAL_ASSERT(result.scalar_type() == output().dtype());
}

So, the code above implements the two functions structured_add_Tensor::meta() and structured_add_out::impl(), but where are they declared? The codegen creates declarations for them.

In NativeMetaFunctions.h:

// namespace at::meta
struct TORCH_API structured_add_Tensor : public TensorIteratorBase {
    void meta(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha);
};

In NativeFunctions.h:

// namespace at::native
struct TORCH_API structured_add_out : public at::meta::structured_add_Tensor {
    void impl(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, const at::Tensor & out);
};

You can see that the codegen generated declarations for the two functions, and we hand-implemented them ourselves in BinaryOps.cpp. But how does the codegen use them?

The code-generated logic that stitches them together lives in the code-generated file RegisterCPU.cpp, and looks like this:

// functional version
at::Tensor wrapper_CPU_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
  structured_ufunc_add_CPU_functional op;
  op.meta(self, other, alpha);
  op.impl(self, other, alpha, op.outputs_[0]);
  return std::move(op.outputs_[0]);
}

// inplace version
at::Tensor & wrapper_CPU_add__Tensor(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
  structured_ufunc_add_CPU_inplace op(self);
  op.meta(self, other, alpha);
  op.impl(self, other, alpha, op.outputs_[0]);
  if (op.proxy_outputs_[0].has_value()) op.outputs_[0].get().copy_(*op.proxy_outputs_[0]);
  return self;
}

// out= version
at::Tensor & wrapper_CPU_add_out_out(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) {
  structured_ufunc_add_CPU_out op(out);
  op.meta(self, other, alpha);
  op.impl(self, other, alpha, op.maybe_get_output(0));
  if (op.proxy_outputs_[0].has_value()) op.outputs_[0].get().copy_(*op.proxy_outputs_[0]);
  return out;
}

// registering the 3 kernels above to the dispatcher, under the CPU Dispatch Key.
TORCH_LIBRARY_IMPL(aten, CPU, m) {
  ...
  m.impl("add.Tensor", TORCH_FN(wrapper_CPU_add_Tensor));
  m.impl("add.out", TORCH_FN(wrapper_CPU_add_out_out));
  m.impl("add_.Tensor", TORCH_FN(wrapper_CPU_add__Tensor));
}

This is the "final" output - the 3 operators that we needed. The codegen created 3 new kernels, each of which call into our meta() and impl() functions. The only difference between the 3 is that they use different classes, each of which has a different implementation of set_output(). You can also find the definition of all 3 of these classes in RegisterCPU.cpp, but below is the example for structured_ufunc_add_CPU_functional:

struct structured_ufunc_add_CPU_functional final : public at::native::structured_ufunc_add_CPU {
    void set_output_strided(
        int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
        TensorOptions options, DimnameList names
    ) override {
        outputs_[output_idx] = create_out(sizes, strides, options);
        if (!names.empty()) {
          namedinference::propagate_names(outputs_[output_idx], names);
        }
        // super must happen after, so that downstream can use maybe_get_output
        // to retrieve the output
        at::native::structured_ufunc_add_CPU::set_output_raw_strided(output_idx, sizes, strides, options, names);
    }
    void set_output_raw_strided(
        int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
        TensorOptions options, DimnameList names
    ) override {
        outputs_[output_idx] = create_out(sizes, strides, options);
        if (!names.empty()) {
          namedinference::propagate_names(outputs_[output_idx], names);
        }
        // super must happen after, so that downstream can use maybe_get_output
        // to retrieve the output
        at::native::structured_ufunc_add_CPU::set_output_raw_strided(output_idx, sizes, strides, options, names);
    }
    const Tensor& maybe_get_output(int64_t output_idx) override {
      return outputs_[output_idx];
    }
    std::array<Tensor, 1> outputs_;
};

You can see that it has its own definition of set_output_strided() and set_output_raw_strided() - in this case, it’s implementing the functional at::Tensor::add kernel, so it needs to allocate a new tensor as the output (it does that using at::create_out()).

That class corresponds to one of the leaves of the class hierarchy - a picture of the full class hierarchy can be found in the linked presentation (https://drive.google.com/file/d/16qPvpCF4Jbh7ss2lCQMk5hmcyzJvUyQj/view?usp=sharing)

Clone this wiki locally