Skip to content

[RFC] Supporting Eager Mode via torch.compile #115545

@EikanWang

Description

@EikanWang

🚀 The feature, motivation and pitch

Motivation

By now, PyTorch has defined > 2000 operations. Meanwhile, PyTorch users start from eager mode. If a new backend intends to support PyTorch eager mode, it means that the backend, like Intel GPU, has to implement all these operations. Otherwise, the users may encounter unimplemented errors if they run PyTorch on the new backend. This scenario presents two substantial challenges - engineering effort and maintenance effort.

  • Engineering Effort - Implementing all ATen operations (over 2000 in number) for a new backend is a considerable task, requiring significant development resources.
  • Maintenance Effort - PyTorch operations may change its interfaces, such as adding or removing parameters. Regarding these changes, the backend maintainers will need to adopt these backend-specific implementations according to these changes.

Given these challenges, we propose an alternate technical pathway for eager mode support through torch.compile, offering two main advantages over the traditional implementation approach:

  • Reduced Implementation Requirement: Backends only need to implement fallback operations that are not yet supported by torch.compile or where torch.compile explicitly falls back to ATen.
  • Device-Agnostic Interface: torch.compile allows generating performant backend code without the need for backend-specific implementation of each ATen operation.

Approach

We propose compiling single aten operations using torch.compile and registering them in PyTorch on the Python side. This approach simplifies backend implementation and maintenance. Here's an illustrative example:

# Example code demonstrating the proposed method
x = torch.empty(1, device="xpu").fill_(1)
y = torch.empty(1, device="xpu").fill_(2)

def wrapper_fn_sub(*args, **kwargs):
    with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False):
        opt_fn_sub = torch.compile(torch.ops.aten.sub)
        res = opt_fn_sub(a, b)
        return res

custom_op_lib_xpu_impl = torch.library.Library("aten", "IMPL")
custom_op_lib_xpu_impl.impl("sub.Tensor", wrapper_fn_sub, "XPU")
res = x - y

ref = torch.empty(1, device="xpu").fill_(-1)
assert all(res == ref)

Detail Design

The torch.compile invocation needs to be wrapped as a general Python function and then registered to the torch dispatcher. The mechanism ensures robust functionality.

Meanwhile, we need to accelerate the performance by mitigating Python and torch.compile overhead. We propose a cache mechanism for this purpose.

So, the detailed design focuses on the registration and cache.

Registration

To trigger torch.compile to produce a C++/Triton kernel, we always need to register a Python kernel for each ATen operation, just like the above example code.

But we do not need to always invoke the Python kernel if its corresponding torch.compile kernel has been produced.

In addition, the context switch between Python and C++ introduces additional overhead like Python GIL. And the performance of a Python implementation is worse than its equivalent C++ implementation in general.

Therefore, we always prefer to avoid running code in the Python world.

To achieve this goal, we wrap the Python kernel as a C++ function/class just like what torch has done for the other Python kernels(PythonKernelHolder)

  • When callers invoke torch.Library.Library.impl to register a Python kernel, the Python kernel will be wrapped by PythonKernelHolder and then cast the PythonKernelHolder as a BoxedFunctor.
[](const py::object& self,
   const char* name,
   // TODO: empty string no longer works
   c10::DispatchKey dispatch,
   py::object func) {
  HANDLE_TH_ERRORS
  auto& lib = self.cast<torch::Library&>();
  if (func.is(py::module::import("torch.library")
                  .attr("fallthrough_kernel"))) {
    lib.impl(
        name,
        torch::dispatch(dispatch, CppFunction::makeFallthrough()),
        register_or_verify());
  } else {
    lib.impl(
        name,
        torch::dispatch(
            dispatch,
        CppFunction::makeFromBoxedFunctor(std::make_unique<PythonKernelHolder>(func, h))),
        register_or_verify());
    python_registrations_[lib._resolve(name)].insert_or_assign(
        dispatch,
        std::make_shared<c10::SafePyObject>(
            func.release().ptr(), getPyInterpreter()));
  }
  END_HANDLE_TH_ERRORS_PYBIND
}

Definitely, we can reuse the mechanism and customize it a little bit to accelerate the operation by introducing a cache mechanism. The following section will elaborate on the detailed design of the cache mechanism.

Suppose a class is named AOTICompatiblePythonKernelHolder for this purpose, its major features are as follows.

  • Lookup a global cache to check whether its torch.compile-ed kernel is available
  • Cache hit
    • Invoke the cached kernel directly
  • Cache miss
    • Invoke the registered Python kernel to trigger torch.compile to produce kernel
    • Run the kernel produced by torch.compile and cache it

Its pseudo-code could be as follows.

class AOTICompatiblePythonKernelHolder : public c10::OperatorKernel {
 public:
  // TODO: We can add more information to the ctor to accelerate the cache lookup 
  AOTICompatiblePythonKernelHolder() = default;
  void operator()(
    const c10::OperatorHandle& op,
    c10::DispatchKeySet keyset,
    torch::jit::Stack* stack) override {


    // aoti_kernel_cache is a global cache to store all the kernels produced by torch.compile for eager
    auto cache_key = aoti_kernel_cache.getCacheKey(op, stack);
    auto kernel_handle = aoti_kernel_cache.get(cache_key);


    if (kernel_handle) {
      // Cach Hit
      auto res = (*kernel_handle)(stack);
      pushPyOutToStack(op, stack, ::reinterpret_steal<py::object>(res), "PythonKernelHolder");
    } else {
      // Cache Miss
      auto func = ::reinterpret_borrow<py::object>(op.schema().name());
      auto arguments = torch::jit::pop(*stack, .schema().arguments().size());
      py::gil_scoped_acquire g;
      auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
      // Invoke the Python kernel to trigger torch.compile and get the result
      auto obj = py::reinterpret_steal<py::object>(PyObject_Call(
          func.ptr(getPyInterpreter()),
          args_kwargs.first.ptr(),
          args_kwargs.second.ptr()));
      if (!obj) {
        throw python_error();
      }
      pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
      // Load the kernel produced by torch.compile from disk into memory for next run
      aoti_kernel_cache.load(op, cache_key);
    }
  }
};

In addition, all the kernels produced by torch.compile for eager should be wrapped by C++(CppWrapper) and loaded by AOTIModelContainerRunner to mitigate the Python overhead.

Cache

Due to the overhead of torch.compile is non-negligible (CPU is Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz; PyTorch commit: e8a9d08),

  • Initialization Phase:
    • There is a significant overhead of over 1 second during the initial run of torch.compile. This overhead is a one-time occurrence, triggered only during the first compilation.
  • Performance of Specific Aten Operations Post-Initialization:
    • On First Execution: After the initial torch.compile setup, the overhead for executing a specific Aten operation for the first time is approximately 115 milliseconds.
    • On Subsequent Executions Without Re-compilation: When the same Aten operation is executed again, without re-compiling, the overhead becomes negligible, aligning with the performance of the traditional eager mode.

Therefore, we introduce a cache mechanism to mitigate the overhead.

By the way, the cache should be a persistent cache. It can avoid producing kernel code multiple times for a particular aten operation if the input parameters have the same meta information when compiling the operation.

Suppose a Mul operation, the torch.compiled-based kernel w/ cache could be as follows from the dispatch perspective.
image

Cache Key

The parameters of an aten operation will be packed as torch.jit.stack. Therefore, we can unpack the parameters one by one and extract information from torch.jit.IValue to constitute the cache key for each operator.

In addition, each torch.compile-ed kernel is dedicated to a particular aten operation. It means that the kernel knows the exact semantics of its implementation. And its schema is informative. Based on this information, the cache key could be constituted by the following factors.

  • Tensor number of dims
  • Tensor size per dim
  • Tensor stride per dim
  • Tensor data type
  • Tensor device
  • Optional
  • Scalar
  • Other constant values like float, double, bool, etc.

Based on these factors, there are two options to establish the cache key:

  • Option 1: Hash the factors and take the hash value as the cache key
  • Option 2: Leave all the meta information as it is and record it in a data structure

Option 1 is the common practice, but it may introduce additional overhead as the hash algorithm and hash key lookup might be time-consuming.

In terms of Option 2, it will introduce additional complexity regarding design and implementation compared to Option 1. However, the overhead should be better than option 1 as it just needs to compare all data fields of cache entry.

We will evaluate the overhead and then determine which option is the best one.

In terms of complex data structures like Tensor list and integer list, we will support them gradually.

Cache Load

Regarding the cache loading, there also are two different options here.

  • Option 1: Load the persistent cache instantly as part of the initialization of the torch
  • Option 2: Load the persistent cache lazily on the first torch.compile-based aten operation being invoked

Due to the persistent cache loading, it may introduce additional overhead and take a longer time; therefore, option 1 may slightly impact the user experience as the torch loading has already taken a longer time to finish its initialization. But the side effect is the loading may be useless as torch.compile-based aten operations may always be not invoked.

Compared to option 1, option 2 is a trade-off solution; it ensures the cache is always useful for the current process. However, it may impact the performance at runtime to initialize the cache during a model/python script running.

We prefer option 1 from the performance perspective.

Cache Lookup

The cache lookup mechanism depends on the cache key design. And the implementation should be straightforward regardless of which one we take.

  • Compare string hash value while the cache key comes from hashing the combined meta-information of all the input parameters
  • Compare the meta-information one by one if the cache key is a data structure that records all the original meta-information of all the input parameters

And there are two scenarios we need to handle – cache hit and cache miss.

  • Cache Hit
    • Convert the input parameters packed as torch::jit::stack to the input parameter of AOTIModelContainerRunner and then return the result just like the current Aten C++ implementation.
  • Cache Miss
    • Invoke the registered Python kernel to trigger torch.compile to produce the kernel wrapped by CppWrapper
    • Build the cache key for the produced kernel and then add a cache entry
    • Launch the kernel by AOTIModelContainerRunner

From the C++ and Python perspective:

  • On cache hit: C++ (dispatcher) -> C++ (cache check) -> C++ (torch.compile generated operator)
  • On cache miss: C++ (dispatcher) -> C++ (cache check) -> Python (Trigger torch.compile and add an entry to the cache) -> C++ (torch.compile generated operator)

Cache Store

The cache mechanism generates a unique key for each kernel produced by torch.compile. Regardless of the cache key being a hash key or a data structure, it will be serialized to the disk to accelerate the next bootup, just like Inductor has done for Triton kernel tuning.

Beyond that, we need to highlight how to organize the kernels produced by torch.compile.

We will create a dedicated directory for each ATen operation. The name combines the qualified name and the overload name. So, the directory could be something like {qualified_name}_{overload_name}. Take aten.add as an example; the directory name could be aten_add_int. The motivation is that we do not need to add the operation name to the cache key and then avoid string comparison.

Currently, the default Inductor kernel cache is placed at /tmp/torchinductor_{account_name}. It will be swept out for each boot. To avoid this penalty, we’d prefer to store the cache in the non-temp folder.

Summary

This document delves into the implementation of PyTorch Eager mode support using torch.compile. Currently, PyTorch has defined over 2000 operations, and for a new backend to support PyTorch Eager mode, it must implement these operations, or users might encounter unimplemented errors. To address this challenge, we propose the method of compiling single ATen operations using torch.compile and registering them. This approach allows for dynamic compilation and optimization of operations, offering more efficient support for different hardware backends.

The document details the registration, cache mechanism, cache key design, and steps that new backend maintainers need to take.

Overall, this proposal aims to simplify the maintenance of PyTorch backends while enhancing efficiency and user experience. Through this approach, it becomes easier to introduce new hardware backends to PyTorch.

Current Status and Challenges

We are working on the exploration and have enabled the above example by providing another alternative registration API for POC. We will support more operations and both inference and training to check if there are more feature gaps.

Besides the feature implementation, there are some challenges.

  • The compilation overhead will significantly impact the user experience.
  • The feature requires robust dynamic support. Otherwise, the single aten operation has to be recompiled as long as the shapes of the input tensors are changed.

Regarding these challenges, we may address them by improving the persistent disk cache of torch.compile and dynamic support.

Alternatives

No response

Additional context

Currently, registering a Python kernel for a particular ATen operation will trigger hermetic Python object assumption when ATen operation dispatching.

EnableHermeticPyObject g2;

However, the assumption will be broken if the Python kernel invokes torch.compile to implement its logic. Take the above code snippet as an example, torch.compile needs to build FakeTensor while the FakeTensor contains __torch_dispatch__. It means that check_has_torch_dispatch will return True. But the logic requires it to be False to ensure "that operations in HermeticPyObject have equivalent C++ implementations."

static bool check_has_torch_dispatch(PyObject* obj) {
PyTypeObject* tp = Py_TYPE(obj);
if (THPVariable_CheckTypeExact(tp)) {
return false;
}
py::object attr = PyObject_FastGetAttrString(obj, "__torch_dispatch__");
return (
attr.ptr() != nullptr &&
attr.ptr() != torch::disabled_torch_dispatch_impl());
}

Therefore, it cannot pass the check -

TORCH_INTERNAL_ASSERT(
!check_has_torch_dispatch(obj),
"While HermeticPyObject was enabled, we attempted to create a tensor "
"subclass with __torch_dispatch__. This violates the invariant that "
"operations in HermeticPyObject have equivalent C++ implementations. "
"If your operator registered from Python operator registration isn't "
"doing anything strange, there may be an internal PyTorch bug involving "
"not appropriately disabling TorchDispatchMode before executing "
"Python op registration.");

To address the issue, we can provide a dedicated registration API to indicate a Python kernel to invoke torch.compile for its implementation.

ATen Operation Parameter Type List

  • Tensor
  • bool
  • int64_t
  • TensorList
  • Scalar
  • c10::SymIntArrayRef
  • ::std::optional<Tensor>
  • IntArrayRef
  • double
  • c10::SymInt
  • ::std::optional<ScalarType>
  • ::std::optional<double>
  • ::std::optional<bool>
  • ::std::optional<Layout>
  • ::std::optional<Device>
  • ::std::optional<int64_t>
  • Dimname
  • ::std::optional<Generator>
  • c10::string_view
  • ::std::optional<c10::string_view>
  • OptionalIntArrayRef
  • ::std::optional<Scalar>
  • OptionalSymIntArrayRef
  • ::std::optional<MemoryFormat>
  • ::std::optional<c10::SymInt>
  • ScalarType
  • ArrayRef<Scalar>
  • DimnameList
  • ::std::optional<ArrayRef<double>>
  • ::std::array<bool,3>
  • ::std::optional<DimnameList>
  • c10::List<::std::optional<Tensor>>
  • ::std::array<bool,2>
  • Storage
  • ::std::array<bool,4>
  • Device
  • DeviceIndex
  • ITensorListRef
  • Stream
  • Layout
  • MemoryFormat

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    RFC

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions