Skip to content

Torch version affects the network's training performance #8

@mli0603

Description

@mli0603

I am opening this issue because apparently depending on which version of pytorch you are using, the training result will be different. Here are the 3px error evaluation curves of on a minimal example of overfitting the network on a single image for 300 epochs:

Screenshot from 2020-12-22 20-06-19

The purple line is trained with Pytorch 1.7.0 and the orange line is trained with Pytorch 1.5.1. As you can see, with version 1.7.0 the error rate is flat 100%, while version 1.5.1 the error rate is dropping. Reason for this is that the BatchNorm function has changed between version 1.5.1 and Pytorch 1.7.0. In version 1.5.1, if I disable track_running_stats here, both evaluation and training will use batch stats. However in Pytorch 1.7.0, it is forced to use running_mean and running_var in evaluation mode, while in training the batch stats is used. With track_running_stats disabled, the running_mean is 0 and running_var is 1, which is clearly different from the batch stats.

Therefore, instead of trying to do something against torch's implementation, I will recommend to use Pytorch 1.5.1 if you want to retrain from scratch. Otherwise, if you want to use other Pytorch version, you can replace all BatchNorm with InstanceNorm and port the learnt values from BatchNorm (i.e. weight and bias). This is a wontfix problem because it is quite hard to accomodate all torch versions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    wontfixThis will not be worked on

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions