Skip to content

Streaming model outputs #1236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Apr 24, 2025
Merged

Streaming model outputs #1236

merged 37 commits into from
Apr 24, 2025

Conversation

aymeric-roucher
Copy link
Collaborator

@aymeric-roucher aymeric-roucher commented Apr 22, 2025

Implement streaming model outputs, to let user see the thoughts of their model displaying live.

Tested for:

  • OpenAI
  • InferenceProviders
  • LiteLLM
  • TransformersModel

Streaming was not implemented, left for future PRs:

  • VLLMModel
  • MLXModel
  • AzureOpenAIServerModel
  • AmazonBedrockServerModel

@aymeric-roucher aymeric-roucher force-pushed the streaming-model-outputs branch from aaba9d4 to fc4546a Compare April 22, 2025 14:13
@@ -184,7 +185,7 @@ class MultiStepAgent(ABC):
def __init__(
self,
tools: List[Tool],
model: Callable[[List[Dict[str, str]]], ChatMessage],
model: Model,
Copy link
Collaborator Author

@aymeric-roucher aymeric-roucher Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we add the option to stream model outputs, model can not just be a Callable returning a ChatMessage.
This means we'l have to edit the parts of the doc that show how to create a Model, to explain how to inherit from the base Model class instead of directly creating a Callable.

@@ -340,7 +342,7 @@ def run(

def _run(
self, task: str, max_steps: int, images: List["PIL.Image.Image"] | None = None
) -> Generator[ActionStep | AgentType, None, None]:
) -> Generator[ActionStep | FinalAnswerStep, None, None]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixing this type hint.

@@ -350,12 +352,14 @@ def _run(
if self.planning_interval is not None and (
self.step_number == 1 or (self.step_number - 1) % self.planning_interval == 0
):
planning_step = self._create_planning_step(
planning_step = self._generate_planning_step(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Generate" is better IMO since an LLM output is generated in this function: it's not simply about creating an empty object.

@@ -375,9 +379,6 @@ def _run(
yield action_step
yield FinalAnswerStep(handle_agent_output_types(final_answer))

def _create_action_step(self, step_start_time: float, images: List["PIL.Image.Image"] | None) -> ActionStep:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is not useful anymore and obscures the workflow

except Exception as e:
raise AgentParsingError(f"Error while generating or parsing output:\n{e}", self.logger) from e
if self.stream_outputs:
raise NotImplementedError("Stream outputs are not yet implemented for ToolCallingAgent")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Streaming output with ToolCallingAgent implies streaming ChoiceDeltaToolCallFunction objects from various APIs, which is worth another PR.

@@ -44,7 +44,7 @@ class MemoryStep:
def dict(self):
return asdict(self)

def to_messages(self, **kwargs) -> List[Dict[str, Any]]:
def to_messages(self, summary_mode: bool = False) -> List[Message]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Harmonize the API for all to_messages methods

response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}),
raw=response,
)
return self.postprocess_message(first_message, tools_to_call_from)

def generate_stream(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New generate_stream methods. once we've setup streaming for ToolCallingAgent, the generate method will simply be able to call generate_stream and return the final completion.

@aymeric-roucher aymeric-roucher marked this pull request as ready for review April 22, 2025 18:23
Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contributions! There's a lot of great work here. Having so many changes bundled into a single PR does make it a bit challenging to review thoroughly, but I appreciate the effort.

These are just some initial comments, I’ll continue reviewing the rest of the PR shortly, once you tell me no more changes are coming in...

@@ -377,7 +391,26 @@ def __call__(
Returns:
`ChatMessage`: A chat message object containing the model's response.
"""
pass # To be implemented in child classes!
raise NotImplementedError("This method must be implemented in child classes")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about defining Model as an abstract class and decorating this method as abstractmethod?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For generate we could! For generate_stream however, it will sometimes be implemented by child classes, sometimes not, so making it an abstract method would prevent proper intialization. Do we prefer to make only generate an abstract method, or keep the common implementation by only raising NotImplementedError in both methods?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can delete generate_stream here (see comment above).

aymeric-roucher and others added 2 commits April 23, 2025 11:41
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@aymeric-roucher
Copy link
Collaborator Author

@albertvillanova it's only minor changes now, you can review

Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments to maintain backward compatibility: users may pass a Callable as model.

aymeric-roucher and others added 5 commits April 23, 2025 15:11
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another batch of comments...
Sorry, difficult to go trough more than 1,000 modified lines...

Comment on lines +457 to +462


def has_implemented_method(instance, parent_class, method_name: str) -> bool:
instance_method = getattr(instance.__class__, method_name, None)
parent_method = getattr(parent_class, method_name, None)
return instance_method is not parent_method
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer used:

Suggested change
def has_implemented_method(instance, parent_class, method_name: str) -> bool:
instance_method = getattr(instance.__class__, method_name, None)
parent_method = getattr(parent_class, method_name, None)
return instance_method is not parent_method

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a mistake to not be using it, we do need it as a check in the init: just reintroduced it!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a very hacky way to check if the model has the generate_stream method.

My suggestion:

  • as this method is optional, the parent Model should not have it (see discussion about setting generate as abstractmethod, but not generate_stream: Streaming model outputs #1236 (comment)).
  • we can remove this hacky method
  • we can just check if the model hast generate_stream method: hasattr(self.model, "generate_stream")

**completion_kwargs, stream=True, stream_options={"include_usage": True}
):
if event.choices:
if event.choices[0].delta is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested this? I'm wondering if event.choices[0].delta can be None or it is always a class instance.

Anyway, maybe we could add some tests for generate_stream.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just aded tests for generate_stream in LiteLLMModel, InferenceClientModel, and TransformersModel.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have manually checked your tests for Transformers and InferenceClient and the condition event.choices[0].delta is None is never fulfilled.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you checked it with LiteLLMModel, gpt-4?

Comment on lines +402 to +413
def parse_tool_calls(self, message: ChatMessage) -> ChatMessage:
"""Sometimes APIs do not return the tool call as a specific object, so we need to parse it."""
message.role = MessageRole.ASSISTANT # Overwrite role if needed
if not message.tool_calls:
assert message.content is not None, "Message contains no content and no tool calls"
message.tool_calls = [
get_tool_call_from_text(message.content, self.tool_name_key, self.tool_arguments_key)
]
assert len(message.tool_calls) > 0, "No tool call was found in the model output"
for tool_call in message.tool_calls:
tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments)
return message
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this for streaming and we didn't need before? Maybe I'm missing something...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's more a simplification: cf this comment

Comment on lines +483 to +490
from vllm import LLM # type: ignore
from vllm.transformers_utils.tokenizer import get_tokenizer # type: ignore

self.model_kwargs = {
**(model_kwargs or {}),
"model": model_id,
}
self.model_kwargs = model_kwargs or {}
super().__init__(**kwargs)
self.model_id = model_id
self.model = LLM(**self.model_kwargs)
self.model = LLM(model=model_id, **self.model_kwargs)
assert self.model is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these changes are not related to streaming. But is it necessary the assert here? I mean, any model is prone to receiving a None as model_id...

self.tokenizer = get_tokenizer(model_id)
self._is_vlm = False # VLLMModel does not support vision models yet.

def cleanup(self):
import gc

import torch
from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel
from vllm.distributed.parallel_state import ( # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this # type: ignore? We are not enforcing static type checking...

Copy link
Collaborator Author

@aymeric-roucher aymeric-roucher Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will improve readability for everyone using static type checking. If you're against that we can also remove it!

Comment on lines 530 to 532
for message in messages:
if not isinstance(message["content"], str):
message["content"] = message["content"][0]["text"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this now and not before?

Copy link
Collaborator Author

@aymeric-roucher aymeric-roucher Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a dirty fix for an error that I missed: just fixed it.

Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another batch of reviews.

Thanks for your contribution.

for event in self.client.completion(**completion_kwargs, stream=True, stream_options={"include_usage": True}):
if event.choices:
if event.choices[0].delta is None:
if not event.choices[0].finish_reason == "stop":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify the logic:

Suggested change
if not event.choices[0].finish_reason == "stop":
if event.choices[0].finish_reason != "stop":

yield CompletionDelta(
content=event.choices[0].delta.content,
)
if getattr(event, "usage", None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition can only happen if the condition above is False, is this assumption right?

Suggested change
if getattr(event, "usage", None):
elif getattr(event, "usage", None):

Copy link
Collaborator Author

@aymeric-roucher aymeric-roucher Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe some messages contain both some content and usage, so we would need to catch both using the double if instead of if/elif.

Comment on lines -533 to -537
if tools_to_call_from:
chat_message.tool_calls = [
get_tool_call_from_text(output_text, self.tool_name_key, self.tool_arguments_key)
]
return chat_message
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we no longer need to set .tool_calls attribute?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because now this will be handled directly in the ToolCallingAgent.step method by parse_tool_calls!


default_max_tokens = 5000
default_max_tokens = 4096
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason fir this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Powers of 2 are always better!

@@ -787,44 +825,51 @@ def __call__(
or kwargs.get("max_tokens")
or self.kwargs.get("max_new_tokens")
or self.kwargs.get("max_tokens")
or 1024
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to hardcode this value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually not sure: in case it's not filled, should we leave this to the underlying model/API?

"""Sometimes APIs do not return the tool call as a specific object, so we need to parse it."""
message.role = MessageRole.ASSISTANT # Overwrite role if needed
if not message.tool_calls:
assert message.content is not None, "Message contains no content and no tool calls"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Differently from before, now we can raise an error here. Is this intended?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes: either the model returns a tool call, either it returns some text, but it should at least return one.

message.tool_calls = [
get_tool_call_from_text(message.content, self.tool_name_key, self.tool_arguments_key)
]
assert len(message.tool_calls) > 0, "No tool call was found in the model output"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Differently from before, now we can raise an error here. Is this intended?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes: it will help the model correct its output!

def __call__(self, *args, **kwargs):
return self.generate(*args, **kwargs)

def parse_tool_calls(self, message: ChatMessage) -> ChatMessage:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function seems to replace the previous postprocess_message. However, this new function is only called by ToolCallingAgent.step, whereas the previous postprocess_message was called by all API models (__call__ method). Is this intended?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes: the idea is that we now more clearly separate:

  1. Generation: the Model just generates text. Sometimes, depending on the API/Model, it can contain pre-defined tool_calls in the dedicated attribute.
  2. Parsing: postprocess_message, which will if there's no tool call so far, fill the tool_calls attribute using tool calls parsed from the text.

@aymeric-roucher aymeric-roucher merged commit 113d8c9 into main Apr 24, 2025
5 checks passed
Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another batch of reviews done before your today modifications.

prompt_templates: Optional[PromptTemplates] = None,
grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None,
planning_interval: Optional[int] = None,
executor_type: str | None = "local",
executor_kwargs: Optional[Dict[str, Any]] = None,
max_print_outputs_length: Optional[int] = None,
stream_outputs: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about calling the param just stream, as in the OpenAI spec for Chat completion create?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a difficult question: for a chat completion, stream is obviously about streaming model outputs.
For an agent, what do you stream: agent steps? (as in agent.run() with stream=True)
Since here it's about streaming outputs, I put that in the name stream_outputs. but maybe there's an even more intuitive API.

prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output.
additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent.
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
executor_type (`str`, default `"local"`): Which executor type to use between `"local"`, `"e2b"`, or `"docker"`.
executor_kwargs (`dict`, *optional*): Additional arguments to pass to initialize the executor.
max_print_outputs_length (`int`, *optional*): Maximum length of the print outputs.
stream_outputs (`bool`, *optional*, default `False`): Whether to stream outputs during execution.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In docstrings, optional means default None.

Suggested change
stream_outputs (`bool`, *optional*, default `False`): Whether to stream outputs during execution.
stream_outputs (`bool`, default `False`): Whether to stream outputs during execution.

Comment on lines +457 to +462


def has_implemented_method(instance, parent_class, method_name: str) -> bool:
instance_method = getattr(instance.__class__, method_name, None)
parent_method = getattr(parent_class, method_name, None)
return instance_method is not parent_method
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a very hacky way to check if the model has the generate_stream method.

My suggestion:

  • as this method is optional, the parent Model should not have it (see discussion about setting generate as abstractmethod, but not generate_stream: Streaming model outputs #1236 (comment)).
  • we can remove this hacky method
  • we can just check if the model hast generate_stream method: hasattr(self.model, "generate_stream")

@@ -377,7 +391,26 @@ def __call__(
Returns:
`ChatMessage`: A chat message object containing the model's response.
"""
pass # To be implemented in child classes!
raise NotImplementedError("This method must be implemented in child classes")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can delete generate_stream here (see comment above).

Comment on lines +396 to +398
def generate_stream(self, *args, **kwargs) -> Generator[CompletionDelta, None, None]:
raise NotImplementedError("This method must be implemented in child classes")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def generate_stream(self, *args, **kwargs) -> Generator[CompletionDelta, None, None]:
raise NotImplementedError("This method must be implemented in child classes")

Comment on lines +1229 to +1234
self.stream_outputs = stream_outputs
can_stream = has_implemented_method(self.model, Model, "generate_stream")
if self.stream_outputs and not can_stream:
raise ValueError(
"`stream_outputs` is set to True, but the model class implements no `generate_stream` method."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.stream_outputs = stream_outputs
can_stream = has_implemented_method(self.model, Model, "generate_stream")
if self.stream_outputs and not can_stream:
raise ValueError(
"`stream_outputs` is set to True, but the model class implements no `generate_stream` method."
)
if stream_outputs and not hasattr(self.model, "generate_stream"):
raise ValueError(
"`stream_outputs` is set to True, but the model class implements no `generate_stream` method."
)
self.stream_outputs = stream_outputs

**completion_kwargs, stream=True, stream_options={"include_usage": True}
):
if event.choices:
if event.choices[0].delta is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have manually checked your tests for Transformers and InferenceClient and the condition event.choices[0].delta is None is never fulfilled.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants