Skip to content

Conversation

hongxiayang
Copy link
Collaborator

@hongxiayang hongxiayang commented Jul 17, 2024

…with large index

Fixes #130806
When an output size of 2147483648 (=131072*16384) is expected in the above issue, it throwed out the following error:
RuntimeError: HIP error: invalid configuration argument

What happened was that the second parameter passed to hipLaunchKernel was crazy {2147483648,1,1}.
Found issues in the Indexing.cu:
On ROCm, std::min -> ::min did not work as expected when outTotalSize>=2147483648

As the result, 2147483648 was sent to hipLaunchKernel which the GPU does not support such a huge number since this number specifies the number of threads per block. The original code intended to set 128 threads per block, though this is debatable as the perf would not good for latest powerful GPUs (a TODO item to update for perf maybe?) , but at least it would not cause invalid configuration argument error.

[Test]
Run the same code snippet in the issue, and print the output, its dim and numel(), which looks like below now:

output=tensor([[ 0.4044, -0.0244, -0.6865,  ..., -0.7800,  0.1175,  1.6726],
        [-1.0866, -0.1609,  0.3538,  ...,  1.9105,  0.7882,  1.1583],
        [-2.2079,  0.3736,  0.3610,  ..., -0.2658, -0.0459,  1.3077],
        ...,
        [ 0.8753, -0.7482, -0.1978,  ...,  0.9016,  1.1501, -0.5178],
        [-1.5845, -0.6277,  1.4520,  ...,  0.5733, -2.1198, -0.0915],
        [-0.6310, -1.0239, -0.1910,  ...,  0.4309,  0.1630,  0.3239]],
       device='cuda:0'), dim=2, numel=2147483648

Added a large tensor unit test too.

/pytorch# pytest test/nn/test_embedding.py -k test_large_tensors
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.9.19, pytest-7.3.2, pluggy-1.4.0
rootdir: /dockerx/development/pytorch
configfile: pytest.ini
plugins: flakefinder-1.1.0, rerunfailures-14.0, xdist-3.3.1, xdoctest-1.1.0, cpp-2.3.0, hypothesis-5.35.1
collected 288 items / 287 deselected / 1 selected                                                                                                                                        
Running 1 items in this shard

test/nn/test_embedding.py .                                                                                                                                                        [100%]

=========================================================================== 1 passed, 287 deselected in 3.16s ============================================================================

Copy link

pytorch-bot bot commented Jul 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130994

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 42d6034 with merge base fdd0a7f (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: cuda release notes category label Jul 17, 2024
Copy link
Collaborator

@jeffdaily jeffdaily left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Admittedly, I don't yet understand what went wrong and why this fixes it. But I do see some use of long and some of int64_t, so perhaps settle on using int64_t?

@hongxiayang
Copy link
Collaborator Author

Admittedly, I don't yet understand what went wrong and why this fixes it. But I do see some use of long and some of int64_t, so perhaps settle on using int64_t?

I updated the description in the PR about what went wrong, hoping that helps for the understanding.

@hongxiayang hongxiayang marked this pull request as ready for review July 19, 2024 03:20
@hongxiayang hongxiayang requested a review from eqy as a code owner July 19, 2024 03:20
@hongxiayang
Copy link
Collaborator Author

There is more opportunity to refactor the code to make it better. Will leave it for future work as this is an urgent ask.

@hongxiayang hongxiayang requested a review from jeffdaily July 19, 2024 03:28
@facebook-github-bot
Copy link
Contributor

@xw285cornell has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Collaborator

@eqy eqy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add a test (perhaps decorated with @largeTensorTest) if necessary for this case

@hongxiayang hongxiayang requested a review from eqy July 19, 2024 17:09
@hongxiayang
Copy link
Collaborator Author

hongxiayang commented Jul 19, 2024

Added unit test.

Copy link
Collaborator

@jeffdaily jeffdaily left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve if CI signal is good.

@hongxiayang
Copy link
Collaborator Author

Approve if CI signal is good.

The only failing one is that Meta internal Diff is not in sync with external PR as I added a unit test as suggested by @eqy after @xw285cornell imported it. @xw285cornell Please import again to make it sync. Thanks.

@facebook-github-bot
Copy link
Contributor

@xw285cornell has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@xw285cornell
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 20, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@xw285cornell
Copy link
Contributor

@hongxiayang @jeffdaily thanks for the fix!

Just wondering, "2: On ROCm, std::min -> ::min did not work as expected when outTotalSize>=2147483648", shall we fix this in rocm?

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / win-vs2019-cpu-py3 / test (default, 1, 3, windows.4xlarge.nonephemeral)

Details for Dev Infra team Raised by workflow job

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge -f 'Landed internally'

(Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

int64_t selfReduceDimSize = self_.size(dim);
ptrdiff_t numIndex = index.numel();
int64_t selfNumel = self_.numel();
uint64_t sliceSize = getSliceSize(self_, dim, index, source_);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @hongxiayang for the fix!
In the description you mentioned 1 and 2, and the diff looks mainly fix 1 by using uint64_t which has larger representable range. But I just wondering if we actually should fix 2: On ROCm, std::min -> ::min did not work as expected instead? (although uint64_t has much larger representable, but in a extreme, in the future if our LLM context length continue to grow, will it to a point that beyond uint64_t representable range?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pull request fixed 1 using uint64_t and had a work-around for 2 using "<" instead of "min" function. The complete fix should use < in the lower level of library for support for 64bit integers in addition to "int". I am following up for 2 now to see why such support was not there for pytorch to use.

DiweiSun pushed a commit to DiweiSun/pytorch that referenced this pull request Jul 22, 2024
pytorch#130994)

…with large index

Fixes pytorch#130806
When an output size of 2147483648 (=131072*16384) is expected in the above issue, it throwed out the following error:
RuntimeError: HIP error: invalid configuration argument

What happened was that the second parameter passed to hipLaunchKernel was crazy {2147483648,1,1}.
Found two issues in the Indexing.cu:

1: ptrdiff_t was used but it is signed int,  outTotalSize >= 2147483648 can cause overflow when doing [this](https://github.com/pytorch/pytorch/blame/39493aa93419532957e6e5ee97cae842b53b8b59/aten/src/ATen/native/cuda/Indexing.cu#L1367):
2: On ROCm, std::min -> ::min did not work as expected when outTotalSize>=2147483648

As the result, 2147483648 was sent to hipLaunchKernel which the GPU does not support such a huge number since this number specifies the number of threads per block. The original code intended to set 128 threads per block, though this is debatable as the perf would not good for latest powerful GPUs (a TODO item to update for perf maybe?) , but at least it would not cause `invalid configuration argument` error.

[Test]
Run the same code snippet in the [issue](pytorch#130806), and print the output, its dim and numel(), which looks like below now:
```
output=tensor([[ 0.4044, -0.0244, -0.6865,  ..., -0.7800,  0.1175,  1.6726],
        [-1.0866, -0.1609,  0.3538,  ...,  1.9105,  0.7882,  1.1583],
        [-2.2079,  0.3736,  0.3610,  ..., -0.2658, -0.0459,  1.3077],
        ...,
        [ 0.8753, -0.7482, -0.1978,  ...,  0.9016,  1.1501, -0.5178],
        [-1.5845, -0.6277,  1.4520,  ...,  0.5733, -2.1198, -0.0915],
        [-0.6310, -1.0239, -0.1910,  ...,  0.4309,  0.1630,  0.3239]],
       device='cuda:0'), dim=2, numel=2147483648
```

Added a large tensor unit test too.
```
/pytorch# pytest test/nn/test_embedding.py -k test_large_tensors
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.9.19, pytest-7.3.2, pluggy-1.4.0
rootdir: /dockerx/development/pytorch
configfile: pytest.ini
plugins: flakefinder-1.1.0, rerunfailures-14.0, xdist-3.3.1, xdoctest-1.1.0, cpp-2.3.0, hypothesis-5.35.1
collected 288 items / 287 deselected / 1 selected
Running 1 items in this shard

test/nn/test_embedding.py .                                                                                                                                                        [100%]

=========================================================================== 1 passed, 287 deselected in 3.16s ============================================================================
```
Pull Request resolved: pytorch#130994
Approved by: https://github.com/jeffdaily, https://github.com/xw285cornell
@hongxiayang
Copy link
Collaborator Author

@hongxiayang @jeffdaily thanks for the fix!

Just wondering, "2: On ROCm, std::min -> ::min did not work as expected when outTotalSize>=2147483648", shall we fix this in rocm?

will follow up on this. thanks.

xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
pytorch#130994)

…with large index

Fixes pytorch#130806
When an output size of 2147483648 (=131072*16384) is expected in the above issue, it throwed out the following error:
RuntimeError: HIP error: invalid configuration argument

What happened was that the second parameter passed to hipLaunchKernel was crazy {2147483648,1,1}.
Found two issues in the Indexing.cu:

1: ptrdiff_t was used but it is signed int,  outTotalSize >= 2147483648 can cause overflow when doing [this](https://github.com/pytorch/pytorch/blame/39493aa93419532957e6e5ee97cae842b53b8b59/aten/src/ATen/native/cuda/Indexing.cu#L1367):
2: On ROCm, std::min -> ::min did not work as expected when outTotalSize>=2147483648

As the result, 2147483648 was sent to hipLaunchKernel which the GPU does not support such a huge number since this number specifies the number of threads per block. The original code intended to set 128 threads per block, though this is debatable as the perf would not good for latest powerful GPUs (a TODO item to update for perf maybe?) , but at least it would not cause `invalid configuration argument` error.

[Test]
Run the same code snippet in the [issue](pytorch#130806), and print the output, its dim and numel(), which looks like below now:
```
output=tensor([[ 0.4044, -0.0244, -0.6865,  ..., -0.7800,  0.1175,  1.6726],
        [-1.0866, -0.1609,  0.3538,  ...,  1.9105,  0.7882,  1.1583],
        [-2.2079,  0.3736,  0.3610,  ..., -0.2658, -0.0459,  1.3077],
        ...,
        [ 0.8753, -0.7482, -0.1978,  ...,  0.9016,  1.1501, -0.5178],
        [-1.5845, -0.6277,  1.4520,  ...,  0.5733, -2.1198, -0.0915],
        [-0.6310, -1.0239, -0.1910,  ...,  0.4309,  0.1630,  0.3239]],
       device='cuda:0'), dim=2, numel=2147483648
```

Added a large tensor unit test too.
```
/pytorch# pytest test/nn/test_embedding.py -k test_large_tensors
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.9.19, pytest-7.3.2, pluggy-1.4.0
rootdir: /dockerx/development/pytorch
configfile: pytest.ini
plugins: flakefinder-1.1.0, rerunfailures-14.0, xdist-3.3.1, xdoctest-1.1.0, cpp-2.3.0, hypothesis-5.35.1
collected 288 items / 287 deselected / 1 selected
Running 1 items in this shard

test/nn/test_embedding.py .                                                                                                                                                        [100%]

=========================================================================== 1 passed, 287 deselected in 3.16s ============================================================================
```
Pull Request resolved: pytorch#130994
Approved by: https://github.com/jeffdaily, https://github.com/xw285cornell
pytorchmergebot pushed a commit that referenced this pull request Jul 31, 2024
As observed during working on this fix (#130994), 128 threads per block seems quite low. This PR is to increase the default to improve the performance, and also slightly refactoring the code to replace the hard-coded 128 for better maintenance.

By increasing the default max threads per block from 128 to 256, I saw for `aten::index_select`,  its "CUDA total" time drop from 44.820ms to 33.608ms by profiling below embedding script:
```
input = torch.randint(low=0, high=16032, size=[131072], device="cuda")
w = torch.randn([16032, 16384], device="cuda")

with profiler.profile(record_shapes=True) as prof:
    x = torch.nn.functional.embedding(input, w)

```
I tested with the default from 128 to 256, 512, 1024 on several different types of devices, and observed "CUDA total" time dropping even more and more latency improvement as the number increases. Below is one example of latency improvement ratio:
128 | 1x
256 | 1.33x
512 | 1.44x
1024 | 1.49x

Using 512 as the new default max for non-mi300x to be conservative, which is 1.44x faster than using 128 with the above profiling script.

Using 1024 for mi300x is 1.61x faster than using 128 with the same profiling script, and using 512 is 1.57x faster.

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Pull Request resolved: #131713
Approved by: https://github.com/jeffdaily, https://github.com/syed-ahmed, https://github.com/malfet
@pruthvistony pruthvistony added this to the 2.4.1 milestone Aug 13, 2024
@hongxiayang
Copy link
Collaborator Author

@pytorchbot cherry-pick --onto release/2.4 -c critical

pytorchbot pushed a commit that referenced this pull request Aug 13, 2024
#130994)

…with large index

Fixes #130806
When an output size of 2147483648 (=131072*16384) is expected in the above issue, it throwed out the following error:
RuntimeError: HIP error: invalid configuration argument

What happened was that the second parameter passed to hipLaunchKernel was crazy {2147483648,1,1}.
Found two issues in the Indexing.cu:

1: ptrdiff_t was used but it is signed int,  outTotalSize >= 2147483648 can cause overflow when doing [this](https://github.com/pytorch/pytorch/blame/39493aa93419532957e6e5ee97cae842b53b8b59/aten/src/ATen/native/cuda/Indexing.cu#L1367):
2: On ROCm, std::min -> ::min did not work as expected when outTotalSize>=2147483648

As the result, 2147483648 was sent to hipLaunchKernel which the GPU does not support such a huge number since this number specifies the number of threads per block. The original code intended to set 128 threads per block, though this is debatable as the perf would not good for latest powerful GPUs (a TODO item to update for perf maybe?) , but at least it would not cause `invalid configuration argument` error.

[Test]
Run the same code snippet in the [issue](#130806), and print the output, its dim and numel(), which looks like below now:
```
output=tensor([[ 0.4044, -0.0244, -0.6865,  ..., -0.7800,  0.1175,  1.6726],
        [-1.0866, -0.1609,  0.3538,  ...,  1.9105,  0.7882,  1.1583],
        [-2.2079,  0.3736,  0.3610,  ..., -0.2658, -0.0459,  1.3077],
        ...,
        [ 0.8753, -0.7482, -0.1978,  ...,  0.9016,  1.1501, -0.5178],
        [-1.5845, -0.6277,  1.4520,  ...,  0.5733, -2.1198, -0.0915],
        [-0.6310, -1.0239, -0.1910,  ...,  0.4309,  0.1630,  0.3239]],
       device='cuda:0'), dim=2, numel=2147483648
```

Added a large tensor unit test too.
```
/pytorch# pytest test/nn/test_embedding.py -k test_large_tensors
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.9.19, pytest-7.3.2, pluggy-1.4.0
rootdir: /dockerx/development/pytorch
configfile: pytest.ini
plugins: flakefinder-1.1.0, rerunfailures-14.0, xdist-3.3.1, xdoctest-1.1.0, cpp-2.3.0, hypothesis-5.35.1
collected 288 items / 287 deselected / 1 selected
Running 1 items in this shard

test/nn/test_embedding.py .                                                                                                                                                        [100%]

=========================================================================== 1 passed, 287 deselected in 3.16s ============================================================================
```
Pull Request resolved: #130994
Approved by: https://github.com/jeffdaily, https://github.com/xw285cornell

(cherry picked from commit 637ab85)
@pytorchbot
Copy link
Collaborator

Cherry picking #130994

The cherry pick PR is at #133346 and it is recommended to link a critical cherry pick PR with an issue. The following tracker issues are updated:

Details for Dev Infra team Raised by workflow job

atalman pushed a commit that referenced this pull request Aug 14, 2024
#133346)

fix for launching kernel invalid config error when calling embedding … (#130994)

…with large index

Fixes #130806
When an output size of 2147483648 (=131072*16384) is expected in the above issue, it throwed out the following error:
RuntimeError: HIP error: invalid configuration argument

What happened was that the second parameter passed to hipLaunchKernel was crazy {2147483648,1,1}.
Found two issues in the Indexing.cu:

1: ptrdiff_t was used but it is signed int,  outTotalSize >= 2147483648 can cause overflow when doing [this](https://github.com/pytorch/pytorch/blame/39493aa93419532957e6e5ee97cae842b53b8b59/aten/src/ATen/native/cuda/Indexing.cu#L1367):
2: On ROCm, std::min -> ::min did not work as expected when outTotalSize>=2147483648

As the result, 2147483648 was sent to hipLaunchKernel which the GPU does not support such a huge number since this number specifies the number of threads per block. The original code intended to set 128 threads per block, though this is debatable as the perf would not good for latest powerful GPUs (a TODO item to update for perf maybe?) , but at least it would not cause `invalid configuration argument` error.

[Test]
Run the same code snippet in the [issue](#130806), and print the output, its dim and numel(), which looks like below now:
```
output=tensor([[ 0.4044, -0.0244, -0.6865,  ..., -0.7800,  0.1175,  1.6726],
        [-1.0866, -0.1609,  0.3538,  ...,  1.9105,  0.7882,  1.1583],
        [-2.2079,  0.3736,  0.3610,  ..., -0.2658, -0.0459,  1.3077],
        ...,
        [ 0.8753, -0.7482, -0.1978,  ...,  0.9016,  1.1501, -0.5178],
        [-1.5845, -0.6277,  1.4520,  ...,  0.5733, -2.1198, -0.0915],
        [-0.6310, -1.0239, -0.1910,  ...,  0.4309,  0.1630,  0.3239]],
       device='cuda:0'), dim=2, numel=2147483648
```

Added a large tensor unit test too.
```
/pytorch# pytest test/nn/test_embedding.py -k test_large_tensors
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.9.19, pytest-7.3.2, pluggy-1.4.0
rootdir: /dockerx/development/pytorch
configfile: pytest.ini
plugins: flakefinder-1.1.0, rerunfailures-14.0, xdist-3.3.1, xdoctest-1.1.0, cpp-2.3.0, hypothesis-5.35.1
collected 288 items / 287 deselected / 1 selected
Running 1 items in this shard

test/nn/test_embedding.py .                                                                                                                                                        [100%]

=========================================================================== 1 passed, 287 deselected in 3.16s ============================================================================
```
Pull Request resolved: #130994
Approved by: https://github.com/jeffdaily, https://github.com/xw285cornell

(cherry picked from commit 637ab85)

Co-authored-by: hongxyan <hongxyan@amd.com>
pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request Aug 15, 2024
pytorch#133346)

fix for launching kernel invalid config error when calling embedding … (pytorch#130994)

…with large index

Fixes pytorch#130806
When an output size of 2147483648 (=131072*16384) is expected in the above issue, it throwed out the following error:
RuntimeError: HIP error: invalid configuration argument

What happened was that the second parameter passed to hipLaunchKernel was crazy {2147483648,1,1}.
Found two issues in the Indexing.cu:

1: ptrdiff_t was used but it is signed int,  outTotalSize >= 2147483648 can cause overflow when doing [this](https://github.com/pytorch/pytorch/blame/39493aa93419532957e6e5ee97cae842b53b8b59/aten/src/ATen/native/cuda/Indexing.cu#L1367):
2: On ROCm, std::min -> ::min did not work as expected when outTotalSize>=2147483648

As the result, 2147483648 was sent to hipLaunchKernel which the GPU does not support such a huge number since this number specifies the number of threads per block. The original code intended to set 128 threads per block, though this is debatable as the perf would not good for latest powerful GPUs (a TODO item to update for perf maybe?) , but at least it would not cause `invalid configuration argument` error.

[Test]
Run the same code snippet in the [issue](pytorch#130806), and print the output, its dim and numel(), which looks like below now:
```
output=tensor([[ 0.4044, -0.0244, -0.6865,  ..., -0.7800,  0.1175,  1.6726],
        [-1.0866, -0.1609,  0.3538,  ...,  1.9105,  0.7882,  1.1583],
        [-2.2079,  0.3736,  0.3610,  ..., -0.2658, -0.0459,  1.3077],
        ...,
        [ 0.8753, -0.7482, -0.1978,  ...,  0.9016,  1.1501, -0.5178],
        [-1.5845, -0.6277,  1.4520,  ...,  0.5733, -2.1198, -0.0915],
        [-0.6310, -1.0239, -0.1910,  ...,  0.4309,  0.1630,  0.3239]],
       device='cuda:0'), dim=2, numel=2147483648
```

Added a large tensor unit test too.
```
/pytorch# pytest test/nn/test_embedding.py -k test_large_tensors
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.9.19, pytest-7.3.2, pluggy-1.4.0
rootdir: /dockerx/development/pytorch
configfile: pytest.ini
plugins: flakefinder-1.1.0, rerunfailures-14.0, xdist-3.3.1, xdoctest-1.1.0, cpp-2.3.0, hypothesis-5.35.1
collected 288 items / 287 deselected / 1 selected
Running 1 items in this shard

test/nn/test_embedding.py .                                                                                                                                                        [100%]

=========================================================================== 1 passed, 287 deselected in 3.16s ============================================================================
```
Pull Request resolved: pytorch#130994
Approved by: https://github.com/jeffdaily, https://github.com/xw285cornell

(cherry picked from commit 637ab85)

Co-authored-by: hongxyan <hongxyan@amd.com>
pytorchmergebot pushed a commit that referenced this pull request Aug 20, 2025
Currently std::min -> ::min did not work as expected on ROCm when input values >= 2147483648

Replace `std::min` to ternary statement
Also `std::min` can be replaced by explicit typing `std::min<int64_t>`

fixes on ROCm:
test_sort_and_select.py::TestSortAndSelectCUDA::test_sort_large_cuda_float16
error:
RuntimeError: Cannot sort dimension of length 8192

Similar PR to fix large tensors on ROCm #130994

Pull Request resolved: #161054
Approved by: https://github.com/jeffdaily
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: cuda release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[rocm] F.embedding reports invalid configuration argument
9 participants