fixed oversight where input audio does not resample (lol...)

This commit is contained in:
mrq 2024-09-27 20:27:53 -05:00
parent 039482a48e
commit 10df2ef5f3

View File

@ -495,7 +495,7 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat
if wav.dim() < 3:
wav = wav.unsqueeze(0)
# skip unnecessary resample
if sr != model.sample_rate and wav.shape[1] != 1:
if sr != model.sample_rate or wav.shape[1] != 1:
wav = convert_audio(wav, sr, model.sample_rate, 1)
wav = wav.to(device)
@ -510,8 +510,9 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat
if wav.dim() < 3:
wav = wav.unsqueeze(0)
# skip unnecessary resample
if sr != model.sample_rate and wav.shape[1] != model.channels:
if sr != model.sample_rate or wav.shape[1] != model.channels:
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
wav = wav.to(device)
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):