From 10df2ef5f3eed5bbf59f56bf875fe5bfbc82246c Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 27 Sep 2024 20:27:53 -0500 Subject: [PATCH] fixed oversight where input audio does not resample (lol...) --- vall_e/emb/qnt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index ffa9f16..a9d7d14 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -495,7 +495,7 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat if wav.dim() < 3: wav = wav.unsqueeze(0) # skip unnecessary resample - if sr != model.sample_rate and wav.shape[1] != 1: + if sr != model.sample_rate or wav.shape[1] != 1: wav = convert_audio(wav, sr, model.sample_rate, 1) wav = wav.to(device) @@ -510,8 +510,9 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat if wav.dim() < 3: wav = wav.unsqueeze(0) # skip unnecessary resample - if sr != model.sample_rate and wav.shape[1] != model.channels: + if sr != model.sample_rate or wav.shape[1] != model.channels: wav = convert_audio(wav, sr, model.sample_rate, model.channels) + wav = wav.to(device) with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):