-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Description
🚀 The feature, motivation and pitch
Motivation
By now, PyTorch has defined > 2000 operations. Meanwhile, PyTorch users start from eager mode. If a new backend intends to support PyTorch eager mode, it means that the backend, like Intel GPU, has to implement all these operations. Otherwise, the users may encounter unimplemented
errors if they run PyTorch on the new backend. This scenario presents two substantial challenges - engineering effort and maintenance effort.
- Engineering Effort - Implementing all ATen operations (over 2000 in number) for a new backend is a considerable task, requiring significant development resources.
- Maintenance Effort - PyTorch operations may change its interfaces, such as adding or removing parameters. Regarding these changes, the backend maintainers will need to adopt these backend-specific implementations according to these changes.
Given these challenges, we propose an alternate technical pathway for eager mode support through torch.compile, offering two main advantages over the traditional implementation approach:
- Reduced Implementation Requirement: Backends only need to implement fallback operations that are not yet supported by
torch.compile
or wheretorch.compile
explicitly falls back to ATen. - Device-Agnostic Interface:
torch.compile
allows generating performant backend code without the need for backend-specific implementation of each ATen operation.
Approach
We propose compiling single aten operations using torch.compile
and registering them in PyTorch on the Python side. This approach simplifies backend implementation and maintenance. Here's an illustrative example:
# Example code demonstrating the proposed method
x = torch.empty(1, device="xpu").fill_(1)
y = torch.empty(1, device="xpu").fill_(2)
def wrapper_fn_sub(*args, **kwargs):
with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False):
opt_fn_sub = torch.compile(torch.ops.aten.sub)
res = opt_fn_sub(a, b)
return res
custom_op_lib_xpu_impl = torch.library.Library("aten", "IMPL")
custom_op_lib_xpu_impl.impl("sub.Tensor", wrapper_fn_sub, "XPU")
res = x - y
ref = torch.empty(1, device="xpu").fill_(-1)
assert all(res == ref)
Detail Design
The torch.compile
invocation needs to be wrapped as a general Python function and then registered to the torch dispatcher. The mechanism ensures robust functionality.
Meanwhile, we need to accelerate the performance by mitigating Python and torch.compile
overhead. We propose a cache mechanism for this purpose.
So, the detailed design focuses on the registration and cache.
Registration
To trigger torch.compile
to produce a C++/Triton kernel, we always need to register a Python kernel for each ATen operation, just like the above example code.
But we do not need to always invoke the Python kernel if its corresponding torch.compile
kernel has been produced.
In addition, the context switch between Python and C++ introduces additional overhead like Python GIL. And the performance of a Python implementation is worse than its equivalent C++ implementation in general.
Therefore, we always prefer to avoid running code in the Python world.
To achieve this goal, we wrap the Python kernel as a C++ function/class just like what torch has done for the other Python kernels(PythonKernelHolder
)
- When callers invoke
torch.Library.Library.impl
to register a Python kernel, the Python kernel will be wrapped byPythonKernelHolder
and then cast thePythonKernelHolder
as aBoxedFunctor
.
[](const py::object& self,
const char* name,
// TODO: empty string no longer works
c10::DispatchKey dispatch,
py::object func) {
HANDLE_TH_ERRORS
auto& lib = self.cast<torch::Library&>();
if (func.is(py::module::import("torch.library")
.attr("fallthrough_kernel"))) {
lib.impl(
name,
torch::dispatch(dispatch, CppFunction::makeFallthrough()),
register_or_verify());
} else {
lib.impl(
name,
torch::dispatch(
dispatch,
CppFunction::makeFromBoxedFunctor(std::make_unique<PythonKernelHolder>(func, h))),
register_or_verify());
python_registrations_[lib._resolve(name)].insert_or_assign(
dispatch,
std::make_shared<c10::SafePyObject>(
func.release().ptr(), getPyInterpreter()));
}
END_HANDLE_TH_ERRORS_PYBIND
}
Definitely, we can reuse the mechanism and customize it a little bit to accelerate the operation by introducing a cache mechanism. The following section will elaborate on the detailed design of the cache mechanism.
Suppose a class is named AOTICompatiblePythonKernelHolder
for this purpose, its major features are as follows.
- Lookup a global cache to check whether its
torch.compile
-ed kernel is available - Cache hit
- Invoke the cached kernel directly
- Cache miss
- Invoke the registered Python kernel to trigger
torch.compile
to produce kernel - Run the kernel produced by
torch.compile
and cache it
- Invoke the registered Python kernel to trigger
Its pseudo-code could be as follows.
class AOTICompatiblePythonKernelHolder : public c10::OperatorKernel {
public:
// TODO: We can add more information to the ctor to accelerate the cache lookup
AOTICompatiblePythonKernelHolder() = default;
void operator()(
const c10::OperatorHandle& op,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack) override {
// aoti_kernel_cache is a global cache to store all the kernels produced by torch.compile for eager
auto cache_key = aoti_kernel_cache.getCacheKey(op, stack);
auto kernel_handle = aoti_kernel_cache.get(cache_key);
if (kernel_handle) {
// Cach Hit
auto res = (*kernel_handle)(stack);
pushPyOutToStack(op, stack, ::reinterpret_steal<py::object>(res), "PythonKernelHolder");
} else {
// Cache Miss
auto func = ::reinterpret_borrow<py::object>(op.schema().name());
auto arguments = torch::jit::pop(*stack, .schema().arguments().size());
py::gil_scoped_acquire g;
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
// Invoke the Python kernel to trigger torch.compile and get the result
auto obj = py::reinterpret_steal<py::object>(PyObject_Call(
func.ptr(getPyInterpreter()),
args_kwargs.first.ptr(),
args_kwargs.second.ptr()));
if (!obj) {
throw python_error();
}
pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
// Load the kernel produced by torch.compile from disk into memory for next run
aoti_kernel_cache.load(op, cache_key);
}
}
};
In addition, all the kernels produced by torch.compile
for eager should be wrapped by C++(CppWrapper
) and loaded by AOTIModelContainerRunner
to mitigate the Python overhead.
Cache
Due to the overhead of torch.compile
is non-negligible (CPU is Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz; PyTorch commit: e8a9d08),
- Initialization Phase:
- There is a significant overhead of over 1 second during the initial run of
torch.compile
. This overhead is a one-time occurrence, triggered only during the first compilation.
- There is a significant overhead of over 1 second during the initial run of
- Performance of Specific Aten Operations Post-Initialization:
- On First Execution: After the initial
torch.compile
setup, the overhead for executing a specific Aten operation for the first time is approximately 115 milliseconds. - On Subsequent Executions Without Re-compilation: When the same Aten operation is executed again, without re-compiling, the overhead becomes negligible, aligning with the performance of the traditional eager mode.
- On First Execution: After the initial
Therefore, we introduce a cache mechanism to mitigate the overhead.
By the way, the cache should be a persistent cache. It can avoid producing kernel code multiple times for a particular aten operation if the input parameters have the same meta information when compiling the operation.
Suppose a Mul
operation, the torch.compiled-based kernel w/ cache could be as follows from the dispatch perspective.
Cache Key
The parameters of an aten operation will be packed as torch.jit.stack
. Therefore, we can unpack the parameters one by one and extract information from torch.jit.IValue
to constitute the cache key for each operator.
In addition, each torch.compile
-ed kernel is dedicated to a particular aten operation. It means that the kernel knows the exact semantics of its implementation. And its schema is informative. Based on this information, the cache key could be constituted by the following factors.
- Tensor number of dims
- Tensor size per dim
- Tensor stride per dim
- Tensor data type
- Tensor device
- Optional
- Scalar
- Other constant values like float, double, bool, etc.
- …
Based on these factors, there are two options to establish the cache key:
- Option 1: Hash the factors and take the hash value as the cache key
- Option 2: Leave all the meta information as it is and record it in a data structure
Option 1 is the common practice, but it may introduce additional overhead as the hash algorithm and hash key lookup might be time-consuming.
In terms of Option 2, it will introduce additional complexity regarding design and implementation compared to Option 1. However, the overhead should be better than option 1 as it just needs to compare all data fields of cache entry.
We will evaluate the overhead and then determine which option is the best one.
In terms of complex data structures like Tensor list and integer list, we will support them gradually.
Cache Load
Regarding the cache loading, there also are two different options here.
- Option 1: Load the persistent cache instantly as part of the initialization of the torch
- Option 2: Load the persistent cache lazily on the first
torch.compile
-based aten operation being invoked
Due to the persistent cache loading, it may introduce additional overhead and take a longer time; therefore, option 1 may slightly impact the user experience as the torch loading has already taken a longer time to finish its initialization. But the side effect is the loading may be useless as torch.compile
-based aten operations may always be not invoked.
Compared to option 1, option 2 is a trade-off solution; it ensures the cache is always useful for the current process. However, it may impact the performance at runtime to initialize the cache during a model/python script running.
We prefer option 1 from the performance perspective.
Cache Lookup
The cache lookup mechanism depends on the cache key design. And the implementation should be straightforward regardless of which one we take.
- Compare string hash value while the cache key comes from hashing the combined meta-information of all the input parameters
- Compare the meta-information one by one if the cache key is a data structure that records all the original meta-information of all the input parameters
And there are two scenarios we need to handle – cache hit and cache miss.
- Cache Hit
- Convert the input parameters packed as
torch::jit::stack
to the input parameter ofAOTIModelContainerRunner
and then return the result just like the current Aten C++ implementation.
- Convert the input parameters packed as
- Cache Miss
- Invoke the registered Python kernel to trigger
torch.compile
to produce the kernel wrapped byCppWrapper
- Build the cache key for the produced kernel and then add a cache entry
- Launch the kernel by
AOTIModelContainerRunner
- Invoke the registered Python kernel to trigger
From the C++ and Python perspective:
- On cache hit: C++ (dispatcher) -> C++ (cache check) -> C++ (torch.compile generated operator)
- On cache miss: C++ (dispatcher) -> C++ (cache check) -> Python (Trigger
torch.compile
and add an entry to the cache) -> C++ (torch.compile generated operator)
Cache Store
The cache mechanism generates a unique key for each kernel produced by torch.compile
. Regardless of the cache key being a hash key or a data structure, it will be serialized to the disk to accelerate the next bootup, just like Inductor has done for Triton kernel tuning.
Beyond that, we need to highlight how to organize the kernels produced by torch.compile
.
We will create a dedicated directory for each ATen operation. The name combines the qualified name and the overload name. So, the directory could be something like {qualified_name}_{overload_name}
. Take aten.add
as an example; the directory name could be aten_add_int
. The motivation is that we do not need to add the operation name to the cache key and then avoid string comparison.
Currently, the default Inductor kernel cache is placed at /tmp/torchinductor_{account_name}
. It will be swept out for each boot. To avoid this penalty, we’d prefer to store the cache in the non-temp folder.
Summary
This document delves into the implementation of PyTorch Eager mode support using torch.compile
. Currently, PyTorch has defined over 2000 operations, and for a new backend to support PyTorch Eager mode, it must implement these operations, or users might encounter unimplemented errors. To address this challenge, we propose the method of compiling single ATen operations using torch.compile
and registering them. This approach allows for dynamic compilation and optimization of operations, offering more efficient support for different hardware backends.
The document details the registration, cache mechanism, cache key design, and steps that new backend maintainers need to take.
Overall, this proposal aims to simplify the maintenance of PyTorch backends while enhancing efficiency and user experience. Through this approach, it becomes easier to introduce new hardware backends to PyTorch.
Current Status and Challenges
We are working on the exploration and have enabled the above example by providing another alternative registration API for POC. We will support more operations and both inference and training to check if there are more feature gaps.
Besides the feature implementation, there are some challenges.
- The compilation overhead will significantly impact the user experience.
- The feature requires robust dynamic support. Otherwise, the single aten operation has to be recompiled as long as the shapes of the input tensors are changed.
Regarding these challenges, we may address them by improving the persistent disk cache of torch.compile
and dynamic support.
Alternatives
No response
Additional context
Currently, registering a Python kernel for a particular ATen operation will trigger hermetic Python object assumption when ATen operation dispatching.
pytorch/torch/csrc/utils/python_dispatch.cpp
Line 174 in b88be16
EnableHermeticPyObject g2; |
However, the assumption will be broken if the Python kernel invokes torch.compile
to implement its logic. Take the above code snippet as an example, torch.compile
needs to build FakeTensor
while the FakeTensor
contains __torch_dispatch__
. It means that check_has_torch_dispatch
will return True
. But the logic requires it to be False
to ensure "that operations in HermeticPyObject have equivalent C++ implementations."
pytorch/torch/csrc/autograd/python_variable.cpp
Lines 220 to 229 in b88be16
static bool check_has_torch_dispatch(PyObject* obj) { | |
PyTypeObject* tp = Py_TYPE(obj); | |
if (THPVariable_CheckTypeExact(tp)) { | |
return false; | |
} | |
py::object attr = PyObject_FastGetAttrString(obj, "__torch_dispatch__"); | |
return ( | |
attr.ptr() != nullptr && | |
attr.ptr() != torch::disabled_torch_dispatch_impl()); | |
} |
Therefore, it cannot pass the check -
pytorch/torch/csrc/autograd/python_variable.cpp
Lines 1962 to 1970 in b88be16
TORCH_INTERNAL_ASSERT( | |
!check_has_torch_dispatch(obj), | |
"While HermeticPyObject was enabled, we attempted to create a tensor " | |
"subclass with __torch_dispatch__. This violates the invariant that " | |
"operations in HermeticPyObject have equivalent C++ implementations. " | |
"If your operator registered from Python operator registration isn't " | |
"doing anything strange, there may be an internal PyTorch bug involving " | |
"not appropriately disabling TorchDispatchMode before executing " | |
"Python op registration."); |
To address the issue, we can provide a dedicated registration API to indicate a Python kernel to invoke torch.compile
for its implementation.
ATen Operation Parameter Type List
-
Tensor
-
bool
-
int64_t
-
TensorList
-
Scalar
-
c10::SymIntArrayRef
-
::std::optional<Tensor>
-
IntArrayRef
-
double
-
c10::SymInt
-
::std::optional<ScalarType>
-
::std::optional<double>
-
::std::optional<bool>
-
::std::optional<Layout>
-
::std::optional<Device>
-
::std::optional<int64_t>
-
Dimname
-
::std::optional<Generator>
-
c10::string_view
-
::std::optional<c10::string_view>
-
OptionalIntArrayRef
-
::std::optional<Scalar>
-
OptionalSymIntArrayRef
-
::std::optional<MemoryFormat>
-
::std::optional<c10::SymInt>
-
ScalarType
-
ArrayRef<Scalar>
-
DimnameList
-
::std::optional<ArrayRef<double>>
-
::std::array<bool,3>
-
::std::optional<DimnameList>
-
c10::List<::std::optional<Tensor>>
-
::std::array<bool,2>
-
Storage
-
::std::array<bool,4>
-
Device
-
DeviceIndex
-
ITensorListRef
-
Stream
-
Layout
-
MemoryFormat
cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519
Metadata
Metadata
Assignees
Labels
Type
Projects
Status