Skip to content

Conversation

ArthurZucker
Copy link
Contributor

@ArthurZucker ArthurZucker commented Dec 19, 2024

Adds support for transformers as a backend

Following huggingface/transformers#35235, a bunch of models should already be supported, we are ramping up support for more models.

Thanks @Isotr0py for the TP support, and @hmellor for his help as well!
This includes:

  • trust_remote_code=True support: any model on the hub, if it implements attention the correct way can be natively supported!!
  • tensor parallel support

ArthurZucker and others added 2 commits December 19, 2024 10:33
Co-authored-by: Isotr0py <41363108+Isotr0py@users.noreply.github.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Dec 19, 2024
@ywang96
Copy link
Member

ywang96 commented Dec 19, 2024

Hello @ArthurZucker! This is very exciting!

I know this PR is still a draft, but could you provide some context on the scope of this effort? Is it to support any model on transformers?

@Isotr0py Isotr0py mentioned this pull request Dec 19, 2024
40 tasks
@ArthurZucker
Copy link
Contributor Author

Yep, overall this should support any model that is supported in transformers, were the cache is "simple" so for now, most of the decoder models and the encoder models for a single modularity!
For MultiModal models, we might need a little bit of extra work, but I thing LLAVA models should work out of the box!

We are refactor our models to make sure it's propagated to as many models as possible!

@ArthurZucker
Copy link
Contributor Author

Might not have time to finish this week, will make it ready for next week 🎄
This should be minimal (no support fort Lora or at least I am not testing it ! This might. need to either call transformers's from pretrain, or replace modules similarly to TP)

@simon-mo simon-mo mentioned this pull request Jan 9, 2025
38 tasks
hmellor added 13 commits January 9, 2025 11:35
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
…orted

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
@hmellor
Copy link
Member

hmellor commented Jan 16, 2025

Benchmarks on A100 using the following command:

python benchmarks/benchmark_throughput.py --backend vllm --model meta-llama/Llama-3.1-8B-Instruct --dataset ShareGPT_V3_unfiltered_cleaned_split.json

Results:

Class Result
LlamaForCausalLM Throughput: 12.88 requests/s, 5325.05 total tokens/s, 2554.02 output tokens/s
TransformersModel Throughput: 11.38 requests/s, 4705.90 total tokens/s, 2257.06 output tokens/s

hmellor and others added 2 commits January 16, 2025 19:08
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) February 1, 2025 06:22
Signed-off-by: Isotr0py <2037008807@qq.com>
@DarkLight1337
Copy link
Member

Please fix the failing tests

Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
@DarkLight1337
Copy link
Member

Please also add the distributed transformers test to the distributed tests CI

Signed-off-by: Isotr0py <2037008807@qq.com>
Isotr0py and others added 2 commits February 2, 2025 10:50
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
@Isotr0py Isotr0py enabled auto-merge (squash) February 2, 2025 02:55
@youkaichao youkaichao disabled auto-merge February 3, 2025 13:30
@youkaichao youkaichao merged commit a1a2aaa into vllm-project:main Feb 3, 2025
69 of 71 checks passed
@DarkLight1337 DarkLight1337 mentioned this pull request Feb 3, 2025
2 tasks
@hmellor hmellor deleted the fix-history branch February 5, 2025 10:34
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
# Adds support for `transformers` as a backend

Following huggingface/transformers#35235, a
bunch of models should already be supported, we are ramping up support
for more models.

Thanks @Isotr0py for the TP support, and @hmellor for his help as well!
This includes: 
- `trust_remote_code=True` support: any model on the hub, if it
implements attention the correct way can be natively supported!!
- tensor parallel support

---------

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <41363108+Isotr0py@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
attention_instances=self.attention_instances,
return_dict=False)[0][0, ...] # we remove batch dimension for now
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker May I know why we want to remove the batch dim here?

Copy link
Member

@hmellor hmellor Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vLLM doesn't expect a batch dimension in the hidden states returned by the model's forward() method

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker Just curious, do you have plan to support models with MLA such as DeepSeek? It seems this integration only supports normal attention right now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes happy to work on every and all model compatibility!

@mergify mergify bot added the new-model Requests to new models label Aug 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation new-model Requests to new models ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants