Skip to content

Reduce the number of times load_weights is called in GhostTrainer's advance function #4931

@nolan-dev

Description

@nolan-dev

Is your feature request related to a problem? Please describe.
When I train policies with a large number of weights and self-play, a lot of time is spent in the load_weights function called from here: https://github.com/Unity-Technologies/ml-agents/blob/master/ml-agents/mlagents/trainers/ghost/trainer.py#L277
I'm still familiarizing myself with the code, but intuitively it doesn't seem like it would be necessary to load weights every time advance is called, which appears to be what's happening now.

Describe the solution you'd like
Here's the code where load weights is called frequently:

            try:
                policy = internal_policy_queue.get_nowait()
                self.current_policy_snapshot[brain_name] = policy.get_weights()
            except AgentManagerQueue.Empty:
                pass
            if next_learning_team in self._team_to_name_to_policy_queue:
                name_to_policy_queue = self._team_to_name_to_policy_queue[
                    next_learning_team
                ]
                if brain_name in name_to_policy_queue:
                    behavior_id = create_name_behavior_id(
                        brain_name, next_learning_team
                    )
                    policy = self.get_policy(behavior_id)
                    policy.load_weights(self.current_policy_snapshot[brain_name])
                    name_to_policy_queue[brain_name].put(policy)

My current impression is that name_to_policy_queue[brain_name].put(policy) only needs to be called when there's a policy update (and that only occurs when the internal policy queue has a policy in it), in which case the solution may be to replace

except AgentManagerQueue.Empty:
    pass

with

except AgentManagerQueue.Empty:
    continue

when I do that, I get around a 30% speed increase. However, I haven't spent enough time with the mlagents code to be sure that doesn't change the functionality at all.

Describe alternatives you've considered
Unfortunately, I don't think there's a way around load_weights being an expensive function for models with a large number of weights.

Thanks

Metadata

Metadata

Assignees

Labels

requestIssue contains a feature request.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions