diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 4072037..4c9c534 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -267,11 +267,31 @@ def _replace_file_extension(path, suffix): # > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime # each audio backend does their "embeddings" a different way that isn't just a embedding weights @torch.inference_mode() -def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cuda"): +def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device="cuda"): model = _load_model(device) codes = codes.to(device=device, dtype=torch.int32) + # yucky kludge + if sums: + if codes.dim() == 1: + codes = rearrange(codes, "t -> t 1") + + if cfg.audio_backend == "dac": + x = [] + for i in range(quant_level+1): + emb = model.quantizer.quantizers[i] + code = rearrange(codes[:, quant_level], "t -> 1 t") + + xi = emb.decode_code(code) + xi = emb.out_proj(xi) + x.append( xi[0].t() ) + + return sum(x).detach() + + raise Exception(f'Currently only DAC is supported') + + if codes.dim() == 2: codes = codes[:, quant_level] diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 2961442..e80068a 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -473,7 +473,7 @@ def example_usage(): """ model = AR_NAR(**kwargs).to(device) - steps = 100 + steps = 150 optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 4674262..f9a9743 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -194,7 +194,7 @@ class AudioEmbedding(nn.Module): input = input[:-1] # get external embedding - embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype) + embedding = encode_as_embedding( input, quant_level, sums=self.sums ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype) # resize if necessary (in case the external embeddings do not match our model dim) embedding = ml.resize_weight( embedding, self.embeddings[quant_level].weight.shape[-1], dim=-1, random=False ) diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py index 6cb5aba..142a6cd 100644 --- a/vall_e/models/experimental.py +++ b/vall_e/models/experimental.py @@ -252,12 +252,6 @@ def example_usage(): kwargs = {} model = Model(**kwargs).to(device) steps = 100 - if cfg.model.arch_type == "mamba2": - steps = 100 - elif cfg.model.arch_type == "llama": - steps = 500 - elif cfg.model.interleave: - steps = 250 optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""