From 100ca6b7d089c077bc2aaff6a967c8807adc585d Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 6 Sep 2023 18:58:35 -0500 Subject: [PATCH] added option to use SGD optimizer through the YAML, added option to pass in additional optimizer parameters through the YAML, added experimental unified AR+NAR model (does not seem fruitful in testing) --- vall_e/config.py | 5 + vall_e/models/__init__.py | 3 + vall_e/models/ar.py | 23 ++-- vall_e/models/ar_nar.py | 258 ++++++++++++++++++++++++++++++++++++++ vall_e/models/base.py | 92 +++++--------- vall_e/models/nar.py | 8 -- vall_e/utils/trainer.py | 25 ++-- vall_e/utils/wrapper.py | 10 +- 8 files changed, 330 insertions(+), 94 deletions(-) create mode 100644 vall_e/models/ar_nar.py diff --git a/vall_e/config.py b/vall_e/config.py index fb441af..f90095f 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -254,6 +254,10 @@ class Models: def ar(self): return self.get("ar") + @property + def ar_nar(self): + return self.get("ar+nar") + @property def nar(self): return self.get("nar") @@ -283,6 +287,7 @@ class Hyperparameters: gradient_clipping: int = 100 optimizer: str = "Adamw" + optimizer_params: dict = field(default_factory=lambda: {}) learning_rate: float = 3.25e-4 scheduler_type: str = "" diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index b6983c1..e9728ec 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -1,11 +1,14 @@ from .ar import AR from .nar import NAR +from .ar_nar import AR_NAR def get_model(cfg): if cfg.name == "ar": Model = AR elif cfg.name == "nar": Model = NAR + elif cfg.name == "ar+nar": + Model = AR_NAR else: raise f"invalid model name: {cfg.name}" name = cfg.name diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index b2b7baa..1c15263 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -13,10 +13,6 @@ class AR(Base): def causal(self): return True - @property - def use_stop_token(self): - return True - @property def norm_type(self): return "ln" @@ -45,10 +41,6 @@ class AR(Base): def n_tasks(self) -> int: return cfg.models.tasks - @property - def resp_loss_only(self) -> bool: - return False - @property def recurrent_chunk_size(self) -> int: if cfg.mode == "training": @@ -103,8 +95,6 @@ class AR(Base): resps_list=self._unsqueeze_list(resps_list), targ_list=resps_list, quant_levels=None, - shift_targ_list=True, - return_all_resp=False, ) device = text_list[0].device @@ -122,9 +112,10 @@ class AR(Base): # get next in sequence r = super().forward( - text_list, - proms_list, - self._unsqueeze_list(resps_list), + text_list=text_list, + proms_list=proms_list, + resps_list=self._unsqueeze_list(resps_list), + quant_levels=None, sampling_temperature=sampling_temperature, state=state ) @@ -188,12 +179,14 @@ def example_usage(): 'n_heads': 16, 'n_layers': 24, } + try: kwargs['config'] = cfg.models.ar except Exception as e: - pass + pass + model = AR(**kwargs).to(device) - engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) + engine = Engine(model=model, optimizer=torch.optim.SGD(model.parameters(), lr=0.1)) def sample( name, steps=400 ): engine.eval() diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py new file mode 100644 index 0000000..0ce24e4 --- /dev/null +++ b/vall_e/models/ar_nar.py @@ -0,0 +1,258 @@ +from ..config import cfg +from .base import Base, list_to_tensor, Categorical + +import torch +from torch.nn.utils.rnn import pad_sequence + +import random +from einops import rearrange +from torch import Tensor +from tqdm import trange + +class AR_NAR(Base): + @property + def causal(self): + return True + + @property + def norm_type(self): + return "ln" + + @property + def arch_type(self) -> str: + if hasattr(self, "config") and self.config: + return self.config.arch_type + return cfg.models.ar_nar.arch_type + + @property + def n_prom_levels(self) -> int: + return cfg.models.prom_levels + + @property + def n_resp_levels(self) -> int: + if hasattr(self, "config") and self.config: + return self.config.resp_levels + return cfg.models.ar_nar.resp_levels + + @property + def n_max_levels(self) -> int: + return cfg.models.max_levels + + @property + def n_tasks(self) -> int: + return cfg.models.tasks + + @property + def recurrent_chunk_size(self) -> int: + if cfg.mode == "training": + return 0 + return cfg.inference.recurrent_chunk_size + + @property + def interleave(self) -> bool: + if hasattr(self, "config") and self.config: + return self.config.interleave + return False + + def _prune(self, l: Tensor): + indices = (l == self.stop_token).nonzero() + if len(indices) == 0: + return l + return l[: indices.min().item()] + + def _interleave( self, codes ): + if not self.interleave: + return codes + + return codes.flatten() + + def _deinterleave( self, codes, length = 0 ): + if not self.interleave: + return codes + + return torch.unflatten( codes[:codes.shape[0] // self.n_prom_levels * self.n_prom_levels], 0, ( codes.shape[0] // self.n_prom_levels, self.n_prom_levels ) ) + + @staticmethod + def _unsqueeze_list(x_list, axis=-1): + return [x.unsqueeze(dim=axis) for x in x_list] + + def forward( + self, + text_list: list[Tensor], + proms_list: list[Tensor], + resps_list: list[Tensor] | None = None, + max_steps: int = 1000, + sampling_temperature: float = 0.0, + ): + device = text_list[0].device + batch_size = len(text_list) + + # is training or NAR + if resps_list is not None: + n_levels_set = {r.shape[-1] for r in resps_list} + n_levels = next(iter(n_levels_set)) + + # is training + if n_levels == self.n_resp_levels: + if random.random() < 0.5: + quant_levels = None + + targ_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels + resps_list = self._unsqueeze_list(targ_list) + else: + quant_levels = torch.randint(1, self.n_resp_levels, (batch_size,)) + + targ_list = [o[..., l] for o, l in zip(resps_list, quant_levels)] + resps_list = [o[..., : l] for o, l in zip(resps_list, quant_levels)] + + if quant_levels is not None: + quant_levels.to(device=device) + + return super().forward( + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + targ_list=targ_list, + quant_levels=quant_levels, + ) + # is NAR + prev_list = resps_list + + while True: + level = prev_list[0].shape[-1] - 1 + + if level >= self.n_resp_levels: + break + + quant_levels = torch.full((len(text_list),), level, device=device) + + resps_list = super().forward( + text_list, + proms_list, + prev_list, + quant_levels=quant_levels, + sampling_temperature=sampling_temperature, + ) + + prev_list = [ + torch.cat([rs, r.unsqueeze(-1)], dim=-1) + for rs, r in zip(prev_list, resps_list) + ] + + return prev_list + + # is AR + resps_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ] + stopped = torch.zeros(batch_size, device=device).bool() + + state = {} if cfg.inference.recurrent_forward else None + + if self.interleave: + max_steps *= self.n_prom_levels + + for n in trange(max_steps // max(1, self.recurrent_chunk_size)): + # get next in sequence + + r = super().forward( + text_list, + proms_list, + self._unsqueeze_list(resps_list), + sampling_temperature=sampling_temperature, + state=state + ) + + # append tokens + for i, ri in enumerate(r): + if self.stop_token in ri: + stopped[i] = True + resps_list[i] = torch.cat([resps_list[i], ri]) + + # stop token found + stopped |= r == self.stop_token + if stopped.all().item(): + break + + return [self._prune(r) for r in resps_list] + + +def example_usage(): + cfg.trainer.backend = "local" + from functools import partial + + from einops import repeat + + from ..emb.qnt import decode_to_file + from ..engines import Engine + from tqdm import tqdm + + device = "cuda" + x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels) + symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} + def tokenize(content, lang_marker="en"): + split = content.split(" ") + phones = [f""] + [ " " if not p else p for p in split ] + [f""] + return torch.tensor([*map(symmap.get, phones)]).to() + + qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device) + + text_list = [ + #torch.tensor([1, 2, 3], device=device), + tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device), + ] + proms_list = [ + #x8(torch.tensor([1, 2, 3], device=device)), + qnt.to(device), + ] + resps_list = [ + qnt.to(device), + ] + + text_list = text_list[:1] + proms_list = proms_list[:1] + resps_list = resps_list[:1] + + kwargs = { + 'n_tokens': 1024, + 'd_model': 1024, + 'n_heads': 16, + 'n_layers': 24, + } + + """ + try: + kwargs['config'] = cfg.models.ar_nar + except Exception as e: + pass + """ + + model = AR_NAR(**kwargs).to(device) + engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=0.001)) + + def sample( name, steps=600 ): + engine.eval() + resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 ) + + for i, o in enumerate(resps_list): + _ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device) + + resps_list = [r.unsqueeze(-1) for r in resps_list] + resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 ) + + for i, o in enumerate(resps_list): + _ = decode_to_file(o, f"data/ar+nar.{i}.{name}.wav", device=device) + + def train(): + engine.train() + t = trange(5000) + for i in t: + stats = {"step": i} + stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list) + + tqdm.write(f"{stats}") + + sample("init", 75) + train() + sample("final") + +if __name__ == "__main__": + example_usage() diff --git a/vall_e/models/base.py b/vall_e/models/base.py index ebd7018..7c9f48d 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -94,14 +94,6 @@ class Base(nn.Module): def causal(self) -> bool: raise NotImplementedError - @property - def n_resp_levels(self) -> int: - raise NotImplementedError - - @property - def use_stop_token(self) -> bool: - raise NotImplementedError - @property def arch_type(self) -> str: raise NotImplementedError @@ -114,6 +106,10 @@ class Base(nn.Module): def n_prom_levels(self) -> int: raise NotImplementedError + @property + def n_resp_levels(self) -> int: + raise NotImplementedError + @property def n_max_levels(self) -> int: raise NotImplementedError @@ -122,10 +118,6 @@ class Base(nn.Module): def n_tasks(self) -> int: raise NotImplementedError - @property - def resp_loss_only(self): - raise NotImplementedError - @property def recurrent_chunk_size(self) -> int: raise NotImplementedError @@ -134,6 +126,24 @@ class Base(nn.Module): def interleave(self) -> bool: return False + @property + def stop_token(self): + if not self.causal: + raise ValueError("Not using stop token!") + return self.n_tokens + + @property + def ignore_index(self): + return -100 + + @staticmethod + def _samplewise_merge_tensors(*l, sep: Tensor | None): + if sep is None: + cat = torch.cat + else: + cat = partial(_join, sep=sep) + return [*map(cat, zip(*l))] + def __init__( self, n_tokens: int = 1024, @@ -155,7 +165,7 @@ class Base(nn.Module): # +1 to include the stop token n_prom_tokens = n_tokens + (self.n_tasks - 1) # - 1 because tts is an inherent task - n_resp_tokens = n_tokens + (1 if self.use_stop_token else 0) # AR requires a stop token to... know when to stop + n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop self.text_emb = Embedding(n_tokens, d_model) self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) @@ -208,24 +218,6 @@ class Base(nn.Module): ignore_index=self.ignore_index, ) - @property - def stop_token(self): - if not self.use_stop_token: - raise ValueError("Not using stop token!") - return self.n_tokens - - @property - def ignore_index(self): - return -100 - - @staticmethod - def _samplewise_merge_tensors(*l, sep: Tensor | None): - if sep is None: - cat = torch.cat - else: - cat = partial(_join, sep=sep) - return [*map(cat, zip(*l))] - @overload def forward( self, @@ -234,9 +226,6 @@ class Base(nn.Module): resps_list: list[Tensor], targ_list: list[Tensor] | None = None, quant_levels: Tensor | None = None, - shift_targ_list: bool = False, - return_all: Literal[False] = False, - return_all_resp: Literal[False] = False, sampling_temperature: float = 1.0, ) -> Tensor: ... @@ -249,9 +238,6 @@ class Base(nn.Module): resps_list: list[Tensor], targ_list: list[Tensor] | None = None, quant_levels: Tensor | None = None, - shift_targ_list: bool = False, - return_all: Literal[True] = True, - return_all_resp: Literal[True] = True, sampling_temperature: float = 1.0, ) -> list[Tensor]: ... @@ -262,28 +248,12 @@ class Base(nn.Module): proms_list: list[Tensor], resps_list: list[Tensor], targ_list: list[Tensor] | None = None, + quant_levels: Tensor | None = None, - shift_targ_list: bool = False, - return_all: bool = False, - return_all_resp: bool = False, sampling_temperature: float = 1.0, state: dict | None = None, ): - """ - Args: - text_list: [t] * b - proms_list: [t' l] * b, l quantization levels. - resps_list: [t'' l] * b, l quantization levels. - targ_list: [t''] * b, one quantization level only; when given, loss will be computed - quant_levels: specify which quant_levels to feed forward, used in NAR mode. - shift_targ_list: whether to shift target list when computing loss. True if AR. - return_all_resp: True if NAR. - sampling_temperature: a lower temperature makes the result more robust but less diverse. - Returns: - y: sampled tokens - """ - x_list = self._samplewise_merge_tensors( self.text_emb(text_list), self.proms_emb(proms_list), @@ -334,17 +304,16 @@ class Base(nn.Module): # process each batch for i in range(len(text_prom_list)): - # for the NAR, ignore completely computing the loss against the text prompt - if self.resp_loss_only: - text_prom_list[i][:] = self.ignore_index - # for the AR, shift the text/input prompt into the future by 1, and ignore the rolled back text token - else: + if quant_levels is None: text_prom_list[i] = text_prom_list[i].roll(-1, dims=0) text_prom_list[i][-1] = self.ignore_index + # for the NAR, ignore completely computing the loss against the text prompt + else: + text_prom_list[i][:] = self.ignore_index # adjust the target sequence if needed for the AR - if shift_targ_list: + if quant_levels is None: # creates a copy because this is aliased against input response sequence targ_list = [*targ_list] # shift the target response into the future by 1, and mark the rolled back token / last token as a stop token @@ -370,10 +339,11 @@ class Base(nn.Module): ) # return the entire generated token string + return_all = False if return_all: logits = [hi[:] for hi, li in zip(h_list, map(len, resps_list))] # return the entire generated response - elif return_all_resp: + elif quant_levels is not None: logits = [hi[-li:] for hi, li in zip(h_list, map(len, resps_list))] # return the last chunkwise piece elif self.causal and self.recurrent_chunk_size > 0: diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 43a3078..2409c7b 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -11,10 +11,6 @@ class NAR(Base): def causal(self): return False - @property - def use_stop_token(self): - return False - @property def arch_type(self) -> str: if hasattr(self, "config") and self.config: @@ -43,10 +39,6 @@ class NAR(Base): def n_tasks(self) -> int: return cfg.models.tasks - @property - def resp_loss_only(self) -> bool: - return True - @property def recurrent_chunk_size(self) -> int: return 0 diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 10fcca5..8c9a945 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -62,16 +62,27 @@ def load_engines(invert=False): optimizer = None lr_scheduler = None - # yuck, should instead check be optimier == "adamw" and backend != "deepspeed" - # and then have ds_cfg pass in the config flag to use pytorch adamw - # I genuinely cannot validate if this ever actually gets used in DeepSpeed + # cfg.deepspeed.torch_adam if (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "adamw") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "adamw-torch"): + params = { + "lr": cfg.hyperparameters.learning_rate, + "betas": (0.9, 0.96), + "eps": 1e-07, + "weight_decay": 0.01, + } + params.update(cfg.hyperparameters.optimizer_params) optimizer = ml.AdamW( model.parameters(), - lr=cfg.hyperparameters.learning_rate, - betas=(0.9, 0.96), - eps=1e-07, - weight_decay=0.01, + **params, + ) + elif (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "sgd") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "sgd-torch"): + params = { + "lr": cfg.hyperparameters.learning_rate, + } + params.update(cfg.hyperparameters.optimizer_params) + optimizer = ml.SGD( + model.parameters(), + **params, ) if not model._cfg.training: diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index 040762d..bbdbf8a 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -25,14 +25,17 @@ if cfg.bitsandbytes.enabled: self.sparse, )).to(self.weight.dtype) ) -Adam = torch.optim.Adam -AdamW = torch.optim.AdamW if cfg.bitsandbytes.enabled: import bitsandbytes as bnb Adam = bnb.optim.Adam AdamW = bnb.optim.AdamW + SGD = bnb.optim.SGD +else: + Adam = torch.optim.Adam + AdamW = torch.optim.AdamW + SGD = torch.optim.SGD # handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16) @contextmanager @@ -72,4 +75,5 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled: torch.nn.Embedding = Embedding torch.optim.Adam = Adam - torch.optim.AdamW = AdamW \ No newline at end of file + torch.optim.AdamW = AdamW + torch.optim.SGD = SGD \ No newline at end of file