Skip to content

Conversation

anjali411
Copy link
Contributor

@anjali411 anjali411 commented Jun 3, 2021

Stack from ghstack:

Benchmarks for addmm: #59380 (comment)

Disabled fastpath for bmm and baddbmm, due to this bug: #64103

30% decrease in peak memory usage when conj_physical is used in the code sample below as well as the backward formulas:

import torch

x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj()
y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()
z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()

out=torch.addmm(x, y, z)

out.sum().backward()

print(torch.cuda.memory_summary())

Differential Revision: D28898374

fixes #64176

This was referenced Jun 3, 2021
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 3, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit aea9e83 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

anjali411 added a commit that referenced this pull request Jun 3, 2021
ghstack-source-id: e33a6f0
Pull Request resolved: #59380
anjali411 added a commit that referenced this pull request Jun 3, 2021
ghstack-source-id: ade5a04
Pull Request resolved: #59380
anjali411 added a commit that referenced this pull request Jun 3, 2021
ghstack-source-id: 1816c2f
Pull Request resolved: #59380
anjali411 added a commit that referenced this pull request Jun 4, 2021
ghstack-source-id: 3d23b7a
Pull Request resolved: #59380
@anjali411
Copy link
Contributor Author

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

anjali411 added a commit that referenced this pull request Jun 4, 2021
ghstack-source-id: 2b88040
Pull Request resolved: #59380
anjali411 added a commit that referenced this pull request Jun 7, 2021
ghstack-source-id: 96e6e68
Pull Request resolved: #59380
anjali411 added a commit that referenced this pull request Jun 23, 2021
ghstack-source-id: 982fb51
Pull Request resolved: #59380
anjali411 added a commit that referenced this pull request Jul 8, 2021
ghstack-source-id: 854b7a3
Pull Request resolved: #59380
@@ -1031,13 +1031,22 @@ static void addmm_impl_cpu_(
const int64_t ldb = b.strides()[(transpose_b == transpose_c) ? 1 : 0];
const int64_t ldc = c.strides()[transpose_c ? 0 : 1];

if (a.is_conj() && !transpose_a) {
a.conj_physical_();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This potentially modifies an input tensor's values.

anjali411 added a commit that referenced this pull request Jul 9, 2021
ghstack-source-id: c6f0af7
Pull Request resolved: #59380
… inputs are conjugate transpose"


Benchmarks for addmm: #59380 (comment)

Disabled fastpath for bmm and baddbmm, due to this bug: #64103

30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas:
```
import torch

x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj()
y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()
z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()

out=torch.addmm(x, y, z)

out.sum().backward()

print(torch.cuda.memory_summary())
```
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 30, 2021
…addbmm} when the inputs are conjugate transpose"


Benchmarks for addmm: #59380 (comment)

Disabled fastpath for bmm and baddbmm, due to this bug: #64103

30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas:
```
import torch

x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj()
y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()
z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()

out=torch.addmm(x, y, z)

out.sum().backward()

print(torch.cuda.memory_summary())
```
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374)

[ghstack-poisoned]
… inputs are conjugate transpose"


Benchmarks for addmm: #59380 (comment)

Disabled fastpath for bmm and baddbmm, due to this bug: #64103

30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas:
```
import torch

x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj()
y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()
z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()

out=torch.addmm(x, y, z)

out.sum().backward()

print(torch.cuda.memory_summary())
```
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 30, 2021
…addbmm} when the inputs are conjugate transpose"


Benchmarks for addmm: #59380 (comment)

Disabled fastpath for bmm and baddbmm, due to this bug: #64103

30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas:
```
import torch

x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj()
y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()
z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()

out=torch.addmm(x, y, z)

out.sum().backward()

print(torch.cuda.memory_summary())
```
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374)

[ghstack-poisoned]
… inputs are conjugate transpose"


Benchmarks for addmm: #59380 (comment)

Disabled fastpath for bmm and baddbmm, due to this bug: #64103

30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas:
```
import torch

x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj()
y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()
z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()

out=torch.addmm(x, y, z)

out.sum().backward()

print(torch.cuda.memory_summary())
```
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 30, 2021
…addbmm} when the inputs are conjugate transpose"


Benchmarks for addmm: #59380 (comment)

Disabled fastpath for bmm and baddbmm, due to this bug: #64103

30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas:
```
import torch

x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj()
y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()
z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()

out=torch.addmm(x, y, z)

out.sum().backward()

print(torch.cuda.memory_summary())
```
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 30, 2021
ghstack-source-id: 96627d3
Pull Request resolved: #59380
@@ -1411,17 +1411,18 @@ Tensor from_file(c10::string_view filename, c10::optional<bool> shared, c10::opt
Tensor clone(const Tensor& src, c10::optional<c10::MemoryFormat> optional_memory_format) {
auto memory_format =
optional_memory_format.value_or(MemoryFormat::Preserve);
Tensor self;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixes #64176

@anjali411
Copy link
Contributor Author

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Comment on lines 6170 to 6171
def _slice(tensor, fn):
return fn(tensor)[:, ::2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _slice(tensor, fn):
return fn(tensor)[:, ::2]
def _slice(tensor, fn):
return fn(tensor)[..., ::2]

this can be used for batched tensor also

Ab = torch.randn(2, 3, 6, dtype=torch.cfloat, device=device)
Bb = torch.randn(2, 3, 3, dtype=torch.cfloat, device=device)
Bb_ = torch.randn(1, 3, 3, dtype=torch.cfloat, device=device).expand(2, 3, 3)
out_b = torch.empty(2, 3, 3, device=device, dtype=torch.complex64).transpose(-1, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: transpose(-1, -2), or transpose(1,2), this mixing of negative and positive indices looks confusing (same in t_b function)

… inputs are conjugate transpose"


Benchmarks for addmm: #59380 (comment)

Disabled fastpath for bmm and baddbmm, due to this bug: #64103

30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas:
```
import torch

x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj()
y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()
z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()

out=torch.addmm(x, y, z)

out.sum().backward()

print(torch.cuda.memory_summary())
```
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374)

fixes #64176

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 31, 2021
…addbmm} when the inputs are conjugate transpose"


Benchmarks for addmm: #59380 (comment)

Disabled fastpath for bmm and baddbmm, due to this bug: #64103

30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas:
```
import torch

x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj()
y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()
z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()

out=torch.addmm(x, y, z)

out.sum().backward()

print(torch.cuda.memory_summary())
```
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374)

fixes #64176

[ghstack-poisoned]
… inputs are conjugate transpose"


Benchmarks for addmm: #59380 (comment)

Disabled fastpath for bmm and baddbmm, due to this bug: #64103

30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas:
```
import torch

x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj()
y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()
z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()

out=torch.addmm(x, y, z)

out.sum().backward()

print(torch.cuda.memory_summary())
```
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374)

fixes #64176

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 31, 2021
…addbmm} when the inputs are conjugate transpose"


Benchmarks for addmm: #59380 (comment)

Disabled fastpath for bmm and baddbmm, due to this bug: #64103

30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas:
```
import torch

x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj()
y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()
z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()

out=torch.addmm(x, y, z)

out.sum().backward()

print(torch.cuda.memory_summary())
```
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374)

fixes #64176

[ghstack-poisoned]
… inputs are conjugate transpose"


Benchmarks for addmm: #59380 (comment)

Disabled fastpath for bmm and baddbmm, due to this bug: #64103

30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas:
```
import torch

x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj()
y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()
z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()

out=torch.addmm(x, y, z)

out.sum().backward()

print(torch.cuda.memory_summary())
```
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374)

fixes #64176

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 31, 2021
…addbmm} when the inputs are conjugate transpose"


Benchmarks for addmm: #59380 (comment)

Disabled fastpath for bmm and baddbmm, due to this bug: #64103

30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas:
```
import torch

x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj()
y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()
z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()

out=torch.addmm(x, y, z)

out.sum().backward()

print(torch.cuda.memory_summary())
```
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374)

fixes #64176

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 31, 2021
ghstack-source-id: 8b83796
Pull Request resolved: #59380
@anjali411
Copy link
Contributor Author

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

… inputs are conjugate transpose"


Benchmarks for addmm: #59380 (comment)

Disabled fastpath for bmm and baddbmm, due to this bug: #64103

30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas:
```
import torch

x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj()
y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()
z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()

out=torch.addmm(x, y, z)

out.sum().backward()

print(torch.cuda.memory_summary())
```
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374)

fixes #64176

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Sep 1, 2021
ghstack-source-id: 39d1c5f
Pull Request resolved: #59380
anjali411 added a commit that referenced this pull request Sep 1, 2021
…addbmm} when the inputs are conjugate transpose"


Benchmarks for addmm: #59380 (comment)

Disabled fastpath for bmm and baddbmm, due to this bug: #64103

30% decrease in peak memory usage when `conj_physical` is used in the code sample below as well as the backward formulas:
```
import torch

x=torch.randn(2, 3, dtype=torch.cdouble, device='cuda', requires_grad=True).conj()
y=torch.randn(3, 2, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()
z=torch.randn(3, 3, dtype=torch.cdouble, device='cuda',requires_grad=True).transpose(0, 1).conj()

out=torch.addmm(x, y, z)

out.sum().backward()

print(torch.cuda.memory_summary())
```
Differential Revision: [D28898374](https://our.internmc.facebook.com/intern/diff/D28898374)

fixes #64176

[ghstack-poisoned]
@anjali411
Copy link
Contributor Author

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed module: complex Related to complex number support in PyTorch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants