Skip to content

ONNX: Wrong output shape for ceil_mode Pooling #71549

@kentanabe

Description

@kentanabe

🐛 Describe the bug

Output shape in the ONNX file exported for the following dummy model using with ceil_mode AvgPooling2d is wrong size.

import torch
import torch.nn as nn

input_tensor = torch.rand(1, 3, 2, 2)
opset_version = 10

class DummyModel(nn.Module):
    def __init__(self):
        super(DummyModel, self).__init__()
        self.layer = nn.AvgPool2d(
            kernel_size=(3, 3),
            stride=(3, 3),
            padding=(1,1),
            ceil_mode=True)
    def forward(self, x):
        x = self.layer(x)
        return x

dummy_model = DummyModel()
output_torch = dummy_model(input_tensor)
print('PyTorch input', input_tensor.shape)
print('PyTorch output', output_torch.shape)

onnx_file_path = "dummy_model.onnx"
torch.onnx.export(model=dummy_model,
                  args=input_tensor,
                  f=onnx_file_path,
                  do_constant_folding=True,
                  opset_version=opset_version,
                  input_names = ['input'],
                  output_names = ['output'],
                  export_params=True
                  )
import onnxruntime
onnxruntime_net = onnxruntime.InferenceSession(onnx_file_path)
print('onnxruntime input', onnxruntime_net.get_inputs()[0].shape)
print('onnxruntime output', onnxruntime_net.get_outputs()[0].shape)

Output

PyTorch input torch.Size([1, 3, 2, 2])
PyTorch output torch.Size([1, 3, 1, 1])
onnxruntime input [1, 3, 2, 2]
onnxruntime output [1, 3, 2, 2]

Output shape of ONNX file is bigger than PyTorch model.

Remarks

It happens only ONNX opset version larger equal than 10.

Versions

Collecting environment information...
PyTorch version: 1.10.1+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.8.10 (default, Nov 26 2021, 20:14:08) [GCC 9.3.0] (64-bit runtime)
Python platform: Linux-5.10.74.3-microsoft-standard-WSL2-x86_64-with-glibc2.29
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.22.1
[pip3] torch==1.10.1
[conda] Could not collect

Metadata

Metadata

Assignees

Labels

module: onnxRelated to torch.onnxonnx-triagedtriaged by ONNX teamtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions