-
Notifications
You must be signed in to change notification settings - Fork 607
Closed
Description
I am currently running a training approach on 4xV100 with 32 GB. I am using a dataset in the style of LJSpeech and the the
finetuning training script. When I use a batch size of 8,
I get OOM errors at some point. But when I reduce the batch size to 4
this error appears right in the beginning. Do you have an idea why changing the batch size to 4 could lead to this error?
Traceback (most recent call last):
File "/raid/me/projects/StyleTTS2/train_finetune.py", line 714, in <module>
main()
File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
return self.main(*args, **kwargs)
File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/click/core.py", line 1078, in main
rv = self.invoke(ctx)
File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/click/core.py", line 783, in invoke
return __callback(*args, **kwargs)
File "/raid/me/projects/StyleTTS2/train_finetune.py", line 396, in main
y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward
outputs = self.parallel_apply(replicas, inputs, module_kwargs)
File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 110, in parallel_apply
output.reraise()
File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/_utils.py", line 694, in reraise
raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in _worker
output = module(*input, **kwargs)
File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/raid/me/projects/StyleTTS2/Modules/hifigan.py", line 458, in forward
F0 = self.F0_conv(F0_curve.unsqueeze(1))
File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 310, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/raid/me/projects/StyleTTS2/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 306, in _conv_forward
return F.conv1d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [1, 1, 3], expected input[1, 284, 1] to have 1 channels, but got 284 channels instead
Changing the batch size to 3 results in this error in the last line:
RuntimeError: Given groups=1, weight of size [1, 1, 3], expected input[1, 221, 1] to have 1 channels, but got 221 channels instead
Apparently there is an issue with the expected input?
Metadata
Metadata
Assignees
Labels
No labels