Skip to content

Cuda streams and torch.compile #92804

@sujoysaraswati

Description

@sujoysaraswati

🐛 Describe the bug

Dynamo doesn't seem to support user defined cuda streams and doesn't create graphs for ops in the user stream context.
Example code:

import torch._dynamo as dynamo
import torch

dynamo.config.log_level = dynamo.config.logging.DEBUG
torch._dynamo.reset()

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print(gm.code)
    print(gm.graph)
    gm.graph.print_tabular()
    return gm.forward

s = torch.cuda.Stream()

@dynamo.optimize(my_compiler)
def fn(t) -> torch.Tensor:
    tmp1 = torch.mul(t, 5)
    tmp2 = torch.add(tmp1, 2)
    with torch.cuda.stream(s):
      r = torch.relu(tmp2)
    return r

i = torch.Tensor([-2, 3]).to('cuda')
r = fn(i)
print(f"r = {r}")

Even though the output is correct, Dynamo seems to break the graph when the user stream context is seen -

Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 420, in proxy_args_kwargs
    proxy_args = tuple(arg.as_proxy() for arg in args)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 420, in <genexpr>
    proxy_args = tuple(arg.as_proxy() for arg in args)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/base.py", line 206, in as_proxy
    raise NotImplementedError(str(self))
NotImplementedError: UserDefinedObjectVariable(Stream)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 307, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 974, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 435, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/variables/torch.py", line 476, in call_function
    *proxy_args_kwargs(args, kwargs),
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 427, in proxy_args_kwargs
    raise unimplemented(
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/exc.py", line 71, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_function args: UserDefinedObjectVariable(Stream) 
[2023-01-23 04:54:12,810] torch._dynamo.output_graph: [DEBUG] restore_graphstate: removed 0 nodes
[2023-01-23 04:54:12,811] torch._dynamo.output_graph: [DEBUG] COMPILING GRAPH due to GraphCompileReason(reason='call_function args: UserDefinedObjectVariable(Stream) ', user_stack=[<FrameSummary file <ipython-input-12-f1436f0fcd91>, line 20 in fn>])

It seems there is no graph captured for the ops in the user stream context, as there is a no graph with relu (op in the user stream context in the example code) seen below -

def forward(self, t : torch.Tensor):
    mul = torch.mul(t, 5);  t = None
    add = torch.add(mul, 2);  mul = None
    return (add,)
    
graph():
    %t : torch.Tensor [#users=1] = placeholder[target=t]
    %mul : [#users=1] = call_function[target=torch.mul](args = (%t, 5), kwargs = {})
    %add : [#users=1] = call_function[target=torch.add](args = (%mul, 2), kwargs = {})
    return (add,)
opcode         name    target                                                  args       kwargs
-------------  ------  ------------------------------------------------------  ---------  --------
placeholder    t       t                                                       ()         {}
call_function  mul     <built-in method mul of type object at 0x7f2e333dbe40>  (t, 5)     {}
call_function  add     <built-in method add of type object at 0x7f2e333dbe40>  (mul, 2)   {}
output         output  output                                                  ((add,),)  {}

Is there a way to capture the user stream ops in a graph via dynamo?

Versions

PyTorch version: 2.0.0.dev20230122+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.25.0
Libc version: glibc-2.31

Python version: 3.8.10 (default, Nov 14 2022, 12:59:47) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.10.147+-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: 11.2.152
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 460.32.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.1.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] pytorch-triton==2.0.0+0d7e753227
[pip3] torch==2.0.0.dev20230122+cu117
[pip3] torchaudio==0.13.1+cu116
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.14.1
[pip3] torchvision==0.14.1+cu116
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @ptrblck @eqy @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @amjames @rec @soumith @ngimel @bdhirsh @mlazos @yanboliang @anijain2305 @chunyuan-w @Xia-Weiwen @desertfire

Metadata

Metadata

Assignees

Labels

featureA request for a proper, new feature.high prioritymodule: cudaRelated to torch.cuda, and CUDA support in generalmodule: dynamomonthsoncall: distributedAdd this issue/PR to distributed oncall triage queueoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions