-
Notifications
You must be signed in to change notification settings - Fork 6.2k
[tests] refactor vae tests #9808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -1467,7 +1467,7 @@ def forward( | |||
z = posterior.sample(generator=generator) | |||
else: | |||
z = posterior.mode() | |||
dec = self.decode(z) | |||
dec = self.decode(z).sample |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise we return a tuple of DecoderOutput
when return_dict=False
.
sample_size = ( | ||
self.config.sample_size[0] | ||
if isinstance(self.config.sample_size, (list, tuple)) | ||
else self.config.sample_size | ||
) | ||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) | ||
self.tile_overlap_factor = 0.25 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused.
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)] | ||
output = [ | ||
self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x_slice) for x_slice in x.split(1) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should use x_slice
and not x
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could maybe further refactor this to how the current implementations of Cog/Mochi are with _decode
method. A bit easier to understand code flow that way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah sure feel free to club those in your PR.
tests/models/autoencoders/test_models_autoencoder_kl_allegro.py
Outdated
Show resolved
Hide resolved
@DN6 a gentle ping. |
@@ -433,7 +433,7 @@ def create_forward(*inputs): | |||
hidden_states, | |||
temb, | |||
zq, | |||
conv_cache=conv_cache.get(conv_cache_key), | |||
conv_cache.get(conv_cache_key), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the torch.utils.checkpoint.checkpoint()
method doesn't have any conv_cache
argument.
if self.model_class.__name__ in [ | ||
"UNetSpatioTemporalConditionModel", | ||
"AutoencoderKLTemporalDecoder", | ||
]: | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because these are supported.
@a-r-r-o-w @DN6 a gentle ping. |
@a-r-r-o-w merging this to unblock you and will let you add any left over tests. Hopefully, that is okay. |
* add: autoencoderkl tests * autoencodertiny. * fix * asymmetric autoencoder. * more * integration tests for stable audio decoder. * consistency decoder vae tests * remove grad check from consistency decoder. * cog * bye test_models_vae.py * fix * fix * remove allegro * fixes * fixes * fixes --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
What does this PR do?
Internal thread: https://huggingface.slack.com/archives/C065E480NN9/p1730203711189419.
Tears apart
test_models_vae.py
to break the tests in accordance with the Autoencoder model classes we have undersrc/diffusers/models/autoencoders.
Didn't include Allegro as it's undergoing some refactoring love from Aryan. Discussed internally.
Some comments inline.