-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Description
🐛 Describe the bug
Dynamo doesn't seem to support user defined cuda streams and doesn't create graphs for ops in the user stream context.
Example code:
import torch._dynamo as dynamo
import torch
dynamo.config.log_level = dynamo.config.logging.DEBUG
torch._dynamo.reset()
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print(gm.code)
print(gm.graph)
gm.graph.print_tabular()
return gm.forward
s = torch.cuda.Stream()
@dynamo.optimize(my_compiler)
def fn(t) -> torch.Tensor:
tmp1 = torch.mul(t, 5)
tmp2 = torch.add(tmp1, 2)
with torch.cuda.stream(s):
r = torch.relu(tmp2)
return r
i = torch.Tensor([-2, 3]).to('cuda')
r = fn(i)
print(f"r = {r}")
Even though the output is correct, Dynamo seems to break the graph when the user stream context is seen -
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 420, in proxy_args_kwargs
proxy_args = tuple(arg.as_proxy() for arg in args)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 420, in <genexpr>
proxy_args = tuple(arg.as_proxy() for arg in args)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/base.py", line 206, in as_proxy
raise NotImplementedError(str(self))
NotImplementedError: UserDefinedObjectVariable(Stream)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 307, in wrapper
return inner_fn(self, inst)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 974, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 435, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/torch.py", line 476, in call_function
*proxy_args_kwargs(args, kwargs),
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 427, in proxy_args_kwargs
raise unimplemented(
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/exc.py", line 71, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_function args: UserDefinedObjectVariable(Stream)
[2023-01-23 04:54:12,810] torch._dynamo.output_graph: [DEBUG] restore_graphstate: removed 0 nodes
[2023-01-23 04:54:12,811] torch._dynamo.output_graph: [DEBUG] COMPILING GRAPH due to GraphCompileReason(reason='call_function args: UserDefinedObjectVariable(Stream) ', user_stack=[<FrameSummary file <ipython-input-12-f1436f0fcd91>, line 20 in fn>])
It seems there is no graph captured for the ops in the user stream context, as there is a no graph with relu (op in the user stream context in the example code) seen below -
def forward(self, t : torch.Tensor):
mul = torch.mul(t, 5); t = None
add = torch.add(mul, 2); mul = None
return (add,)
graph():
%t : torch.Tensor [#users=1] = placeholder[target=t]
%mul : [#users=1] = call_function[target=torch.mul](args = (%t, 5), kwargs = {})
%add : [#users=1] = call_function[target=torch.add](args = (%mul, 2), kwargs = {})
return (add,)
opcode name target args kwargs
------------- ------ ------------------------------------------------------ --------- --------
placeholder t t () {}
call_function mul <built-in method mul of type object at 0x7f2e333dbe40> (t, 5) {}
call_function add <built-in method add of type object at 0x7f2e333dbe40> (mul, 2) {}
output output output ((add,),) {}
Is there a way to capture the user stream ops in a graph via dynamo?
Versions
PyTorch version: 2.0.0.dev20230122+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.25.0
Libc version: glibc-2.31
Python version: 3.8.10 (default, Nov 14 2022, 12:59:47) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.10.147+-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: 11.2.152
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 460.32.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.1.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] pytorch-triton==2.0.0+0d7e753227
[pip3] torch==2.0.0.dev20230122+cu117
[pip3] torchaudio==0.13.1+cu116
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.14.1
[pip3] torchvision==0.14.1+cu116
[conda] Could not collect
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @ptrblck @eqy @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @amjames @rec @soumith @ngimel @bdhirsh @mlazos @yanboliang @anijain2305 @chunyuan-w @Xia-Weiwen @desertfire