diff --git a/vall_e/config.py b/vall_e/config.py index f90095f..7bcdaab 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -168,7 +168,7 @@ class Model: def full_name(self): name = [ self.name ] - if self.size != "full": + if self.size != "full" and isinstance(self.size, str): name.append(self.size) if self.arch_type != "transformer": @@ -287,6 +287,7 @@ class Hyperparameters: gradient_clipping: int = 100 optimizer: str = "Adamw" + torch_optimizer: bool = False optimizer_params: dict = field(default_factory=lambda: {}) learning_rate: float = 3.25e-4 @@ -328,7 +329,7 @@ class DeepSpeed: "params": { "lr": cfg.hyperparameters.learning_rate, } - } if not cfg.hyperparameters.optimizer.endswith("-torch") else None, + } if not cfg.hyperparameters.torch_optimizer else None, "scheduler": { "type": cfg.hyperparameters.scheduler_type, "params": scheduler_params, diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 98afea4..1bbad85 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -17,7 +17,7 @@ class AR_NAR(Base): @property def norm_type(self): - return "ln" if self.n_resp_levels == 1 else "adaln" + return "ln" # if self.n_resp_levels == 1 else "adaln" @property def arch_type(self) -> str: @@ -202,9 +202,9 @@ def example_usage(): kwargs = { 'n_tokens': 1024, - 'd_model': 1536, # 1536 - 'n_heads': 24, # 24 - 'n_layers': 24, # 32 + 'd_model': 1024, # 1536 + 'n_heads': 16, # 24 + 'n_layers': 12, # 32 } """ @@ -218,6 +218,8 @@ def example_usage(): optimizer = ml.Prodigy(model.parameters(), lr=1.0) #optimizer = ml.AdamW(model.parameters(), lr=0.0001) engine = Engine(model=model, optimizer=optimizer) + + print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") def sample( name, steps=600 ): engine.eval() @@ -243,7 +245,7 @@ def example_usage(): tqdm.write(f"{stats}") - sample("init", 75) + #sample("init", 75) train() sample("final") diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 13331b1..8ca5644 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -87,8 +87,31 @@ class MultiEmbedding(nn.Embedding): x_list = x.split([*map(len, x_list)]) return x_list +""" +class PromEmbedding(nn.Module): + def __init__(self, n_levels, n_tokens, token_dim): + super().__init__() + self.n_levels = n_levels + self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)]) + + def forward(self, x_list: list[Tensor] ) -> list[Tensor]: + if len(x_list) == 0: + return [] + return [ sum([ self.embeddings[k](xi[:, k]) for k in range(xi.shape[-1]) ]) for i, xi in enumerate(x_list) ] +class RespEmbedding(nn.Module): + def __init__(self, n_levels, n_tokens, token_dim): + super().__init__() + self.n_levels = n_levels + self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)]) + + def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]: + if len(x_list) == 0: + return [] + res = [ self.embeddings[quant_levels[i] if quant_levels is not None else 0](xi) for i, xi in enumerate(x_list) ] + return res +""" class Base(nn.Module): @property def causal(self) -> bool: @@ -130,6 +153,10 @@ class Base(nn.Module): def dual(self) -> bool: return False + @property + def n_embeddings(self): + return self.n_resp_levels if self.dual else 1 + @property def stop_token(self): if not self.causal: @@ -172,12 +199,18 @@ class Base(nn.Module): n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop self.text_emb = Embedding(n_tokens, d_model) - self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) - if self.dual: - self.resps_emb = nn.ModuleList([MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) for _ in range(2)]) - else: + self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) + if self.n_embeddings == 1: self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) + else: + self.resps_emb = nn.ModuleList([ MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) for _ in range(self.n_embeddings) ]) + """ + if self.n_embeddings == 1: + self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) + else: + self.resps_emb = RespEmbedding(self.n_resp_levels, n_resp_tokens, d_model) + """ self.sep = nn.Parameter(torch.randn(d_model)) @@ -262,18 +295,19 @@ class Base(nn.Module): state: dict | None = None, ): - if self.dual: + if self.n_embeddings == 1: x_list = self._samplewise_merge_tensors( self.text_emb(text_list), self.proms_emb(proms_list), - self.resps_emb[0 if quant_levels is None else 1](resps_list), + self.resps_emb(resps_list), sep=self.sep, ) else: x_list = self._samplewise_merge_tensors( self.text_emb(text_list), self.proms_emb(proms_list), - self.resps_emb(resps_list), + self.resps_emb[0 if quant_levels is None else 1](resps_list), + #self.resps_emb(resps_list, quant_levels), sep=self.sep, ) @@ -296,8 +330,11 @@ class Base(nn.Module): if self.arch_type == "transformer": x = self.sin_emb.add_pe(x) + l = torch.zeros((batch_size,), dtype=torch.int32) if quant_levels is None else quant_levels + l = l.to(device) for block in self.blocks: - x = block(x, m, quant_levels) + x = block(x, m, l) + elif self.arch_type == "retnet": x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True) diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 5f3f234..878ff15 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -91,8 +91,6 @@ class NAR(Base): proms_list, prev_list, targ_list, - return_all_resp=True, - shift_targ_list=False, quant_levels=quant_levels, ) @@ -112,8 +110,6 @@ class NAR(Base): text_list, proms_list, prev_list, - return_all_resp=True, - shift_targ_list=False, quant_levels=quant_levels, sampling_temperature=sampling_temperature, ) diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index e729cc1..ad29bd0 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -62,37 +62,37 @@ def load_engines(invert=False): optimizer = None lr_scheduler = None - # cfg.deepspeed.torch_adam - if (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "adamw") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "adamw-torch"): - params = { - "lr": cfg.hyperparameters.learning_rate, - "betas": (0.9, 0.96), - "eps": 1e-07, - "weight_decay": 0.01, - } - params.update(cfg.hyperparameters.optimizer_params) - optimizer = ml.AdamW( - model.parameters(), - **params, - ) - elif (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "sgd") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "sgd-torch"): - params = { - "lr": cfg.hyperparameters.learning_rate, - } - params.update(cfg.hyperparameters.optimizer_params) - optimizer = ml.SGD( - model.parameters(), - **params, - ) - elif (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "prodigy") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "prodigy-torch"): - params = { - "lr": cfg.hyperparameters.learning_rate, - } - params.update(cfg.hyperparameters.optimizer_params) - optimizer = ml.Prodigy( - model.parameters(), - **params, - ) + if cfg.trainer.backend == "local" or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.torch_optimizer): + if cfg.hyperparameters.optimizer.lower() == "adamw": + params = { + "lr": cfg.hyperparameters.learning_rate, + "betas": (0.9, 0.96), + "eps": 1e-07, + "weight_decay": 0.01, + } + params.update(cfg.hyperparameters.optimizer_params) + optimizer = ml.AdamW( + model.parameters(), + **params, + ) + elif cfg.hyperparameters.optimizer.lower() == "sgd": + params = { + "lr": cfg.hyperparameters.learning_rate, + } + params.update(cfg.hyperparameters.optimizer_params) + optimizer = ml.SGD( + model.parameters(), + **params, + ) + elif cfg.hyperparameters.optimizer.lower() == "prodigy": + params = { + "lr": cfg.hyperparameters.learning_rate, + } + params.update(cfg.hyperparameters.optimizer_params) + optimizer = ml.Prodigy( + model.parameters(), + **params, + ) if not model._cfg.training: optimizer = None