From b2907ae7e0f987cf7ff3895c2cc968f531628345 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 8 Sep 2023 01:03:24 -0500 Subject: [PATCH] seems that my PromEmbedding/RespEmbedding doesn't actually work all that well, naively using dedicated MultiEmbeddings for AR/NAR in the monolithic model is the best way to go --- vall_e/inference.py | 17 +++++++++++++++-- vall_e/models/ar_nar.py | 2 +- vall_e/models/base.py | 36 +++++++++++++++--------------------- vall_e/train.py | 11 +++++++++-- 4 files changed, 40 insertions(+), 26 deletions(-) diff --git a/vall_e/inference.py b/vall_e/inference.py index b7a683a..fb10d8c 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -42,7 +42,17 @@ class TTS(): models = get_models(cfg.models.get()) for name, model in models.items(): - if name.startswith("ar"): + if name.startswith("ar+nar"): + self.ar = model + state = torch.load(self.ar_ckpt) + if "symmap" in state: + self.symmap = state['symmap'] + if "module" in state: + state = state['module'] + self.ar.load_state_dict(state) + self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32) + self.nar = self.ar + elif name.startswith("ar"): self.ar = model state = torch.load(self.ar_ckpt) if "symmap" in state: @@ -74,7 +84,10 @@ class TTS(): def load_models( self ): engines = load_engines() for name, engine in engines.items(): - if name[:2] == "ar": + if name[:6] == "ar+nar": + self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32) + self.nar = self.ar + elif name[:2] == "ar": self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32) elif name[:3] == "nar": self.nar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index c9288da..5d93c2b 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -82,7 +82,7 @@ class AR_NAR(Base): # is training if n_levels == self.n_resp_levels: - if random.random() < 0.5: + if random.random() < 0.25: quant_levels = None targ_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 164fc46..4645eaa 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -62,11 +62,15 @@ class MultiEmbedding(nn.Embedding): This embedding sums embeddings on different levels. """ - def __init__(self, max_n_levels, n_tokens, token_dim): + def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False): super().__init__(max_n_levels, token_dim) self.max_n_levels = max_n_levels self.n_tokens = n_tokens - self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim)) + self.monolithic = monolithic + if self.monolithic: + self.weights = nn.Parameter(torch.randn(2, max_n_levels, n_tokens, token_dim)) + else: + self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim)) # to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb) # I imagine this is an oversight in the NAR. @@ -74,7 +78,10 @@ class MultiEmbedding(nn.Embedding): if len(x_list) == 0: return [] - w = self.weight + if self.monolithic: + w = self.weights[0 if quant_levels is None else 1] + else: + w = self.weight padded_x_list = [] @@ -91,6 +98,7 @@ class MultiEmbedding(nn.Embedding): return x_list +""" # Embedding that sums each RVQ-bin level within a given input acoustic prompt class PromEmbedding(nn.Module): def __init__(self, n_levels, n_tokens, token_dim): @@ -110,6 +118,7 @@ class RespEmbedding(nn.Module): def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]: return [ self.embeddings[min(self.n_levels, quant_levels[i]) if quant_levels is not None else 0](xi)[:, 0, :] for i, xi in enumerate(x_list) ] +""" class Base(nn.Module): @property @@ -154,7 +163,7 @@ class Base(nn.Module): @property def n_embeddings(self) -> int: - return self.n_resp_levels if self.monolithic else 1 + return 2 if self.monolithic else 1 @property def stop_token(self): @@ -183,15 +192,11 @@ class Base(nn.Module): p_dropout: float = 0.1, config = None, - use_multiembedding = False, ): super().__init__() self.config = config self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True - if self.config is not None and hasattr(self.config, "use_multiembedding"): - use_multiembedding = self.config.use_multiembedding - self.n_tokens = n_tokens self.d_model = d_model self.n_heads = n_heads @@ -203,19 +208,8 @@ class Base(nn.Module): self.text_emb = Embedding(n_tokens, d_model) - # use dedicated embeddings for each RVQ-bin level in the input acoustic prompt if requested - # n_embeddings == prom_levels because using the MultiEmbedding is more than fine for the input acoustic prompt - if self.n_embeddings == self.n_prom_levels or not use_multiembedding: - self.proms_emb = PromEmbedding(self.n_prom_levels, n_prom_tokens, d_model) - else: - self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) - - # use dedicated embeddings for each RVQ-bin level in the output response / target if requested - # n_embeddings > 1 because the using the MultiEmbedding "works" fine enough for split AR/NARs. - if self.n_embeddings > 1 or not use_multiembedding: - self.resps_emb = RespEmbedding(self.n_embeddings, n_resp_tokens, d_model) - else: - self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) + self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) #, monolithic=self.monolithic) + self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic) self.sep = nn.Parameter(torch.randn(d_model)) diff --git a/vall_e/train.py b/vall_e/train.py index 83f8f97..c874874 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -52,10 +52,13 @@ def run_eval(engines, disabled_engines, eval_name, dl): AR = None NAR = None + AR_NAR = None names = [] for name, engine in engines.items(): - if name[:2] == "ar": + if name[:6] == "ar+nar": + AR_NAR = engine + elif name[:2] == "ar": AR = engine elif name[:3] == "nar": NAR = engine @@ -127,7 +130,11 @@ def run_eval(engines, disabled_engines, eval_name, dl): for name in engines: model = engines[name] - if name.startswith("ar"): + if name.startswith("ar+nar"): + resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature) + resps_list = [ r.unsqueeze(-1) for r in resps_list ] + resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature) + elif name.startswith("ar"): resps_list = model( text_list=batch["text"], proms_list=batch["proms"],