Skip to content

Conversation

weiyangfb
Copy link
Contributor

@weiyangfb weiyangfb commented Oct 30, 2018

>>> t = torch.randn(1, 3, 4, 5)
>> t.flip(1, 3).shape
torch.Size([1, 3, 4, 5])
  • performance:
====== with this PR ======
>>> a = torch.randn(1000, 1000)
>>> %timeit -r 100 a.flip(0, 1)
1.98 ms ± 579 µs per loop (mean ± std. dev. of 100 runs, 1000 loops each)

====== Perf at previous PR #7873 ======   
100 loops, best of 3: 11 ms per loop

@weiyangfb weiyangfb changed the title [wip] fix flip() shape bug in CPU fix flip() shape bug in CPU Oct 31, 2018
Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what is the original cause of the bug? Also, could you benchmark on larger tensors? The OMP doesn't kick in until 1000 numel.


dim_list_to_bitset(dims, total_dims);

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

const int64_t numel = in_tensor.numel();
auto strides = in_tensor.strides();
auto strides_v = strides.vec();
auto strides_t = at::CPU(kLong).tensorFromBlob(strides_v.data(), {static_cast<int64_t>(strides_v.size())});

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@weiyangfb
Copy link
Contributor Author

@ssnl Thanks! The case I showed has numel = 10^6:

====== with this PR ======
>>> a = torch.randn(1000, 1000)
>>> %timeit -r 100 a.flip(0, 1)
1.98 ms ± 579 µs per loop (mean ± std. dev. of 100 runs, 1000 loops each)

@weiyangfb
Copy link
Contributor Author

The previous bug was caused by misused of advance indexing that I still can't figure out. Since a customized kernel is faster, so I just use a kernel instead.

@ssnl
Copy link
Collaborator

ssnl commented Oct 31, 2018

@weiyangfb Ah you are right about the benchmark. Sorry about it!

@ssnl
Copy link
Collaborator

ssnl commented Oct 31, 2018

Maybe worth trying to look into the advanced indexing issue further in future. We may have a bug there.


dim_list_to_bitset(dims, total_dims); // returned bitset is not used, here only check correctness of dims

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

maybe_wrap_dims(flip_dims_v, total_dims);

auto sizes = in_tensor.sizes();
auto flip_dims_t = at::CPU(kLong).tensorFromBlob(flip_dims_v.data(), {static_cast<int64_t>(flip_dims_v.size())});

This comment was marked as off-topic.

This comment was marked as off-topic.

@weiyangfb
Copy link
Contributor Author

weiyangfb commented Oct 31, 2018

Maybe worth trying to look into the advanced indexing issue further in future. We may have a bug there.

Yeah, definitely! can I land this PR to fix the bug on user side first? We can keep the issue open until I figure out the root cause of it.

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that doesn't need to be done in this PR :)

Tensor out_tensor = at::empty_like(in_tensor);

// create contiguous strides for input tensor
Tensor stride_contiguous = at::zeros({total_dims}, kLong);

This comment was marked as off-topic.

This comment was marked as off-topic.

void inline flip_cpu_kernel(
const int64_t total_dims,
const int64_t* stride_contiguous_d,
const std::bitset<dim_bitset_size>& flip_dims_b,

This comment was marked as off-topic.

This comment was marked as off-topic.

int64_t temp = cur_indices;
cur_indices = cur_indices / stride_contiguous_d[d];
rem = temp - cur_indices * stride_contiguous_d[d];
if (flip_dims_b[d]) cur_indices = in_tensor.size(d) - 1 - cur_indices;

This comment was marked as off-topic.

This comment was marked as off-topic.

@weiyangfb
Copy link
Contributor Author

@ssnl I guess I found the root cause of the bug: #13682. Still prefer this PR since it is faster

@weiyangfb
Copy link
Contributor Author

can I get a stamp on this? cc @ssnl

@ssnl
Copy link
Collaborator

ssnl commented Nov 7, 2018

wait... so which one should I look at?

@weiyangfb
Copy link
Contributor Author

@ssnl Sorry about the confusion. You should look at this one. I just change the title of the other one: #13682

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for late review

zdevito pushed a commit to zdevito/ATen that referenced this pull request Nov 8, 2018
Summary:
- a walk around for #13292, a complete fix requires investigation on the root cause when using advanced indexing
- this PR brings in `filp()` CUDA implementation for CPU kernel
- with this change:
```
>>> t = torch.randn(1, 3, 4, 5)
>> t.flip(1, 3).shape
torch.Size([1, 3, 4, 5])
```
- performance:
```
====== with this PR ======
>>> a = torch.randn(1000, 1000)
>>> %timeit -r 100 a.flip(0, 1)
1.98 ms ± 579 µs per loop (mean ± std. dev. of 100 runs, 1000 loops each)

====== Perf at previous PR #7873 ======
100 loops, best of 3: 11 ms per loop
```
Pull Request resolved: pytorch/pytorch#13344

Differential Revision: D12968003

Pulled By: weiyangfb

fbshipit-source-id: 66f434049d143a0575a35b5c983b3e0577a1a28d
@ezyang ezyang added the merged label Jun 25, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch.flip incorrect behavior
4 participants