-
Notifications
You must be signed in to change notification settings - Fork 2k
Streaming model outputs #1236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Streaming model outputs #1236
Conversation
aaba9d4
to
fc4546a
Compare
@@ -184,7 +185,7 @@ class MultiStepAgent(ABC): | |||
def __init__( | |||
self, | |||
tools: List[Tool], | |||
model: Callable[[List[Dict[str, str]]], ChatMessage], | |||
model: Model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we add the option to stream model outputs, model
can not just be a Callable
returning a ChatMessage
.
This means we'l have to edit the parts of the doc that show how to create a Model
, to explain how to inherit from the base Model
class instead of directly creating a Callable
.
src/smolagents/agents.py
Outdated
@@ -340,7 +342,7 @@ def run( | |||
|
|||
def _run( | |||
self, task: str, max_steps: int, images: List["PIL.Image.Image"] | None = None | |||
) -> Generator[ActionStep | AgentType, None, None]: | |||
) -> Generator[ActionStep | FinalAnswerStep, None, None]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixing this type hint.
@@ -350,12 +352,14 @@ def _run( | |||
if self.planning_interval is not None and ( | |||
self.step_number == 1 or (self.step_number - 1) % self.planning_interval == 0 | |||
): | |||
planning_step = self._create_planning_step( | |||
planning_step = self._generate_planning_step( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Generate" is better IMO since an LLM output is generated in this function: it's not simply about creating an empty object.
@@ -375,9 +379,6 @@ def _run( | |||
yield action_step | |||
yield FinalAnswerStep(handle_agent_output_types(final_answer)) | |||
|
|||
def _create_action_step(self, step_start_time: float, images: List["PIL.Image.Image"] | None) -> ActionStep: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method is not useful anymore and obscures the workflow
src/smolagents/agents.py
Outdated
except Exception as e: | ||
raise AgentParsingError(f"Error while generating or parsing output:\n{e}", self.logger) from e | ||
if self.stream_outputs: | ||
raise NotImplementedError("Stream outputs are not yet implemented for ToolCallingAgent") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Streaming output with ToolCallingAgent
implies streaming ChoiceDeltaToolCallFunction
objects from various APIs, which is worth another PR.
@@ -44,7 +44,7 @@ class MemoryStep: | |||
def dict(self): | |||
return asdict(self) | |||
|
|||
def to_messages(self, **kwargs) -> List[Dict[str, Any]]: | |||
def to_messages(self, summary_mode: bool = False) -> List[Message]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Harmonize the API for all to_messages
methods
response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}), | ||
raw=response, | ||
) | ||
return self.postprocess_message(first_message, tools_to_call_from) | ||
|
||
def generate_stream( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New generate_stream
methods. once we've setup streaming for ToolCallingAgent
, the generate
method will simply be able to call generate_stream
and return the final completion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contributions! There's a lot of great work here. Having so many changes bundled into a single PR does make it a bit challenging to review thoroughly, but I appreciate the effort.
These are just some initial comments, I’ll continue reviewing the rest of the PR shortly, once you tell me no more changes are coming in...
@@ -377,7 +391,26 @@ def __call__( | |||
Returns: | |||
`ChatMessage`: A chat message object containing the model's response. | |||
""" | |||
pass # To be implemented in child classes! | |||
raise NotImplementedError("This method must be implemented in child classes") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about defining Model as an abstract class and decorating this method as abstractmethod?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For generate
we could! For generate_stream
however, it will sometimes be implemented by child classes, sometimes not, so making it an abstract method would prevent proper intialization. Do we prefer to make only generate
an abstract method, or keep the common implementation by only raising NotImplementedError
in both methods?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can delete generate_stream
here (see comment above).
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@albertvillanova it's only minor changes now, you can review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments to maintain backward compatibility: users may pass a Callable as model.
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another batch of comments...
Sorry, difficult to go trough more than 1,000 modified lines...
|
||
|
||
def has_implemented_method(instance, parent_class, method_name: str) -> bool: | ||
instance_method = getattr(instance.__class__, method_name, None) | ||
parent_method = getattr(parent_class, method_name, None) | ||
return instance_method is not parent_method |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No longer used:
def has_implemented_method(instance, parent_class, method_name: str) -> bool: | |
instance_method = getattr(instance.__class__, method_name, None) | |
parent_method = getattr(parent_class, method_name, None) | |
return instance_method is not parent_method |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was a mistake to not be using it, we do need it as a check in the init: just reintroduced it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a very hacky way to check if the model has the generate_stream
method.
My suggestion:
- as this method is optional, the parent
Model
should not have it (see discussion about settinggenerate
as abstractmethod, but notgenerate_stream
: Streaming model outputs #1236 (comment)). - we can remove this hacky method
- we can just check if the model hast
generate_stream
method:hasattr(self.model, "generate_stream")
**completion_kwargs, stream=True, stream_options={"include_usage": True} | ||
): | ||
if event.choices: | ||
if event.choices[0].delta is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you tested this? I'm wondering if event.choices[0].delta
can be None or it is always a class instance.
Anyway, maybe we could add some tests for generate_stream
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just aded tests for generate_stream in LiteLLMModel
, InferenceClientModel
, and TransformersModel
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have manually checked your tests for Transformers and InferenceClient and the condition event.choices[0].delta is None
is never fulfilled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you checked it with LiteLLMModel
, gpt-4?
def parse_tool_calls(self, message: ChatMessage) -> ChatMessage: | ||
"""Sometimes APIs do not return the tool call as a specific object, so we need to parse it.""" | ||
message.role = MessageRole.ASSISTANT # Overwrite role if needed | ||
if not message.tool_calls: | ||
assert message.content is not None, "Message contains no content and no tool calls" | ||
message.tool_calls = [ | ||
get_tool_call_from_text(message.content, self.tool_name_key, self.tool_arguments_key) | ||
] | ||
assert len(message.tool_calls) > 0, "No tool call was found in the model output" | ||
for tool_call in message.tool_calls: | ||
tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments) | ||
return message |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this for streaming and we didn't need before? Maybe I'm missing something...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's more a simplification: cf this comment
from vllm import LLM # type: ignore | ||
from vllm.transformers_utils.tokenizer import get_tokenizer # type: ignore | ||
|
||
self.model_kwargs = { | ||
**(model_kwargs or {}), | ||
"model": model_id, | ||
} | ||
self.model_kwargs = model_kwargs or {} | ||
super().__init__(**kwargs) | ||
self.model_id = model_id | ||
self.model = LLM(**self.model_kwargs) | ||
self.model = LLM(model=model_id, **self.model_kwargs) | ||
assert self.model is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these changes are not related to streaming. But is it necessary the assert here? I mean, any model is prone to receiving a None
as model_id
...
self.tokenizer = get_tokenizer(model_id) | ||
self._is_vlm = False # VLLMModel does not support vision models yet. | ||
|
||
def cleanup(self): | ||
import gc | ||
|
||
import torch | ||
from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel | ||
from vllm.distributed.parallel_state import ( # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this # type: ignore
? We are not enforcing static type checking...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will improve readability for everyone using static type checking. If you're against that we can also remove it!
src/smolagents/models.py
Outdated
for message in messages: | ||
if not isinstance(message["content"], str): | ||
message["content"] = message["content"][0]["text"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this now and not before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was a dirty fix for an error that I missed: just fixed it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another batch of reviews.
Thanks for your contribution.
src/smolagents/models.py
Outdated
for event in self.client.completion(**completion_kwargs, stream=True, stream_options={"include_usage": True}): | ||
if event.choices: | ||
if event.choices[0].delta is None: | ||
if not event.choices[0].finish_reason == "stop": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplify the logic:
if not event.choices[0].finish_reason == "stop": | |
if event.choices[0].finish_reason != "stop": |
yield CompletionDelta( | ||
content=event.choices[0].delta.content, | ||
) | ||
if getattr(event, "usage", None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This condition can only happen if the condition above is False, is this assumption right?
if getattr(event, "usage", None): | |
elif getattr(event, "usage", None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe some messages contain both some content and usage, so we would need to catch both using the double if
instead of if/elif
.
if tools_to_call_from: | ||
chat_message.tool_calls = [ | ||
get_tool_call_from_text(output_text, self.tool_name_key, self.tool_arguments_key) | ||
] | ||
return chat_message |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we no longer need to set .tool_calls
attribute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because now this will be handled directly in the ToolCallingAgent.step
method by parse_tool_calls
!
|
||
default_max_tokens = 5000 | ||
default_max_tokens = 4096 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason fir this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Powers of 2 are always better!
@@ -787,44 +825,51 @@ def __call__( | |||
or kwargs.get("max_tokens") | |||
or self.kwargs.get("max_new_tokens") | |||
or self.kwargs.get("max_tokens") | |||
or 1024 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to hardcode this value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually not sure: in case it's not filled, should we leave this to the underlying model/API?
"""Sometimes APIs do not return the tool call as a specific object, so we need to parse it.""" | ||
message.role = MessageRole.ASSISTANT # Overwrite role if needed | ||
if not message.tool_calls: | ||
assert message.content is not None, "Message contains no content and no tool calls" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Differently from before, now we can raise an error here. Is this intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes: either the model returns a tool call, either it returns some text, but it should at least return one.
message.tool_calls = [ | ||
get_tool_call_from_text(message.content, self.tool_name_key, self.tool_arguments_key) | ||
] | ||
assert len(message.tool_calls) > 0, "No tool call was found in the model output" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Differently from before, now we can raise an error here. Is this intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes: it will help the model correct its output!
def __call__(self, *args, **kwargs): | ||
return self.generate(*args, **kwargs) | ||
|
||
def parse_tool_calls(self, message: ChatMessage) -> ChatMessage: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function seems to replace the previous postprocess_message
. However, this new function is only called by ToolCallingAgent.step
, whereas the previous postprocess_message
was called by all API models (__call__
method). Is this intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes: the idea is that we now more clearly separate:
- Generation: the
Model
just generates text. Sometimes, depending on the API/Model, it can contain pre-defined tool_calls in the dedicated attribute. - Parsing:
postprocess_message
, which will if there's no tool call so far, fill the tool_calls attribute using tool calls parsed from the text.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another batch of reviews done before your today modifications.
prompt_templates: Optional[PromptTemplates] = None, | ||
grammar: Optional[Dict[str, str]] = None, | ||
additional_authorized_imports: Optional[List[str]] = None, | ||
planning_interval: Optional[int] = None, | ||
executor_type: str | None = "local", | ||
executor_kwargs: Optional[Dict[str, Any]] = None, | ||
max_print_outputs_length: Optional[int] = None, | ||
stream_outputs: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about calling the param just stream
, as in the OpenAI spec for Chat completion create?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a difficult question: for a chat completion, stream
is obviously about streaming model outputs.
For an agent, what do you stream: agent steps? (as in agent.run()
with stream=True
)
Since here it's about streaming outputs, I put that in the name stream_outputs
. but maybe there's an even more intuitive API.
prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates. | ||
grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output. | ||
additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent. | ||
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step. | ||
executor_type (`str`, default `"local"`): Which executor type to use between `"local"`, `"e2b"`, or `"docker"`. | ||
executor_kwargs (`dict`, *optional*): Additional arguments to pass to initialize the executor. | ||
max_print_outputs_length (`int`, *optional*): Maximum length of the print outputs. | ||
stream_outputs (`bool`, *optional*, default `False`): Whether to stream outputs during execution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In docstrings, optional means default None.
stream_outputs (`bool`, *optional*, default `False`): Whether to stream outputs during execution. | |
stream_outputs (`bool`, default `False`): Whether to stream outputs during execution. |
|
||
|
||
def has_implemented_method(instance, parent_class, method_name: str) -> bool: | ||
instance_method = getattr(instance.__class__, method_name, None) | ||
parent_method = getattr(parent_class, method_name, None) | ||
return instance_method is not parent_method |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a very hacky way to check if the model has the generate_stream
method.
My suggestion:
- as this method is optional, the parent
Model
should not have it (see discussion about settinggenerate
as abstractmethod, but notgenerate_stream
: Streaming model outputs #1236 (comment)). - we can remove this hacky method
- we can just check if the model hast
generate_stream
method:hasattr(self.model, "generate_stream")
@@ -377,7 +391,26 @@ def __call__( | |||
Returns: | |||
`ChatMessage`: A chat message object containing the model's response. | |||
""" | |||
pass # To be implemented in child classes! | |||
raise NotImplementedError("This method must be implemented in child classes") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can delete generate_stream
here (see comment above).
def generate_stream(self, *args, **kwargs) -> Generator[CompletionDelta, None, None]: | ||
raise NotImplementedError("This method must be implemented in child classes") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def generate_stream(self, *args, **kwargs) -> Generator[CompletionDelta, None, None]: | |
raise NotImplementedError("This method must be implemented in child classes") |
self.stream_outputs = stream_outputs | ||
can_stream = has_implemented_method(self.model, Model, "generate_stream") | ||
if self.stream_outputs and not can_stream: | ||
raise ValueError( | ||
"`stream_outputs` is set to True, but the model class implements no `generate_stream` method." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.stream_outputs = stream_outputs | |
can_stream = has_implemented_method(self.model, Model, "generate_stream") | |
if self.stream_outputs and not can_stream: | |
raise ValueError( | |
"`stream_outputs` is set to True, but the model class implements no `generate_stream` method." | |
) | |
if stream_outputs and not hasattr(self.model, "generate_stream"): | |
raise ValueError( | |
"`stream_outputs` is set to True, but the model class implements no `generate_stream` method." | |
) | |
self.stream_outputs = stream_outputs |
**completion_kwargs, stream=True, stream_options={"include_usage": True} | ||
): | ||
if event.choices: | ||
if event.choices[0].delta is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have manually checked your tests for Transformers and InferenceClient and the condition event.choices[0].delta is None
is never fulfilled.
Implement streaming model outputs, to let user see the thoughts of their model displaying live.
Tested for:
Streaming was not implemented, left for future PRs: