Skip to content

Print statements inside kernel print incorrect value of int64 tensors #4060

@georg-wolflein

Description

@georg-wolflein

I came across a bug using int64 tensors. Here's a minimal reproduction.

MWE:

import torch
import triton
import triton.language as tl


@triton.jit
def ndscore_kernel(ptr):
    value = tl.load(ptr)
    print("value in kernel", value)
    tl.store(ptr, value + 1)


ptr = torch.tensor(42, dtype=torch.int64).cuda()
print("value before kernel", ptr.item())
ndscore_kernel[(1,)](ptr)
print("value after kernel", ptr.item())

Output:

value before kernel 42
pid (0, 0, 0) idx () value in kernel: 0
[...]
pid (0, 0, 0) idx () value in kernel: 0
value after kernel 43

Why does the kernel print 0 instead of 42?

Observations:

  • Changing the dtype of ptr to torch.int32 correctly prints pid (0, 0, 0) idx () value in kernel: 42

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions