-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Description
Is your feature request related to a problem? Please describe.
When I train policies with a large number of weights and self-play, a lot of time is spent in the load_weights function called from here: https://github.com/Unity-Technologies/ml-agents/blob/master/ml-agents/mlagents/trainers/ghost/trainer.py#L277
I'm still familiarizing myself with the code, but intuitively it doesn't seem like it would be necessary to load weights every time advance is called, which appears to be what's happening now.
Describe the solution you'd like
Here's the code where load weights is called frequently:
try:
policy = internal_policy_queue.get_nowait()
self.current_policy_snapshot[brain_name] = policy.get_weights()
except AgentManagerQueue.Empty:
pass
if next_learning_team in self._team_to_name_to_policy_queue:
name_to_policy_queue = self._team_to_name_to_policy_queue[
next_learning_team
]
if brain_name in name_to_policy_queue:
behavior_id = create_name_behavior_id(
brain_name, next_learning_team
)
policy = self.get_policy(behavior_id)
policy.load_weights(self.current_policy_snapshot[brain_name])
name_to_policy_queue[brain_name].put(policy)
My current impression is that name_to_policy_queue[brain_name].put(policy) only needs to be called when there's a policy update (and that only occurs when the internal policy queue has a policy in it), in which case the solution may be to replace
except AgentManagerQueue.Empty:
pass
with
except AgentManagerQueue.Empty:
continue
when I do that, I get around a 30% speed increase. However, I haven't spent enough time with the mlagents code to be sure that doesn't change the functionality at all.
Describe alternatives you've considered
Unfortunately, I don't think there's a way around load_weights being an expensive function for models with a large number of weights.
Thanks