-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Description
🐛 Describe the bug
import torch
pool = torch.nn.AvgPool2d(4, stride=4, padding=2, ceil_mode=True)
x = torch.ones(1, 1, 9, 9)
print(pool(x).shape)
Expected output shape: (1, 1, 4, 4)
Actual output shape: (1, 1, 3, 3)
Per the documentation ceil((9 + 2*2 - 4) / 4 + 1) = 4
Versions
Collecting environment information...
PyTorch version: 2.1.2+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: Could not collect
Libc version: glibc-2.31
Python version: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-88-generic-x86_64-with-glibc2.31
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 43 bits physical, 48 bits virtual
CPU(s): 8
On-line CPU(s) list: 0-7
Thread(s) per core: 1
Core(s) per socket: 1
Socket(s): 8
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 143
Model name: Intel(R) Xeon(R) Gold 6444Y
Stepping: 8
CPU MHz: 3599.999
BogoMIPS: 7199.99
Hypervisor vendor: VMware
Virtualization type: full
L1d cache: 384 KiB
L1i cache: 256 KiB
L2 cache: 16 MiB
L3 cache: 360 MiB
NUMA node0 CPU(s): 0-7
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Unknown: No mitigations
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon nopl xt opology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowpref etch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 invpcid avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsav ec xsaves arat pku ospke md_clear flush_l1d arch_capabilities
Versions of relevant libraries:
[pip3] numpy==1.26.2
[pip3] onnx==1.15.0
[pip3] onnxruntime==1.16.3
[pip3] onnxruntime-extensions==0.9.0
[pip3] torch==2.1.2
[pip3] torchvision==0.16.2
[pip3] triton==2.1.0
[conda] numpy 1.26.2 pypi_0 pypi
[conda] torch 2.1.2 pypi_0 pypi
[conda] torchvision 0.16.2 pypi_0 pypi
[conda] triton 2.1.0 pypi_0 pypi
cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki
Metadata
Metadata
Assignees
Labels
Type
Projects
Status