Skip to content

MatMul and blocksparse matmul incorrect precision in some shape. #1808

@Qu-Xiangjun

Description

@Qu-Xiangjun

Using an Triton 2.0.0, Pytorch 2.0.0, Python 3.9.16, Cuda 11.6 on a pc running Centos release 7.4.1708 with an nvidia A100. I using the matmul and blocksparse/matmul ops in https://github.com/openai/triton/tree/main/python/triton/ops . And I using the test code like to [test_matmul.py](https://github.com/openai/triton/blob/main/python/test/unit/operators/test_matmul.py) and [test_blocksparse.py](https://github.com/openai/triton/blob/main/python/test/unit/operators/test_blocksparse.py).

Then I find some problem when I compare the tirton matmul with torch.matmul, the result is different by torch.allclose(atol = 1e-5, rtol=0) as follow:

Matmul Test

the tesing code as follow:

import torch
import triton

M, N, K = 2048, 2048, 2048
torch.manual_seed(0)
a = torch.randn((M,K), device = 'cuda', dtype = torch.float16)
b = torch.randn((K,N), device = 'cuda', dtype = torch.float16)
# compute torch
torch_output = torch.matmul(a, b)
# compute triton
triton_output = triton.ops.matmul(a, b)

# compare
diff = torch.sum(torch.abs(triton_output - torch_output))
print("total difference: {:10f}".format(diff))

if(torch.allclose(triton_output, torch_output, atol = 1e-5, rtol = 0)):
    print("✅ Triton and Torch match")
else:
    print("❌ Triton and Torch differ")

This code will print total difference more than 0.0, and the torch.allclose is return false.

Then I tried observed some character:

  1. the diff increasing as the shape increase. I guess it maybe related from cumulative accuracy of the calculation. But when I using M,K,N = 4096,4096,4096 running this code in my machine, it's pass ✅ the allclose function and diff = 0.000000. It's also related with shape? Because only some shape will occur the problem.

  2. Moreover, I had try some special data to test in shape M, N, K = 2048, 2048, 2048.

    • I take the a = torch.ones ,b = torch.ones to run the code, which result is always pass ✅. So in some times this don't related from shape.

    • I take the a = torch.ones ,b = torch.randn to run the code, which every row for the result matrix is same, also same in the incorrect elements.

Blocksparse Matmul Test

The incorrect precision also in blocksparse matmul function. the test code as follow, which only using the forward testing for [test_blocksparse.py](https://github.com/openai/triton/blob/main/python/test/unit/operators/test_blocksparse.py) :

def sparsify_tensor(x, mask, block):
    ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
    for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
        ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
    return ret

def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None,
                 dtype=torch.float32):
    if data is None:
        data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device)
    ref_ret = data
    ref_ret = ref_ret * alpha + beta
    ref_ret = ref_ret.half().to(dtype)
    if trans:
        ref_ret = ref_ret.t().requires_grad_()
    ref_ret = ref_ret.detach().requires_grad_()
    tri_ret = ref_ret.clone().detach().requires_grad_()
    return ref_ret, tri_ret

def mask_tensor(x, mask, block, value=0):
    ret = x.clone()
    for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
        ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
    return ret

def test_blocksparsematmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
    seed = 0
    torch.manual_seed(seed)
    is_sdd = MODE == "sdd"
    is_dsd = MODE == "dsd"
    is_dds = MODE == "dds"
    do_sparsify = lambda x: sparsify_tensor(x, layout, BLOCK)
    do_mask = lambda x: mask_tensor(x, layout, BLOCK)


    # create inputs
    # create op
    a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
    b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N)
    c_shape = (Z, H, M, N)
    shape = {
        "sdd": (M, N),
        "dsd": (a_shape[2], a_shape[3]),
        "dds": (b_shape[2], b_shape[3]),
    }[MODE]
    layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))

    # create data
    a_ref, a_tri = make_pair(a_shape, alpha=.1, dtype=DTYPE)
    b_ref, b_tri = make_pair(b_shape, alpha=.1, dtype=DTYPE)
    dc_ref, dc_tri = make_pair(c_shape, dtype=DTYPE)

    # compute [torch]
    a_ref = do_mask(a_ref) if is_dsd else a_ref
    b_ref = do_mask(b_ref) if is_dds else b_ref
    a_ref.retain_grad()
    b_ref.retain_grad()
    c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
                         b_ref.transpose(2, 3) if TRANS_B else b_ref)
    c_ref = do_sparsify(c_ref) if is_sdd else c_ref      
    # dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
    # c_ref.backward(dc_ref) 
    # da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad
    # db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad

    # triton result
    a_tri = do_sparsify(a_tri) if is_dsd else a_tri
    b_tri = do_sparsify(b_tri) if is_dds else b_tri
    a_tri.retain_grad()
    b_tri.retain_grad()
    # op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
    op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
    c_tri = op(a_tri, b_tri)
    # dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri
    # c_tri.backward(dc_tri)
    # da_tri = a_tri.grad
    # db_tri = b_tri.grad

    # compare
    print("--------------------------------------------------------------")
    perf = lambda ms: 2 * M * N * K * Z * H * 1e9 / ( ms * 1e-3)
    total_op = 2 * M * N * K * Z * H

    print('''MODE={}, Z={}, H={}, M={}, N={}, K={}, total_op={}. '''
            .format(MODE,Z, H, M, N, K, total_op))

    diff = torch.sum(torch.abs(c_ref - c_tri))
    print('total diff = {:.10f}'.format(diff))

    if(torch.allclose(c_ref, c_tri, atol = 1e-5, rtol = 0)):
        print("✅ Triton and Torch match")
    else:
        print("❌ Triton and Torch differ")

    ms, _, _ = triton.testing.do_bench(lambda: op(a_tri, b_tri), rep = 20)
    print('''Triton: GFLOPS: {:.3f}, time: {:.6f}ms.'''.format(perf(ms), ms))

    ms_torch, _, _ = triton.testing.do_bench(
        lambda: torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
                            b_ref.transpose(2, 3) if TRANS_B else b_ref),
        rep = 20
    )
    print('''Torch: GFLOPS: {:.3f}, time: {:.6f}ms.'''.format(perf(ms_torch), ms_torch))

    return perf(ms), perf(ms_torch), diff

test_blocksparsematmul('dds', False, False, 32, torch.float16, Z = 1, H = 2, M = 64, K = 4096, N = 4096)

This code will print total difference more than 0.0, and the torch.allclose is return false.

Then I tried observed some character:

  • In small shape sunch as M, N, K = 256, 256, 256 , the code always pass ✅
  • I tried testing the shape in the range [1,1024] for M and N, K = 4096, 4096, which show the more than half of the range print the ❌ Triton and Torch differ.

So what could be causing the incorrect precision and how to solute the problem?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions