-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Closed
Description
🚀 The feature, motivation and pitch
Right now if you try to use torch.triu on bfloat16 tensor (I hit it when was training simple network with AMP) you'll get a error that triu is not support
In [8]: torch.__version__
Out[8]: '2.0.1+cu117'
In [9]: torch.arange(4).reshape(2,2).bfloat16().cuda().triu()
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[9], line 1
----> 1 torch.arange(4).reshape(2,2).bfloat16().cuda().triu()
RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'
It would be nice to have it.
Alternatives
Can be replaced with multiplication against torch.ones.triu()
self.register_buffer("triu0", torch.ones(n, n).triu()) # __init__
...
y = x * self.triu0 # forward()
Additional context
No response
YerayLvadimkantorov
Metadata
Metadata
Assignees
Labels
No labels