Skip to content

ONNX: wrong operator for ceil_mode Pooling in case of skip the last windowย #131272

@kentanabe

Description

@kentanabe

๐Ÿ› Describe the bug

Output shape in the ONNX file exported for the following dummy model using with ceil_mode AvgPooling2d is wrong size.

This issue was reported in #71549 and fixed in pytorch 2.0.0, but regressioned in 2.1.0 again.

According to the document of AvgPool2d, PyTorch skips the last window as it would start in the bottom padded region.
https://docs-preview.pytorch.org/pytorch/pytorch/120335/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
On the other hand, AveragePool in ONNX doesn't skip.
#116420 (comment)

Though we need to split ceil_mode AvgPool2d into a Pad operator and floor_mode AveragePool operator, PyTorch 2.1 or later simply exports a ceil_mode AveragePool operator.

import torch
import torch.nn as nn

input_tensor = torch.rand(1, 3, 2, 2)
opset_version = 10

class DummyModel(nn.Module):
    def __init__(self):
        super(DummyModel, self).__init__()
        self.layer = nn.AvgPool2d(
            kernel_size=(3, 3),
            stride=(3, 3),
            padding=(1,1),
            ceil_mode=True)
    def forward(self, x):
        x = self.layer(x)
        return x

dummy_model = DummyModel()
output_torch = dummy_model(input_tensor)

onnx_file_path = "dummy_model.onnx"
torch.onnx.export(model=dummy_model,
                  args=input_tensor,
                  f=onnx_file_path,
                  do_constant_folding=True,
                  opset_version=opset_version,
                  input_names = ['input'],
                  output_names = ['output'],
                  export_params=True
                  )
import onnxruntime
onnxruntime_net = onnxruntime.InferenceSession(onnx_file_path)

print('PyTorch input', input_tensor.shape)
print('PyTorch output', output_torch.shape)
print('onnxruntime input', onnxruntime_net.get_inputs()[0].shape)
print('onnxruntime output', onnxruntime_net.get_outputs()[0].shape)

Output

1.13.0: WRONG

PyTorch input torch.Size([1, 3, 2, 2])
PyTorch output torch.Size([1, 3, 1, 1])
onnxruntime input [1, 3, 2, 2]
onnxruntime output [1, 3, 2, 2]

2.0.0: CORRECT

PyTorch input torch.Size([1, 3, 2, 2])
PyTorch output torch.Size([1, 3, 1, 1])
onnxruntime input [1, 3, 2, 2]
onnxruntime output [1, 3, 1, 1]

2.1.0: WRONG

PyTorch input torch.Size([1, 3, 2, 2])
PyTorch output torch.Size([1, 3, 1, 1])
onnxruntime input [1, 3, 2, 2]
onnxruntime output [1, 3, 2, 2]

ONNX model

1.13.0

image

image

2.0.0

image

image

2.1.0

image

Versions

Collecting environment information...
PyTorch version: 2.1.0+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.30.0
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.153.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 20
On-line CPU(s) list: 0-19
Vendor ID: GenuineIntel
Model name: 12th Gen Intel(R) Core(TM) i7-12700K
CPU family: 6
Model: 151
Thread(s) per core: 2
Core(s) per socket: 10
Socket(s): 1
Stepping: 2
BogoMIPS: 7219.19
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq vmx ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 480 KiB (10 instances)
L1i cache: 320 KiB (10 instances)
L2 cache: 12.5 MiB (10 instances)
L3 cache: 25 MiB (1 instance)
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] numpy==1.23.1
[pip3] onnx==1.16.1
[pip3] onnx-graphsurgeon==0.5.2
[pip3] onnxruntime==1.18.1
[pip3] torch==2.1.0+cpu
[pip3] triton==2.1.0
[conda] Could not collect

Metadata

Metadata

Assignees

Labels

module: onnxRelated to torch.onnxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions