Skip to content

Bug during finetuning training #104

@asusdisciple

Description

@asusdisciple

I am currently running a training approach on 4xV100 with 32 GB. I am using a dataset in the style of LJSpeech and the the
finetuning training script. When I use a batch size of 8, I get OOM errors at some point. But when I reduce the batch size to 4
this error appears right in the beginning. Do you have an idea why changing the batch size to 4 could lead to this error?

Traceback (most recent call last):
  File "/raid/me/projects/StyleTTS2/train_finetune.py", line 714, in <module>
    main()
  File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/click/core.py", line 1078, in main
    rv = self.invoke(ctx)
  File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/raid/me/projects/StyleTTS2/train_finetune.py", line 396, in main
    y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
  File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward
    outputs = self.parallel_apply(replicas, inputs, module_kwargs)
  File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 110, in parallel_apply
    output.reraise()
  File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/_utils.py", line 694, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in _worker
    output = module(*input, **kwargs)
  File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/raid/me/projects/StyleTTS2/Modules/hifigan.py", line 458, in forward
    F0 = self.F0_conv(F0_curve.unsqueeze(1))
  File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 310, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 306, in _conv_forward
    return F.conv1d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [1, 1, 3], expected input[1, 284, 1] to have 1 channels, but got 284 channels instead

Changing the batch size to 3 results in this error in the last line:

RuntimeError: Given groups=1, weight of size [1, 1, 3], expected input[1, 221, 1] to have 1 channels, but got 221 channels instead

Apparently there is an issue with the expected input?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions