Skip to content

[Bug] No grad in YourTTS speaker encoder #2348

@Tomiinek

Description

@Tomiinek

Describe the bug

Hello guys (CC: @Edresson @WeberJulian), when going through YourTTS code & paper, I noticed that you are calculating the inputs for the speaker encoder with no grads:

def forward(self, x, l2_norm=False):
"""Forward pass of the model.
Args:
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
to compute the spectrogram on-the-fly.
l2_norm (bool): Whether to L2-normalize the outputs.
Shapes:
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
"""
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
x.squeeze_(1)
# if you torch spec compute it otherwise use the mel spec computed by the AP
if self.use_torch_spec:
x = self.torch_spec(x)
if self.log_input:
x = (x + 1e-6).log()
x = self.instancenorm(x).unsqueeze(1)
x = self.conv1(x)
x = self.relu(x)
x = self.bn1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.reshape(x.size()[0], -1, x.size()[-1])
w = self.attention(x)
if self.encoder_type == "SAP":
x = torch.sum(x * w, dim=2)
elif self.encoder_type == "ASP":
mu = torch.sum(x * w, dim=2)
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
x = torch.cat((mu, sg), 1)
x = x.view(x.size()[0], -1)
x = self.fc(x)
if l2_norm:
x = torch.nn.functional.normalize(x, p=2, dim=1)
return x

I suspect that the speaker encoder is not producing any gradients, and the speaker consistency loss has no effect.
It looks like this happens:

  • forward gets x with grads
  • the spectrogram is extracted via torch_spec with no grads
  • the output of the speaker encoder has requires grad False (and would produce an exception when you called backward on it since it did not keep activations for grad calculation), but is added to the total loss (which has requires grads se to True)
  • so the call to loss.backward() works as usually, but the speaker encoder does not contribute to the gradients flowing to the generator at all

Could you please check on that?

To Reproduce

import torch

a = torch.tensor(1.0, requires_grad=True)
b = torch.tensor(2.0, requires_grad=True)

with torch.no_grad():
     c = a + b

d = c + 1
e = a + d
e.backward()

print(a.grad) # 1
print(b.grad) # not set

d.backward() # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Expected behavior

No response

Logs

No response

Environment

Just reading the code and asking :)

Additional context

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions