-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
Using an Triton 2.0.0, Pytorch 2.0.0, Python 3.9.16, Cuda 11.6 on a pc running Centos release 7.4.1708 with an nvidia A100. I using the matmul
and blocksparse/matmul
ops in https://github.com/openai/triton/tree/main/python/triton/ops . And I using the test code like to [test_matmul.py](https://github.com/openai/triton/blob/main/python/test/unit/operators/test_matmul.py) and [test_blocksparse.py](https://github.com/openai/triton/blob/main/python/test/unit/operators/test_blocksparse.py).
Then I find some problem when I compare the tirton matmul with torch.matmul, the result is different by torch.allclose(atol = 1e-5, rtol=0) as follow:
Matmul Test
the tesing code as follow:
import torch
import triton
M, N, K = 2048, 2048, 2048
torch.manual_seed(0)
a = torch.randn((M,K), device = 'cuda', dtype = torch.float16)
b = torch.randn((K,N), device = 'cuda', dtype = torch.float16)
# compute torch
torch_output = torch.matmul(a, b)
# compute triton
triton_output = triton.ops.matmul(a, b)
# compare
diff = torch.sum(torch.abs(triton_output - torch_output))
print("total difference: {:10f}".format(diff))
if(torch.allclose(triton_output, torch_output, atol = 1e-5, rtol = 0)):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")
This code will print total difference
more than 0.0, and the torch.allclose
is return false.
Then I tried observed some character:
-
the
diff
increasing as the shape increase. I guess it maybe related from cumulative accuracy of the calculation. But when I usingM,K,N = 4096,4096,4096
running this code in my machine, it's pass ✅ theallclose
function and diff = 0.000000. It's also related withshape
? Because only some shape will occur the problem. -
Moreover, I had try some special data to test in shape
M, N, K = 2048, 2048, 2048
.-
I take the
a = torch.ones ,b = torch.ones
to run the code, which result is always pass ✅. So in some times this don't related from shape. -
I take the
a = torch.ones ,b = torch.randn
to run the code, which every row for the result matrix is same, also same in the incorrect elements.
-
Blocksparse Matmul Test
The incorrect precision also in blocksparse matmul function. the test code as follow, which only using the forward testing for [test_blocksparse.py](https://github.com/openai/triton/blob/main/python/test/unit/operators/test_blocksparse.py) :
def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
return ret
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None,
dtype=torch.float32):
if data is None:
data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device)
ref_ret = data
ref_ret = ref_ret * alpha + beta
ref_ret = ref_ret.half().to(dtype)
if trans:
ref_ret = ref_ret.t().requires_grad_()
ref_ret = ref_ret.detach().requires_grad_()
tri_ret = ref_ret.clone().detach().requires_grad_()
return ref_ret, tri_ret
def mask_tensor(x, mask, block, value=0):
ret = x.clone()
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
return ret
def test_blocksparsematmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
seed = 0
torch.manual_seed(seed)
is_sdd = MODE == "sdd"
is_dsd = MODE == "dsd"
is_dds = MODE == "dds"
do_sparsify = lambda x: sparsify_tensor(x, layout, BLOCK)
do_mask = lambda x: mask_tensor(x, layout, BLOCK)
# create inputs
# create op
a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N)
c_shape = (Z, H, M, N)
shape = {
"sdd": (M, N),
"dsd": (a_shape[2], a_shape[3]),
"dds": (b_shape[2], b_shape[3]),
}[MODE]
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
# create data
a_ref, a_tri = make_pair(a_shape, alpha=.1, dtype=DTYPE)
b_ref, b_tri = make_pair(b_shape, alpha=.1, dtype=DTYPE)
dc_ref, dc_tri = make_pair(c_shape, dtype=DTYPE)
# compute [torch]
a_ref = do_mask(a_ref) if is_dsd else a_ref
b_ref = do_mask(b_ref) if is_dds else b_ref
a_ref.retain_grad()
b_ref.retain_grad()
c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
b_ref.transpose(2, 3) if TRANS_B else b_ref)
c_ref = do_sparsify(c_ref) if is_sdd else c_ref
# dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
# c_ref.backward(dc_ref)
# da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad
# db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad
# triton result
a_tri = do_sparsify(a_tri) if is_dsd else a_tri
b_tri = do_sparsify(b_tri) if is_dds else b_tri
a_tri.retain_grad()
b_tri.retain_grad()
# op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
c_tri = op(a_tri, b_tri)
# dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri
# c_tri.backward(dc_tri)
# da_tri = a_tri.grad
# db_tri = b_tri.grad
# compare
print("--------------------------------------------------------------")
perf = lambda ms: 2 * M * N * K * Z * H * 1e9 / ( ms * 1e-3)
total_op = 2 * M * N * K * Z * H
print('''MODE={}, Z={}, H={}, M={}, N={}, K={}, total_op={}. '''
.format(MODE,Z, H, M, N, K, total_op))
diff = torch.sum(torch.abs(c_ref - c_tri))
print('total diff = {:.10f}'.format(diff))
if(torch.allclose(c_ref, c_tri, atol = 1e-5, rtol = 0)):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")
ms, _, _ = triton.testing.do_bench(lambda: op(a_tri, b_tri), rep = 20)
print('''Triton: GFLOPS: {:.3f}, time: {:.6f}ms.'''.format(perf(ms), ms))
ms_torch, _, _ = triton.testing.do_bench(
lambda: torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
b_ref.transpose(2, 3) if TRANS_B else b_ref),
rep = 20
)
print('''Torch: GFLOPS: {:.3f}, time: {:.6f}ms.'''.format(perf(ms_torch), ms_torch))
return perf(ms), perf(ms_torch), diff
test_blocksparsematmul('dds', False, False, 32, torch.float16, Z = 1, H = 2, M = 64, K = 4096, N = 4096)
This code will print total difference
more than 0.0, and the torch.allclose
is return false.
Then I tried observed some character:
- In small shape sunch as
M, N, K = 256, 256, 256
, the code always pass ✅ - I tried testing the shape in the range [1,1024] for M and
N, K = 4096, 4096
, which show the more than half of the range print the❌ Triton and Torch differ
.
So what could be causing the incorrect precision and how to solute the problem?