diff --git a/vall_e/config.py b/vall_e/config.py index 2afac34..c3b69b0 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -162,7 +162,7 @@ class Model: tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") arch_type: str = "transformer" training: bool = True - interleave_pattern: str | None = None + interleave: bool = False @property def full_name(self): @@ -174,6 +174,9 @@ class Model: if self.arch_type != "transformer": name.append(self.arch_type.replace("/", "-")) + if self.interleave: + name.append("interleaved") + name.append(f'{cfg.models.prom_levels}') return "-".join(name) @@ -228,8 +231,8 @@ class Models: _prom_levels: int = 1 _models: list[Model] = field(default_factory=lambda: [ - Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True), - Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True), + Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True, interleave=False), + Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True, interleave=False), ]) def get(self, name=None): diff --git a/vall_e/ext/interleaver.py b/vall_e/ext/interleaver.py deleted file mode 100644 index f7d9590..0000000 --- a/vall_e/ext/interleaver.py +++ /dev/null @@ -1,2 +0,0 @@ -# From: https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/codebooks_patterns.py -# audiocraft has heavy dependencies, so it doesn't make sense to depend on it just for this file. \ No newline at end of file diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 363c94e..b2b7baa 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -23,8 +23,8 @@ class AR(Base): @property def arch_type(self) -> str: - if hasattr(self, "_cfg") and self._cfg: - return self._cfg.arch_type + if hasattr(self, "config") and self.config: + return self.config.arch_type return cfg.models.ar.arch_type @property @@ -33,8 +33,8 @@ class AR(Base): @property def n_resp_levels(self) -> int: - if hasattr(self, "_cfg") and self._cfg: - return self._cfg.resp_levels + if hasattr(self, "config") and self.config: + return self.config.resp_levels return cfg.models.ar.resp_levels @property @@ -55,12 +55,30 @@ class AR(Base): return 0 return cfg.inference.recurrent_chunk_size + @property + def interleave(self) -> bool: + if hasattr(self, "config") and self.config: + return self.config.interleave + return False + def _prune(self, l: Tensor): indices = (l == self.stop_token).nonzero() if len(indices) == 0: return l return l[: indices.min().item()] + def _interleave( self, codes ): + if not self.interleave: + return codes + + return codes.flatten() + + def _deinterleave( self, codes, length = 0 ): + if not self.interleave: + return codes + + return torch.unflatten( codes[:codes.shape[0] // self.n_prom_levels * self.n_prom_levels], 0, ( codes.shape[0] // self.n_prom_levels, self.n_prom_levels ) ) + @staticmethod def _unsqueeze_list(x_list, axis=-1): return [x.unsqueeze(dim=axis) for x in x_list] @@ -74,7 +92,10 @@ class AR(Base): sampling_temperature: float = 1.0, ): if resps_list is not None: - resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels + if self.interleave: + resps_list = [self._interleave(r) for r in resps_list] + else: + resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels return super().forward( text_list=text_list, @@ -94,6 +115,9 @@ class AR(Base): state = {} if cfg.inference.recurrent_forward else None + if self.interleave: + max_steps *= self.n_prom_levels + for n in trange(max_steps // max(1, self.recurrent_chunk_size)): # get next in sequence @@ -116,9 +140,10 @@ class AR(Base): if stopped.all().item(): break - - pruned = [self._prune(r) for r in resps_list] - return pruned + res = [self._prune(r) for r in resps_list] + if self.interleave: + res = [self._deinterleave(r) for r in res] + return res def example_usage(): @@ -163,6 +188,10 @@ def example_usage(): 'n_heads': 16, 'n_layers': 24, } + try: + kwargs['config'] = cfg.models.ar + except Exception as e: + pass model = AR(**kwargs).to(device) engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index c74bace..7eefaeb 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -129,6 +129,10 @@ class Base(nn.Module): @property def recurrent_chunk_size(self) -> int: raise NotImplementedError + + @property + def interleave(self) -> bool: + return False def __init__( self, @@ -137,8 +141,11 @@ class Base(nn.Module): n_heads: int = 8, n_layers: int = 12, p_dropout: float = 0.1, + config = None, ): super().__init__() + self.config = config + self.n_tokens = n_tokens self.d_model = d_model self.n_heads = n_heads diff --git a/vall_e/models/interleaved_ar.py b/vall_e/models/interleaved_ar.py index 3d78576..d2644cf 100644 --- a/vall_e/models/interleaved_ar.py +++ b/vall_e/models/interleaved_ar.py @@ -135,7 +135,7 @@ class Base(nn.Module): @property def n_resp_levels(self) -> int: - return 4 + return 1 @property def n_max_levels(self) -> int: @@ -155,7 +155,7 @@ class Base(nn.Module): @property def interleave_pattern(self) -> str | None: - return "musiclm" + return "flatten" @property def stop_token(self): @@ -192,27 +192,12 @@ class Base(nn.Module): return codes return codes.flatten() - """ - pattern_provider = _get_pattern_provider( self.interleave_pattern )( self.n_resp_levels ) - pattern = pattern_provider.get_pattern( codes.shape[0] ) - res, _, _ = pattern.build_pattern_sequence( codes.t()[None, :, :], self.interleaved_token, keep_only_valid_steps=True ) - return res[0].t().flatten() - """ - def _deinterleave( self, codes ): + def _deinterleave( self, codes, length = 0 ): if not self.interleave_pattern: return codes - return torch.unflatten( codes[:codes.shape[0] // self.n_resp_levels * self.n_resp_levels], 0, ( codes.shape[0] // self.n_resp_levels, self.n_resp_levels ) ) - """ - if codes.dim() == 1: - codes = torch.unflatten( codes[:codes.shape[0] // self.n_resp_levels * self.n_resp_levels], 0, ( codes.shape[0] // self.n_resp_levels, self.n_resp_levels ) ) - - pattern_provider = _get_pattern_provider( self.interleave_pattern )( self.n_resp_levels ) - pattern = pattern_provider.get_pattern( codes.shape[0] ) - res, _, _ = pattern.revert_pattern_sequence( codes, special_token=self.interleaved_token) - return res[0].t() - """ + return torch.unflatten( codes[:codes.shape[0] // self.n_prom_levels * self.n_prom_levels], 0, ( codes.shape[0] // self.n_prom_levels, self.n_prom_levels ) ) def __init__( self, @@ -232,13 +217,13 @@ class Base(nn.Module): self.n_layers = n_layers # + tasks for each token they represent in the prom - n_prom_tokens = n_tokens + (self.n_tasks - 1) + (1 if self.interleave_pattern else 0) # - 1 because tts is an inherent task + n_prom_tokens = n_tokens + (self.n_tasks - 1) # - 1 because tts is an inherent task # +1 to include the stop token + 1 to include interleave token - n_resp_tokens = n_tokens + (1 if self.use_stop_token else 0) + (1 if self.interleave_pattern else 0) # AR requires a stop token to... know when to stop + n_resp_tokens = n_tokens + (1 if self.use_stop_token else 0) # AR requires a stop token to... know when to stop self.text_emb = Embedding(n_tokens, d_model) self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) - self.resps_emb = MultiEmbedding(1, n_resp_tokens, d_model) + self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) self.sep = nn.Parameter(torch.randn(d_model)) @@ -270,7 +255,6 @@ class Base(nn.Module): )) # I imagine because each step returns `resp_level`s tokens at once, so we need to have a classifier for each level - #self.classifier = nn.ModuleList([ nn.Linear(d_model, n_resp_tokens) for _ in range(self.n_resp_levels) ]) if self.interleave_pattern else nn.Linear(d_model, n_resp_tokens) self.classifier = nn.Linear(d_model, n_resp_tokens) self.accuracy_metric = MulticlassAccuracy( @@ -385,11 +369,6 @@ class Base(nn.Module): # Remove padding h_list = [hi[:li] for hi, li in zip(x, map(len, x_list))] - if True: - logits = [hi[:] for hi, li in zip(h_list, map(len, resps_list))] - ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ] - print( [ r for r in ret ] ) - # compute loss if the target is given if targ_list is not None: if any([l == 0 for l in map(len, targ_list)]): @@ -487,6 +466,8 @@ class Base(nn.Module): state = {} if cfg.inference.recurrent_forward else None + max_steps *= self.n_prom_levels + for n in range(max_steps // max(1, self.recurrent_chunk_size)): # get next in sequence @@ -502,6 +483,7 @@ class Base(nn.Module): for i, ri in enumerate(r): if self.stop_token in ri: stopped[i] = True + resps_list[i] = torch.cat([resps_list[i], ri]) # stop token found @@ -509,12 +491,7 @@ class Base(nn.Module): if stopped.all().item(): break - - pruned = [self._prune(r) for r in resps_list] - print( [ r for r in pruned ] ) - deinterleaved = [ self._deinterleave(r) for r in pruned ] - print( [ r for r in deinterleaved ] ) - return deinterleaved + return [self._deinterleave(self._prune(r)) for r in resps_list] def example_usage(): from ..config import cfg @@ -548,7 +525,7 @@ def example_usage(): for name, model in models.items(): print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") - engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() }) + engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=5e-5)) for name, model in models.items() }) train = True @@ -565,7 +542,7 @@ def example_usage(): qnt.to(device), ] - def sample( filename, steps=450 * 4 ): + def sample( filename, steps=450 ): AR = None engines.eval() @@ -578,10 +555,10 @@ def example_usage(): decode_to_file(resps_list[0].cpu(), f"./data/{filename}.wav", device="cpu") if train: - sample("init", 15) + sample("init", 75 ) engines.train() - t = trange(100) + t = trange(500) for i in t: stats = {"step": i} """ diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 531720b..43a3078 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -17,8 +17,8 @@ class NAR(Base): @property def arch_type(self) -> str: - if hasattr(self, "_cfg") and self._cfg: - return self._cfg.arch_type + if hasattr(self, "config") and self.config: + return self.config.arch_type return cfg.models.nar.arch_type @property @@ -31,8 +31,8 @@ class NAR(Base): @property def n_resp_levels(self) -> int: - if hasattr(self, "_cfg") and self._cfg: - return self._cfg.resp_levels + if hasattr(self, "config") and self.config: + return self.config.resp_levels return cfg.models.nar.resp_levels @property @@ -51,6 +51,10 @@ class NAR(Base): def recurrent_chunk_size(self) -> int: return 0 + @property + def interleave(self) -> bool: + return False + def forward( self, text_list: list[Tensor],