diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 604d3eb..55cc61b 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -143,13 +143,11 @@ def _load_vocos_model(device="cuda", levels=cfg.model.max_levels): @cache def _load_dac_model(device="cuda", levels=cfg.model.max_levels): - kwargs = dict(model_type="24khz",model_bitrate="8kbps",tag="latest") + kwargs = dict(model_type="44khz",model_bitrate="8kbps",tag="latest") if not cfg.variable_sample_rate: # yes there's a better way, something like f'{cfg.sample.rate//1000}hz' if cfg.sample_rate == 44_000: kwargs["model_type"] = "44khz" - elif cfg.sample_rate == 24_000: - kwargs["model_type"] = "24khz" elif cfg.sample_rate == 16_000: kwargs["model_type"] = "16khz" else: @@ -279,11 +277,6 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels) - # trim to 8 codebooks if 24Khz - # probably redundant with levels, should rewrite logic eventuall - if model.model_type == "24khz": - artifact.codes = artifact.codes[:, :8, :] - return artifact.codes if not return_metadata else artifact # vocos does not encode wavs to encodecs, so just use normal encodec