Skip to content

Conversation

hongxiayang
Copy link
Collaborator

@hongxiayang hongxiayang commented Jul 24, 2024

As observed during working on this fix (#130994), 128 threads per block seems quite low. This PR is to increase the default to improve the performance, and also slightly refactoring the code to replace the hard-coded 128 for better maintenance.

By increasing the default max threads per block from 128 to 256, I saw for aten::index_select, its "CUDA total" time drop from 44.820ms to 33.608ms by profiling below embedding script:

input = torch.randint(low=0, high=16032, size=[131072], device="cuda")
w = torch.randn([16032, 16384], device="cuda")

with profiler.profile(record_shapes=True) as prof:
    x = torch.nn.functional.embedding(input, w)

I tested with the default from 128 to 256, 512, 1024 on several different types of devices, and observed "CUDA total" time dropping even more and more latency improvement as the number increases. Below is one example of latency improvement ratio:
128 | 1x
256 | 1.33x
512 | 1.44x
1024 | 1.49x

Using 512 as the new default max for non-mi300x to be conservative, which is 1.44x faster than using 128 with the above profiling script.

Using 1024 for mi300x is 1.61x faster than using 128 with the same profiling script, and using 512 is 1.57x faster.

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo

Copy link

pytorch-bot bot commented Jul 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/131713

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 348162f with merge base d355678 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/rocm Trigger "default" config CI on ROCm module: rocm AMD GPU support for Pytorch release notes: cuda release notes category labels Jul 24, 2024
Copy link
Collaborator

@jeffdaily jeffdaily left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One nit change and one necessary change requested, but otherwise LGTM if CI is green.

@jeffdaily jeffdaily changed the title [ROCm] performance optimization by increasing the default max threads… [ROCm] performance optimization for index select Jul 25, 2024
@hongxiayang hongxiayang marked this pull request as ready for review July 26, 2024 01:48
@hongxiayang hongxiayang requested a review from eqy as a code owner July 26, 2024 01:48
@hongxiayang hongxiayang requested a review from jeffdaily July 26, 2024 01:48
@hongxiayang
Copy link
Collaborator Author

cc @xw285cornell

@hongxiayang
Copy link
Collaborator Author

hongxiayang commented Jul 26, 2024

@eqy Please let me know whether you are comfortable for me to change the default to 256 for Nvidia case. We can change all the places of 128 to 256 as a follow up.

@hongxiayang hongxiayang requested a review from syed-ahmed as a code owner July 30, 2024 17:51
Copy link
Collaborator

@syed-ahmed syed-ahmed left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requested a small change. Otherwise, LGTM.

@hongxiayang hongxiayang requested a review from syed-ahmed July 30, 2024 19:05
@hongxiayang
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 31, 2024
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Approvers from one of the following sets are needed:

  • superuser (pytorch/metamates)
  • Core Reviewers (mruberry, lezcano, Skylion007, ngimel, peterbell10)
  • Core Maintainers (soumith, gchanan, ezyang, dzhulgakov, malfet)
Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@jithunnair-amd
Copy link
Collaborator

@malfet Can you please review and approve this PR?

@malfet
Copy link
Contributor

malfet commented Jul 31, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged module: rocm AMD GPU support for Pytorch open source release notes: cuda release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants