diff --git a/vall_e/inference.py b/vall_e/inference.py index f8a0b98..b7a683a 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -33,6 +33,7 @@ class TTS(): pass cfg.mode = "inferencing" + cfg.trainer.load_module_only = True self.symmap = None if ar_ckpt and nar_ckpt: diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 7861253..850909c 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -150,6 +150,7 @@ def example_usage(): from ..emb.qnt import decode_to_file from ..engines import Engine from tqdm import tqdm + from ..utils import wrapper as ml device = "cuda" x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels) @@ -183,16 +184,20 @@ def example_usage(): 'n_heads': 16, 'n_layers': 24, } - + + """ try: kwargs['config'] = cfg.models.ar except Exception as e: pass + """ model = AR(**kwargs).to(device) - engine = Engine(model=model, optimizer=torch.optim.SGD(model.parameters(), lr=0.1)) + optimizer = ml.Prodigy(model.parameters(), lr=1.0) + engine = Engine(model=model, optimizer=optimizer) + steps = 500 - def sample( name, steps=400 ): + def sample( name, steps=600 ): engine.eval() out = engine(text_list, proms_list, max_steps=steps) for i, o in enumerate(out): @@ -200,7 +205,7 @@ def example_usage(): def train(): engine.train() - t = trange(60) + t = trange(steps) for i in t: stats = {"step": i} stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 49ea8b9..c9288da 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -1,5 +1,4 @@ from .base import Base, list_to_tensor, Categorical -from ..utils import wrapper as ml from ..config import cfg import torch @@ -173,6 +172,7 @@ def example_usage(): from ..emb.qnt import decode_to_file, unload_model from ..engines import Engine from tqdm import tqdm + from ..utils import wrapper as ml device = "cuda" x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels) @@ -215,8 +215,8 @@ def example_usage(): """ model = AR_NAR(**kwargs).to(device) + steps = 500 optimizer = ml.Prodigy(model.parameters(), lr=1.0) - #optimizer = ml.AdamW(model.parameters(), lr=0.0001) engine = Engine(model=model, optimizer=optimizer) print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") @@ -238,7 +238,7 @@ def example_usage(): def train(): engine.train() - t = trange(500) + t = trange(steps) for i in t: stats = {"step": i} stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 8a98426..e0aab7d 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -68,7 +68,9 @@ class MultiEmbedding(nn.Embedding): self.n_tokens = n_tokens self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim)) - def forward(self, x_list: list[Tensor]) -> list[Tensor]: + # to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb) + # I imagine this is an oversight in the NAR. + def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]: if len(x_list) == 0: return [] @@ -151,9 +153,13 @@ class Base(nn.Module): return False @property - def n_embeddings(self): + def n_embeddings(self) -> int: return self.n_resp_levels if self.monolithic else 1 + @property + def use_old_embeddings(self) -> bool: + return True + @property def stop_token(self): if not self.causal: @@ -199,14 +205,14 @@ class Base(nn.Module): # use dedicated embeddings for each RVQ-bin level in the input acoustic prompt if requested # n_embeddings == prom_levels because using the MultiEmbedding is more than fine for the input acoustic prompt - if self.n_embeddings == self.n_prom_levels: + if self.n_embeddings == self.n_prom_levels or not self.use_old_embeddings: self.proms_emb = PromEmbedding(self.n_prom_levels, n_prom_tokens, d_model) else: self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) # use dedicated embeddings for each RVQ-bin level in the output response / target if requested # n_embeddings > 1 because the using the MultiEmbedding "works" fine enough for split AR/NARs. - if self.n_embeddings > 1: + if self.n_embeddings > 1 or not self.use_old_embeddings: self.resps_emb = RespEmbedding(self.n_embeddings, n_resp_tokens, d_model) else: self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) @@ -409,6 +415,7 @@ def example_usage(): from ..emb.qnt import decode_to_file from ..engines import Engine, Engines from tqdm import tqdm, trange + from ..utils import wrapper as ml from .ar import AR from .nar import NAR @@ -432,7 +439,7 @@ def example_usage(): for name, model in models.items(): print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") - engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() }) + engines = Engines({ name: Engine(model=model, optimizer=ml.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() }) train = True @@ -449,7 +456,7 @@ def example_usage(): qnt.to(device), ] - def sample( name, steps=400 ): + def sample( name, steps=600 ): AR = None NAR = None @@ -471,7 +478,7 @@ def example_usage(): sample("init", 15) engines.train() - t = trange(60) + t = trange(500) for i in t: stats = {"step": i} """ diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index a1fe500..194d8cd 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -130,6 +130,7 @@ def example_usage(): from ..emb.qnt import decode_to_file from ..engines import Engine from tqdm import tqdm + from ..utils import wrapper as ml device = "cuda" x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels) @@ -162,7 +163,9 @@ def example_usage(): 'n_layers': 12, } model = NAR(**kwargs).to(device) - engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) + steps = 500 + optimizer = ml.Prodigy(model.parameters(), lr=1.0) + engine = Engine(model=model, optimizer=optimizer) def sample( name ): engine.eval() @@ -171,7 +174,7 @@ def example_usage(): def train(): engine.train() - t = trange(60) + t = trange(steps) for i in t: stats = {"step": i} stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list) diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index ad29bd0..c8995a4 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -43,21 +43,22 @@ def load_engines(invert=False): engines = dict() for name, model in models.items(): - # load only the models for training initially - # loads disabled models at evaluation time (to load updated weights if training separately) - # I'm sure there's a more elegant solution to this - if cfg.evaluation.load_disabled_engines: - if not invert and not model._cfg.training: - continue - if invert and model._cfg.training: - continue - # load only the models for training initially - # if load_disabled_engines, then models not marked for training will be loaded but ignored - # DeepSpeed has some weird quirks where loading an engine and moving it to CPU will have a memory leak or something - # I recommend not using this pathway - elif not cfg.trainer.load_disabled_engines: - if model._cfg.training: - continue + if cfg.mode != "inferencing": + # load only the models for training initially + # loads disabled models at evaluation time (to load updated weights if training separately) + # I'm sure there's a more elegant solution to this + if cfg.evaluation.load_disabled_engines: + if not invert and not model._cfg.training: + continue + if invert and model._cfg.training: + continue + # load only the models for training initially + # if load_disabled_engines, then models not marked for training will be loaded but ignored + # DeepSpeed has some weird quirks where loading an engine and moving it to CPU will have a memory leak or something + # I recommend not using this pathway + elif not cfg.trainer.load_disabled_engines: + if model._cfg.training: + continue optimizer = None lr_scheduler = None