From c0b25541e38c905207ff4df5b97b1b828d341c41 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 20 Sep 2023 19:10:59 -0500 Subject: [PATCH] restructured some things with the model to remove dead weights --- setup.py | 2 +- vall_e/config.py | 10 ++--- vall_e/engines/base.py | 38 +++++++++-------- vall_e/engines/deepspeed.py | 16 ++++++- vall_e/inference.py | 62 +++++++++++++-------------- vall_e/models/ar.py | 12 ++++++ vall_e/models/ar_nar.py | 12 ++++++ vall_e/models/base.py | 38 ++++++++++++----- vall_e/models/nar.py | 12 ++++++ vall_e/models/retnet.py | 70 +----------------------------- vall_e/utils/trainer.py | 85 +++++++++++++------------------------ 11 files changed, 166 insertions(+), 191 deletions(-) diff --git a/setup.py b/setup.py index bff42ea..993c723 100755 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ setup( "auraloss[all]", "vocos", "h5py", - "torchscale @ git+https://github.com/microsoft/torchscale", + "torchscale @ git+https://git.ecker.tech/mrq/torchscale", ], url="https://git.ecker.tech/mrq/vall-e", ) diff --git a/vall_e/config.py b/vall_e/config.py index 404d130..058cf24 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -157,16 +157,16 @@ class Dataset: @dataclass() class Model: name: str = "" + version: int = 1 size: str | float | dict = "full" resp_levels: int = 1 prom_levels: int = 8 - tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") + tasks: int = 0 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") + langs: int = 0 # defined languages arch_type: str = "retnet" training: bool = True interleave: bool = False frozen_params: list[str] = field(default_factory=lambda: []) - p_ar_nar: float = 0.5 - version: int = 1 @property def full_name(self): @@ -240,8 +240,8 @@ class Models: _prom_levels: int = 1 _models: list[Model] = field(default_factory=lambda: [ - Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True, interleave=False), - Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True, interleave=False), + Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, training=True, interleave=False), + Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, langs=1, training=True, interleave=False), ]) def get(self, name=None): diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index b630fb1..de29d80 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -59,10 +59,10 @@ class Engine(): self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None - self.global_steps = 0 - self.micro_steps = 0 - self.global_samples = 0 - self.tokens_processed = 0 + self.global_steps = kwargs.pop("global_steps", 0) + self.micro_steps = kwargs.pop("micro_steps", 0) + self.global_samples = kwargs.pop("global_samples", 0) + self.tokens_processed = kwargs.pop("tokens_processed", 0) self._frozen_params = set() @@ -117,10 +117,12 @@ class Engine(): "optimizer": self.optimizer.state_dict() if self.optimizer is not None else None, "lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, - "global_step": self.global_step, - "micro_step": self.micro_step, - "global_samples": self.global_samples, - "tokens_processed": self.tokens_processed, + "stats": { + "global_step": self.global_step, + "micro_step": self.micro_step, + "global_samples": self.global_samples, + "tokens_processed": self.tokens_processed, + } }, save_path) open(save_dir / "latest", 'w').write( tag ) @@ -137,10 +139,10 @@ class Engine(): return state = torch.load(load_path, map_location=torch.device(cfg.device)) - self.global_steps = state['global_step'] - self.micro_steps = state['micro_step'] - self.global_samples = state['global_samples'] - self.tokens_processed = state['tokens_processed'] + self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step'] + self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step'] + self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples'] + self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed'] self.module.load_state_dict(state['module']) load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state @@ -261,12 +263,14 @@ class Engines(dict[str, Engine]): outpath = cfg.ckpt_dir / name / "fp32.pth" state_dict = { 'module': engine.module.state_dict(), - "global_step": engine.global_step, - "micro_step": engine.micro_step, - "global_samples": engine.global_samples, - "tokens_processed": engine.tokens_processed, + "stats": { + "global_step": engine.global_step, + "micro_step": engine.micro_step, + "global_samples": engine.global_samples, + "tokens_processed": engine.tokens_processed, + }, + "userdata": userdata } - state_dict.update(userdata) torch.save(state_dict, outpath) print(f"Exported {name} to {outpath}") diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index c81838c..d04badc 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -39,10 +39,24 @@ class Engine(DeepSpeedEngine): kwargs['config'] = cfg.trainer.deepspeed.ds_cfg kwargs['config_class'] = DeepSpeedConfig(kwargs['config']) + if "stats" in kwargs: + # stats COULD be = None + stats = kwargs.pop('stats') + if stats is None: + stats = { + "global_steps": 0, + "micro_steps": 0, + "global_samples": 0, + "tokens_processed": 0, + } + super().__init__(None, *args, **kwargs) self._frozen_params = set() - self.tokens_processed = 0 + self.global_steps = stats["global_steps"] + self.micro_steps = stats["micro_steps"] + self.global_samples = stats["global_samples"] + self.tokens_processed = stats["tokens_processed"] def freeze(self, freeze_all=True): if self._cfg is None or not hasattr(self._cfg, "frozen_params"): diff --git a/vall_e/inference.py b/vall_e/inference.py index c062b7d..1b5cb80 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -34,7 +34,7 @@ class TTS(): if amp is None: amp = cfg.inference.amp if dtype is None: - dtype = cfg.inference.dtype + dtype = cfg.inference.weight_dtype if device is None: device = cfg.device @@ -50,43 +50,41 @@ class TTS(): self.amp = amp self.symmap = None + + def parse( name, model, state ): + if "userdata" in state and 'symmap' in state['userdata']: + self.symmap = state['userdata']['symmap'] + elif "symmap" in state: + self.symmap = state['symmap'] + + if "module" in state: + state = state['module'] + + model.load_state_dict(state) + return model + if ar_ckpt and nar_ckpt: self.ar_ckpt = ar_ckpt self.nar_ckpt = nar_ckpt models = get_models(cfg.models.get()) + for name, model in models.items(): - if name.startswith("ar+nar"): - self.ar = model + if name.startswith("ar"): state = torch.load(self.ar_ckpt) - if "symmap" in state: - self.symmap = state['symmap'] - if "module" in state: - state = state['module'] - self.ar.load_state_dict(state) - self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) - self.nar = self.ar - elif name.startswith("ar"): - self.ar = model - state = torch.load(self.ar_ckpt) - if "symmap" in state: - self.symmap = state['symmap'] - if "module" in state: - state = state['module'] - self.ar.load_state_dict(state) - self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) + self.ar = parse( name, model, state ) elif name.startswith("nar"): - self.nar = model state = torch.load(self.nar_ckpt) - if "symmap" in state: - self.symmap = state['symmap'] - if "module" in state: - state = state['module'] - self.nar.load_state_dict(state) - self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) + self.nar = parse( name, model, state ) + + if name.startswith("ar+nar"): + self.nar = self.ar else: self.load_models() + self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) + self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) + if self.symmap is None: self.symmap = get_phone_symmap() @@ -98,13 +96,13 @@ class TTS(): def load_models( self ): engines = load_engines() for name, engine in engines.items(): - if name[:6] == "ar+nar": - self.ar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32) + if name.startswith("ar"): + self.ar = engine.module + elif name.startswith("nar"): + self.nar = engine.module + + if name.startswith("ar+nar"): self.nar = self.ar - elif name[:2] == "ar": - self.ar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32) - elif name[:3] == "nar": - self.nar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32) def encode_text( self, text, language="en" ): # already a tensor, return it diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 3678de5..79122df 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -41,12 +41,24 @@ class AR(Base): def n_tasks(self) -> int: return cfg.models.tasks + @property + def n_langs(self) -> int: + return cfg.models.langs + @property def recurrent_chunk_size(self) -> int: if cfg.mode == "training": return 0 return cfg.inference.recurrent_chunk_size + """ + @property + def rotary_embedding_base(self) -> float: + if hasattr(self, "config") and self.config: + return self.config.rotary_embedding_base + return cfg.models.ar.rotary_embedding_base + """ + @property def interleave(self) -> bool: if hasattr(self, "config") and self.config: diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 5579a13..f146934 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -47,6 +47,14 @@ class AR_NAR(Base): def recurrent_chunk_size(self) -> int: return 0 + """ + @property + def rotary_embedding_base(self) -> float: + if hasattr(self, "config") and self.config: + return self.config.rotary_embedding_base + return cfg.models.ar_nar.rotary_embedding_base + """ + @property def interleave(self) -> bool: return False @@ -293,6 +301,10 @@ def example_usage(): optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) engine = Engine(model=model, optimizer=optimizer) + torch.save( { + 'module': model.state_dict() + }, "./data/test.pth" ) + print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") @torch.inference_mode() diff --git a/vall_e/models/base.py b/vall_e/models/base.py index c535296..fb44745 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -235,11 +235,9 @@ class MultiEmbedding(nn.Module): # Embedding that sums each RVQ-bin level within a given input acoustic prompt class AudioEmbedding(nn.Module): - def __init__(self, n_levels, n_tokens, token_dim): + def __init__(self, l_tokens, token_dim): super().__init__() - self.n_levels = n_levels - # would it be better to have embeddings[1:] reduced to 1024 tokens to attend to, so it's *not* factoring in the stop token? - self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)]) + self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]: res_list = [] @@ -283,6 +281,10 @@ class Base(nn.Module): def n_max_levels(self) -> int: raise NotImplementedError + @property + def n_langs(self) -> int: + raise NotImplementedError + @property def n_tasks(self) -> int: raise NotImplementedError @@ -290,6 +292,10 @@ class Base(nn.Module): @property def recurrent_chunk_size(self) -> int: raise NotImplementedError + + @property + def rotary_embedding_base(self) -> float: + return 10000 @property def interleave(self) -> bool: @@ -341,17 +347,24 @@ class Base(nn.Module): self.n_layers = n_layers # +1 to include the stop token - n_prom_tokens = n_tokens + (self.n_tasks - 1) # - 1 because tts is an inherent task + # to-do: undo this dogshit mistake; tasks tokens should be delegated to its own embedding + n_prom_tokens = n_tokens n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop self.text_emb = Embedding(n_tokens, d_model) if self.version == 1: # legacy + n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic) else: - self.proms_emb = AudioEmbedding(self.n_prom_levels, n_prom_tokens, d_model) - self.resps_emb = AudioEmbedding(self.n_resp_levels, n_resp_tokens, d_model) + # [1024] * 8 + self.proms_emb = AudioEmbedding([n_prom_tokens] * self.n_prom_levels, d_model) + # [1025] + [1024] * 8 + self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model) + + # self.langs_emb = Embedding(self.n_langs, d_model) + # self.tasks_emb = Embedding(self.n_tasks, d_model) self.sep = nn.Parameter(torch.randn(d_model)) @@ -365,7 +378,6 @@ class Base(nn.Module): norm_type=self.norm_type, n_levels=self.n_resp_levels, ) for _ in range(n_layers) ]) - elif self.arch_type == "retnet": self.retnet = RetNetDecoder(RetNetConfig( vocab_size=n_tokens, @@ -380,6 +392,8 @@ class Base(nn.Module): recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0, no_output_layer=True, decoder_normalize_before=True, + + rotary_embedding_base=self.rotary_embedding_base, # 10000 )) self.classifier = nn.Linear(d_model, n_resp_tokens) @@ -407,12 +421,17 @@ class Base(nn.Module): resps_list: list[Tensor], targ_list: list[Tensor] | None = None, + #langs_list: list[Tensor] | None = None, + #tasks_list: list[Tensor] | None = None, + quant_levels: Tensor | None = None, state: dict | None = None, ): x_list = self._samplewise_merge_tensors( self.text_emb(text_list), + #self.langs_emb(langs_list), self.proms_emb(proms_list), + #self.tasks_emb(tasks_list), self.resps_emb(resps_list, quant_levels), sep=self.sep, ) @@ -422,7 +441,7 @@ class Base(nn.Module): batch_size = len(text_list) device = x.device - if state is not None: + if state is not None and self.arch_type == "retnet": # prefill if len(state) == 0: prefill_size = x.shape[1] @@ -443,7 +462,6 @@ class Base(nn.Module): # pass our inputs through the transformer for block in self.blocks: x = block(x, m, l) - elif self.arch_type == "retnet": # pass our inputs through the RetNet x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True) diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 0372292..46a984c 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -39,6 +39,10 @@ class NAR(Base): def n_tasks(self) -> int: return cfg.models.tasks + @property + def n_langs(self) -> int: + return cfg.models.langs + @property def version(self) -> int: if hasattr(self, "config") and self.config: @@ -49,6 +53,14 @@ class NAR(Base): def recurrent_chunk_size(self) -> int: return 0 + """ + @property + def rotary_embedding_base(self) -> float: + if hasattr(self, "config") and self.config: + return self.config.rotary_embedding_base + return cfg.models.nar.rotary_embedding_base + """ + @property def interleave(self) -> bool: return False diff --git a/vall_e/models/retnet.py b/vall_e/models/retnet.py index 4ff6945..a0c4f7c 100755 --- a/vall_e/models/retnet.py +++ b/vall_e/models/retnet.py @@ -1,71 +1,3 @@ -""" -# https://github.com/facebookresearch/fairseq/blob/main/fairseq/incremental_decoding_utils.py -# Copied directly because even having fairseq installed WILL break logging, why are corposhitters like this -""" - -import uuid -from typing import Dict, Optional - -from torch import Tensor - from torchscale.architecture.config import RetNetConfig from torchscale.architecture.retnet import RetNetDecoder - -""" -class FairseqIncrementalState(object): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.init_incremental_state() - - def init_incremental_state(self): - self._incremental_state_id = str(uuid.uuid4()) - - def _get_full_incremental_state_key(self, key: str) -> str: - return "{}.{}".format(self._incremental_state_id, key) - - def get_incremental_state( - self, - incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], - key: str, - ) -> Optional[Dict[str, Optional[Tensor]]]: - full_key = self._get_full_incremental_state_key(key) - if incremental_state is None or full_key not in incremental_state: - return None - return incremental_state[full_key] - - def set_incremental_state( - self, - incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], - key: str, - value: Dict[str, Optional[Tensor]], - ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: - if incremental_state is not None: - full_key = self._get_full_incremental_state_key(key) - incremental_state[full_key] = value - return incremental_state - - -def with_incremental_state(cls): - cls.__bases__ = (FairseqIncrementalState,) + tuple( - b for b in cls.__bases__ if b != FairseqIncrementalState - ) - return cls - - -from torchscale.architecture.config import RetNetConfig -from torchscale.architecture.retnet import RetNetDecoder as Decoder - -@with_incremental_state -class RetNetDecoder(Decoder): - def forward(self, src_tokens, **kwargs): - return super().forward(src_tokens, **kwargs) - - def max_positions(self): - return self.args.max_token_positions - - def reorder_incremental_state( self, incremental_state, new_order ): - for module in incremental_state: - for key in incremental_state[module]: - result = incremental_state[module][key].index_select(0, new_order) - incremental_state[module][key] = result -""" \ No newline at end of file +# from retnet import RetNet \ No newline at end of file diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index e16c004..ace03b3 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -64,84 +64,57 @@ def load_engines(invert=False): lr_scheduler = None if cfg.trainer.backend == "local" or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.torch_optimizer): + optimizer_class = None + params = { + "lr": cfg.hyperparameters.learning_rate, + } if cfg.hyperparameters.optimizer.lower() == "adamw": - params = { - "lr": cfg.hyperparameters.learning_rate, - "betas": (0.9, 0.96), - "eps": 1e-07, - "weight_decay": 0.01, - } - params.update(cfg.hyperparameters.optimizer_params) - optimizer = ml.AdamW( - [ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], - **params, - ) + params["betas"] = (0.9, 0.96) + params["eps"] = 1e-07 + params["weight_decay"] = 0.01 + + optimizer_class = ml.AdamW elif cfg.hyperparameters.optimizer.lower() == "sgd": - params = { - "lr": cfg.hyperparameters.learning_rate, - } - params.update(cfg.hyperparameters.optimizer_params) - optimizer = ml.SGD( - [ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], - **params, - ) + optimizer = ml.SGD elif cfg.hyperparameters.optimizer.lower() == "prodigy": - params = { - "lr": cfg.hyperparameters.learning_rate, - } - params.update(cfg.hyperparameters.optimizer_params) - optimizer = ml.Prodigy( - [ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], - **params, - ) + optimizer_class = ml.Prodigy + else: + raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}') + + params.update(cfg.hyperparameters.optimizer_params) + optimizer = optimizer_class( + [ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], + **params, + ) + + # set up our LR scheduler here if not model._cfg.training: optimizer = None lr_scheduler = None + stats = None if cfg.trainer.load_state_dict or not model._cfg.training: load_path = cfg.ckpt_dir / name / "fp32.pth" state = torch.load(load_path, map_location=torch.device(cfg.device)) - # exporting the model from the zero_to_fp32.py exports the actual module's dict - # exporting with vall_e.export exports the state dict under .module + + # state dict is not just the module, extract the extra trainer details + if "stats" in state: + additionals = state["stats"] + if "module" in state: state = state["module"] - - # should decouple the following from this trainer script - # probably with passing a fun that defaults to a lambda x: x deal - - """ - # can probably be done a lot more intelligently but oh well - # extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks) - if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]: - o_prom_levels, o_prom_tokens, d_model = state['proms_emb.weight'].shape - - # copy weights from the dict into the old portion - model.proms_emb.weight.data[:o_prom_levels, :o_prom_tokens, :] = state['proms_emb.weight'].data[:o_prom_levels, :o_prom_tokens, :] - # copy the full tensors back - state['proms_emb.weight'] = model.proms_emb.weight - - # extend the resps_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks) - if model.resps_emb.weight.shape[0] > state['resps_emb.weight'].shape[0] or model.resps_emb.weight.shape[1] > state['resps_emb.weight'].shape[1]: - o_resp_levels, o_resp_tokens, d_model = state['resps_emb.weight'].shape - n_resp_levels, n_resp_tokens, d_model = model.resps_emb.weight.shape - - # copy weights from the dict into the old portion - model.resps_emb.weight.data[:o_resp_levels, :o_resp_tokens, :] = state['resps_emb.weight'].data[:o_resp_levels, :o_resp_tokens, :] - # copy the full tensors back - state['resps_emb.weight'] = model.resps_emb.weight - """ model.load_state_dict(state, strict=cfg.trainer.strict_loading) - # use base engine because DeepSpeed memory leaks + # use base engine because DeepSpeed memory leaks if it's a non-training model engines[name] = (Engine if model._cfg.training else _Engine)( - #engines[name] = Engine( model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, _cfg=model._cfg, + stats=stats ) engines = Engines(engines)