Skip to content

JIT "optimizations" take way too long  #57894

@PiotrDabkowski

Description

@PiotrDabkowski

🐛 Bug

The initial calls to the loaded TorchScript model take extremely long (even 1000x longer than the non-jit run, 20s vs 0.02s). I understand that some "optimizations" happen under the hood, but this is really broken user experience for a number of reasons:

  1. Not clear what the cause of slow execution is, eg user issue, docs do not mention that the initial execution can be extremely slow
  2. Solution torch.jit.optimized_execution(False) not mentioned in docs
  3. The optimisation seems to be re-run for different input shape, this is an issue for variable-length input (eg sound waveform)
  4. Why is it it 1000x slower than normal? Can optimisations be run by the user explicitly rather than magically under the hood? If they are run implicitly then they should not make the run 1000x slower (or at least should be clearly documented).

To Reproduce

Load TorchScript model, run with variable-sized input. Notice extremely slow initial runtimes.

Expected behavior

Optimisation is fast or does not happen implicitly when calling the model. Optimisation not re-run for inputs of different shape. Clear documentation.

Environment

PyTorch version: 1.9.0.dev20210501+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.16.3

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 465.19.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.4
[pip3] pytorch-lightning==1.2.10
[pip3] pytorch-lightning-bolts==0.3.0
[pip3] torch==1.9.0.dev20210501+cu111
[pip3] torch-stft==0.1.4
[pip3] torchaudio==0.9.0.dev20210501
[pip3] torchext==0.0.6
[pip3] torchmetrics==0.2.0
[pip3] torchtext==0.10.0.dev20210501
[pip3] torchvision==0.10.0.dev20210501+cu111
[conda] Could not collect

cc @gmagogsfm

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: jitAdd this issue/PR to JIT oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions