-
Notifications
You must be signed in to change notification settings - Fork 74.8k
Closed
Labels
TF 2.4for issues related to TF 2.4for issues related to TF 2.4comp:kerasKeras related issuesKeras related issuesstaleThis label marks the issue/pr stale - to be closed automatically if no activityThis label marks the issue/pr stale - to be closed automatically if no activitystat:awaiting responseStatus - Awaiting response from authorStatus - Awaiting response from authortype:bugBugBug
Description
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): No
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Windows, Ubuntu 18.04
- TensorFlow installed from (source or binary): binary
- TensorFlow version (use command below): 2.4.0, nightly
- Python version: 3.7
- CUDA/cuDNN version: 11.0, 8.0.4
- GPU model and memory: 1060, 6GB
Describe the current behavior
When loading weights with skip_mismatch=True
, and if the stored weights don't match with the current layer weights,
the warning seems to show misleading shapes for the stored weights.
Describe the expected behavior
I have attached a minimal code to produce the warning. The warning seems confusing. The expected shapes of weights shown in
the warning should be ((3, 4, 7, 32) vs (3, 4, 7, 16))
rather than ((3, 4, 7, 32) vs (16, 7, 3, 4))
. I suspect the transpose here is causing the change of shape.
tensorflow/tensorflow/python/keras/saving/hdf5_format.py
Lines 407 to 411 in 3c84fc9
if layer.__class__.__name__ in conv_layers: | |
if K.int_shape(layer.weights[0]) != weights[0].shape: | |
weights[0] = np.transpose(weights[0], (3, 2, 0, 1)) | |
if layer.__class__.__name__ == 'ConvLSTM2D': | |
weights[1] = np.transpose(weights[1], (3, 2, 0, 1)) |
Standalone code to reproduce the issue
import tensorflow as tf
model = tf.keras.Sequential(
tf.keras.layers.Conv2D(16, [3, 4], use_bias=False, input_shape=[10, 10, 7]))
model.save_weights('test.h5')
tf.keras.backend.clear_session()
new_model = tf.keras.Sequential(
tf.keras.layers.Conv2D(32, [3, 4], use_bias=False, input_shape=[10, 10, 7]))
new_model.load_weights('test.h5', skip_mismatch=True, by_name=True)
WARNING:tensorflow:Skipping loading of weights for layer conv2d due to mismatch in shape ((3, 4, 7, 32) vs (16, 7, 3, 4)).
Metadata
Metadata
Assignees
Labels
TF 2.4for issues related to TF 2.4for issues related to TF 2.4comp:kerasKeras related issuesKeras related issuesstaleThis label marks the issue/pr stale - to be closed automatically if no activityThis label marks the issue/pr stale - to be closed automatically if no activitystat:awaiting responseStatus - Awaiting response from authorStatus - Awaiting response from authortype:bugBugBug