Skip to content

Conversation

weiyangfb
Copy link
Contributor

@weiyangfb weiyangfb commented Nov 7, 2018

>>> a = torch.randn(1, 2, 3, 4)
>>> a_flip = flip(a, [1, 3])
>>> a_flip.shape
torch.Size([2, 4, 1, 3]) # need to permute([2, 0, 3, 1])

# =====  original flip() impl only works for this case =======
>>> a = torch.randn(1, 2, 3, 4)
>>> a_flip = flip(a, [0, 2])
>>> a_flip.shape
torch.Size([1, 3, 2, 4]) # need to permute([0, 2, 1, 3])

>>> a = torch.randn(1, 2, 3, 4)
>>> a_flip = flip(a, [0, 3])
>>> a_flip.shape
torch.Size([1, 4, 2, 3]) # need to permute([0, 2, 3, 1])
  • performance:
>>> a = torch.randn(1000, 1000)
>>> %timeit -r 10 a.flip(0, 1)
15.2 ms ± 2.57 ms per loop (mean ± std. dev. of 10 runs, 100 loops each)

cc @fmassa @ssnl

@@ -30,6 +30,10 @@ Tensor flip_cpu(const Tensor& self, IntList dims) {

// check if distance between two flip dims >= 2, where permute of output tensor is needed,
// because the advanced indexing puts all non-consecutive indices in the beginning of the tensor
auto out_tensor = self.index(TensorList(final_indices));
auto out_sizes_idx = std::vector<int64_t>(out_tensor.dim());

This comment was marked as off-topic.

@@ -49,11 +53,10 @@ Tensor flip_cpu(const Tensor& self, IntList dims) {
permute_order.emplace_back(i);
}
}
auto out_tensor = self.index(TensorList(final_indices));
return out_tensor.permute(IntList(permute_order));
std::sort(out_sizes_idx.begin(), out_sizes_idx.end(),

This comment was marked as off-topic.

@weiyangfb weiyangfb changed the title fix flip() shape bug by sorting permute index [DONT MERGE] demonstrate cause of shape bug at flip() Nov 7, 2018
@weiyangfb weiyangfb force-pushed the flip_fix_permute_bug branch from a58ba40 to 22c1342 Compare November 7, 2018 23:16
@weiyangfb weiyangfb closed this Nov 8, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants