-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Add fast path for torch.{mm, addmm, bmm, baddbmm} when the inputs are conjugate transpose #59380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit aea9e83 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) [ghstack-poisoned]
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) [ghstack-poisoned]
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) [ghstack-poisoned]
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) [ghstack-poisoned]
@@ -1031,13 +1031,22 @@ static void addmm_impl_cpu_( | |||
const int64_t ldb = b.strides()[(transpose_b == transpose_c) ? 1 : 0]; | |||
const int64_t ldc = c.strides()[transpose_c ? 0 : 1]; | |||
|
|||
if (a.is_conj() && !transpose_a) { | |||
a.conj_physical_(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This potentially modifies an input tensor's values.
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) [ghstack-poisoned]
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) [ghstack-poisoned]
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) [ghstack-poisoned]
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) [ghstack-poisoned]
… inputs are conjugate transpose" Benchmarks for addmm: #59380 (comment) Disabled fastpath for bmm and baddbmm, due to this bug: #64103 30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas: ``` import torch x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj() y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() out=torch.addmm(x, y, z) out.sum().backward() print(torch.cuda.memory_summary()) ``` Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) [ghstack-poisoned]
…addbmm} when the inputs are conjugate transpose" Benchmarks for addmm: #59380 (comment) Disabled fastpath for bmm and baddbmm, due to this bug: #64103 30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas: ``` import torch x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj() y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() out=torch.addmm(x, y, z) out.sum().backward() print(torch.cuda.memory_summary()) ``` Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) [ghstack-poisoned]
… inputs are conjugate transpose" Benchmarks for addmm: #59380 (comment) Disabled fastpath for bmm and baddbmm, due to this bug: #64103 30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas: ``` import torch x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj() y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() out=torch.addmm(x, y, z) out.sum().backward() print(torch.cuda.memory_summary()) ``` Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) [ghstack-poisoned]
…addbmm} when the inputs are conjugate transpose" Benchmarks for addmm: #59380 (comment) Disabled fastpath for bmm and baddbmm, due to this bug: #64103 30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas: ``` import torch x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj() y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() out=torch.addmm(x, y, z) out.sum().backward() print(torch.cuda.memory_summary()) ``` Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) [ghstack-poisoned]
… inputs are conjugate transpose" Benchmarks for addmm: #59380 (comment) Disabled fastpath for bmm and baddbmm, due to this bug: #64103 30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas: ``` import torch x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj() y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() out=torch.addmm(x, y, z) out.sum().backward() print(torch.cuda.memory_summary()) ``` Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) [ghstack-poisoned]
…addbmm} when the inputs are conjugate transpose" Benchmarks for addmm: #59380 (comment) Disabled fastpath for bmm and baddbmm, due to this bug: #64103 30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas: ``` import torch x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj() y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() out=torch.addmm(x, y, z) out.sum().backward() print(torch.cuda.memory_summary()) ``` Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) [ghstack-poisoned]
@@ -1411,17 +1411,18 @@ Tensor from_file(c10::string_view filename, c10::optional<bool> shared, c10::opt | |||
Tensor clone(const Tensor& src, c10::optional<c10::MemoryFormat> optional_memory_format) { | |||
auto memory_format = | |||
optional_memory_format.value_or(MemoryFormat::Preserve); | |||
Tensor self; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixes #64176
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
test/test_linalg.py
Outdated
def _slice(tensor, fn): | ||
return fn(tensor)[:, ::2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _slice(tensor, fn): | |
return fn(tensor)[:, ::2] | |
def _slice(tensor, fn): | |
return fn(tensor)[..., ::2] |
this can be used for batched tensor also
test/test_linalg.py
Outdated
Ab = torch.randn(2, 3, 6, dtype=torch.cfloat, device=device) | ||
Bb = torch.randn(2, 3, 3, dtype=torch.cfloat, device=device) | ||
Bb_ = torch.randn(1, 3, 3, dtype=torch.cfloat, device=device).expand(2, 3, 3) | ||
out_b = torch.empty(2, 3, 3, device=device, dtype=torch.complex64).transpose(-1, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: transpose(-1, -2)
, or transpose(1,2)
, this mixing of negative and positive indices looks confusing (same in t_b
function)
… inputs are conjugate transpose" Benchmarks for addmm: #59380 (comment) Disabled fastpath for bmm and baddbmm, due to this bug: #64103 30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas: ``` import torch x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj() y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() out=torch.addmm(x, y, z) out.sum().backward() print(torch.cuda.memory_summary()) ``` Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) fixes #64176 [ghstack-poisoned]
…addbmm} when the inputs are conjugate transpose" Benchmarks for addmm: #59380 (comment) Disabled fastpath for bmm and baddbmm, due to this bug: #64103 30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas: ``` import torch x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj() y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() out=torch.addmm(x, y, z) out.sum().backward() print(torch.cuda.memory_summary()) ``` Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) fixes #64176 [ghstack-poisoned]
… inputs are conjugate transpose" Benchmarks for addmm: #59380 (comment) Disabled fastpath for bmm and baddbmm, due to this bug: #64103 30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas: ``` import torch x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj() y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() out=torch.addmm(x, y, z) out.sum().backward() print(torch.cuda.memory_summary()) ``` Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) fixes #64176 [ghstack-poisoned]
…addbmm} when the inputs are conjugate transpose" Benchmarks for addmm: #59380 (comment) Disabled fastpath for bmm and baddbmm, due to this bug: #64103 30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas: ``` import torch x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj() y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() out=torch.addmm(x, y, z) out.sum().backward() print(torch.cuda.memory_summary()) ``` Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) fixes #64176 [ghstack-poisoned]
… inputs are conjugate transpose" Benchmarks for addmm: #59380 (comment) Disabled fastpath for bmm and baddbmm, due to this bug: #64103 30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas: ``` import torch x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj() y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() out=torch.addmm(x, y, z) out.sum().backward() print(torch.cuda.memory_summary()) ``` Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) fixes #64176 [ghstack-poisoned]
…addbmm} when the inputs are conjugate transpose" Benchmarks for addmm: #59380 (comment) Disabled fastpath for bmm and baddbmm, due to this bug: #64103 30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas: ``` import torch x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj() y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() out=torch.addmm(x, y, z) out.sum().backward() print(torch.cuda.memory_summary()) ``` Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) fixes #64176 [ghstack-poisoned]
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
… inputs are conjugate transpose" Benchmarks for addmm: #59380 (comment) Disabled fastpath for bmm and baddbmm, due to this bug: #64103 30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas: ``` import torch x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj() y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() out=torch.addmm(x, y, z) out.sum().backward() print(torch.cuda.memory_summary()) ``` Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) fixes #64176 [ghstack-poisoned]
…addbmm} when the inputs are conjugate transpose" Benchmarks for addmm: #59380 (comment) Disabled fastpath for bmm and baddbmm, due to this bug: #64103 30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas: ``` import torch x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj() y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj() out=torch.addmm(x, y, z) out.sum().backward() print(torch.cuda.memory_summary()) ``` Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374) fixes #64176 [ghstack-poisoned]
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Stack from ghstack:
Benchmarks for addmm: #59380 (comment)
Disabled fastpath for bmm and baddbmm, due to this bug: #64103
30% decrease in peak memory usage when
conj_physical
is used in the code sample below as well as the backward formulas:Differential Revision: D28898374
fixes #64176