-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Closed
Description
🐛 Describe the bug
Gradient accumulation does not seem to work with torch.distributions.categorical
. See the following snippet
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
if __name__ == "__main__":
torch.manual_seed(3)
torch.backends.cudnn.deterministic = True
# setup
x = torch.rand(2, 4)
layer = nn.Linear(4, 4)
actions = torch.tensor([1, 2])
# whole batch
layer.weight.grad = None
logprobs = Categorical(logits=layer(x)).log_prob(actions)
logprobs.mean().backward()
print(layer.weight.grad.sum())
# gradient accumulation
layer.weight.grad = None
Categorical(logits=layer(x[0])).log_prob(actions[0]).mean().backward()
Categorical(logits=layer(x[1])).log_prob(actions[1]).mean().backward()
print(layer.weight.grad.sum() / 2)
tensor(1.6391e-07)
tensor(1.4901e-08)
Versions
Collecting environment information...
PyTorch version: 1.10.2+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: Pop!_OS 21.10 (x86_64)
GCC version: (Ubuntu 11.2.0-7ubuntu2) 11.2.0
Clang version: Could not collect
CMake version: version 3.18.4
Libc version: glibc-2.34
Python version: 3.9.5 (default, Jul 19 2021, 13:27:26) [GCC 10.3.0] (64-bit runtime)
Python platform: Linux-5.15.8-76051508-generic-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 11.3.109
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2060
Nvidia driver version: 470.86
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.21.5
[pip3] torch==1.10.2
[conda] Could not collect
Metadata
Metadata
Assignees
Labels
No labels