-
Notifications
You must be signed in to change notification settings - Fork 108
Description
I am opening this issue because apparently depending on which version of pytorch you are using, the training result will be different. Here are the 3px error evaluation curves of on a minimal example of overfitting the network on a single image for 300 epochs:
The purple line is trained with Pytorch 1.7.0 and the orange line is trained with Pytorch 1.5.1. As you can see, with version 1.7.0 the error rate is flat 100%, while version 1.5.1 the error rate is dropping. Reason for this is that the BatchNorm function has changed between version 1.5.1 and Pytorch 1.7.0. In version 1.5.1, if I disable track_running_stats
here, both evaluation and training will use batch stats. However in Pytorch 1.7.0, it is forced to use running_mean
and running_var
in evaluation mode, while in training the batch stats is used. With track_running_stats
disabled, the running_mean
is 0 and running_var
is 1, which is clearly different from the batch stats.
Therefore, instead of trying to do something against torch's implementation, I will recommend to use Pytorch 1.5.1 if you want to retrain from scratch. Otherwise, if you want to use other Pytorch version, you can replace all BatchNorm with InstanceNorm and port the learnt values from BatchNorm (i.e. weight and bias). This is a wontfix
problem because it is quite hard to accomodate all torch versions.