diff --git a/vall_e/config.py b/vall_e/config.py index 3e77416..e2d4e04 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -160,11 +160,12 @@ class Model: resp_levels: int = 1 prom_levels: int = 8 tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") - arch_type: str = "transformer" + arch_type: str = "retnet" training: bool = True interleave: bool = False frozen_params: list[str] = field(default_factory=lambda: []) p_ar_nar: float = 0.5 + version: int = 1 @property def full_name(self): diff --git a/vall_e/data.py b/vall_e/data.py index c2d8a48..e66b352 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -279,7 +279,7 @@ class Dataset(_Dataset): # shuffle it up a bit prom_length = 0 - trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-16, 16) + trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75) for _ in range(cfg.dataset.max_prompts): path = random.choice(choices) diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 8e56345..455197b 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -57,6 +57,12 @@ class AR(Base): def monolithic(self) -> bool: return False + @property + def version(self) -> int: + if hasattr(self, "config") and self.config: + return self.config.version + return cfg.models.ar.version + def _prune(self, l: Tensor): indices = (l == self.stop_token).nonzero() if len(indices) == 0: diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 415df55..6257256 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -54,6 +54,12 @@ class AR_NAR(Base): def monolithic(self) -> bool: return True + @property + def version(self) -> int: + if hasattr(self, "config") and self.config: + return self.config.version + return cfg.models.ar_nar.version + def _prune(self, l: Tensor): indices = (l == self.stop_token).nonzero() if len(indices) == 0: @@ -208,7 +214,7 @@ def example_usage(): ] proms_list = [ #x8(torch.tensor([1, 2, 3], device=device)), - qnt.to(device), + qnt[:75*3, :].to(device), ] resps_list = [ qnt.to(device), @@ -233,11 +239,15 @@ def example_usage(): """ model = AR_NAR(**kwargs).to(device) + #steps = 500 + #optimizer = ml.Prodigy(model.parameters(), lr=1.0) steps = 500 - optimizer = ml.Prodigy(model.parameters(), lr=1.0) + optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) engine = Engine(model=model, optimizer=optimizer) print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + + print([ name for name, _ in model.named_parameters()]) @torch.inference_mode() def sample( name, steps=600 ): diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 01025f1..36ec6bd 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -151,64 +151,45 @@ class MultiEmbedding(nn.Embedding): else: w = self.weight - padded_x_list = [] + padded_x_list = [] - for i, xi in enumerate(x_list): - xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k - wi = w.shape[0] - xi.shape[1] - xi = F.pad(xi, (0, 0, 0, wi)) # t l k - padded_x_list.append(xi.to(w)) + for i, xi in enumerate(x_list): + xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k + wi = w.shape[0] - xi.shape[1] + xi = F.pad(xi, (0, 0, 0, wi)) # t l k + padded_x_list.append(xi.to(w)) - x = torch.cat(padded_x_list) # n l k - x = einsum("l k d, n l k -> n d", w, x) + x = torch.cat(padded_x_list) # n l k + x = einsum("l k d, n l k -> n d", w, x) - x_list = x.split([*map(len, x_list)]) + x_list = x.split([*map(len, x_list)]) return x_list - """ - w_ar, w_nar = self.weight[:1], self.weight[1:] - p_ar_list, p_nar_list = [], [] +# Embedding that sums each RVQ-bin level within a given input acoustic prompt +class AudioEmbedding(nn.Module): + def __init__(self, n_levels, n_tokens, token_dim): + super().__init__() + self.n_levels = n_levels + # would it be better to have embeddings[1:] reduced to 1024 tokens to attend to, so it's *not* factoring in the stop token? + self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)]) + + def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]: + res_list = [] for i, xi in enumerate(x_list): - if quant_levels is None or quant_levels[i] == 0: - w padded_x_list, = w_ar, p_ar_list + # prom + if quant_levels is None and xi.shape[-1] > 1: + x = sum( [ self.embeddings[k]( xi[:, k] ) for k in range(xi.shape[-1]) ] ) + # AR resp + elif quant_levels is None or quant_levels[i] == 0: + x = self.embeddings[0]( xi[:, 0] ) + # NAR resp else: - w, padded_x_list = w_nar, p_nar_list + x = sum( [ self.embeddings[k+1]( xi[:, k] ) for k in range(xi.shape[-1]) ] ) + res_list.append(x) - # pad resp/prom tensor to fit weight - xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k - xi = F.pad(xi, (0, 0, 0, w.shape[0] - xi.shape[1])) # t l k - padded_x_list.append(xi.to(w)) - - # batch list => batch tensor - x_ar_list = einsum("l k d, n l k -> n d", w_ar, torch.cat(p_ar_list)) if len(p_ar_list) > 0 else [] - x_nar_list = einsum("l k d, n l k -> n d", w_nar, torch.cat(p_nar_list)) if len(p_nar_list) > 0 else [] - - x_list = x.split([*map(len, x_list)]) - """ - -""" -# Embedding that sums each RVQ-bin level within a given input acoustic prompt -class PromEmbedding(nn.Module): - def __init__(self, n_levels, n_tokens, token_dim): - super().__init__() - self.n_levels = n_levels - self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)]) - - def forward(self, x_list: list[Tensor] ) -> list[Tensor]: - return [ sum([ self.embeddings[k](xi[:, k]) for k in range(xi.shape[-1]) ]) for i, xi in enumerate(x_list) ] - -# Embedding that selects which embedding based on a quant_level tensor for a given batch -class RespEmbedding(nn.Module): - def __init__(self, n_levels, n_tokens, token_dim): - super().__init__() - self.n_levels = n_levels - self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)]) - - def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]: - return [ self.embeddings[min(self.n_levels, quant_levels[i]) if quant_levels is not None else 0](xi)[:, 0, :] for i, xi in enumerate(x_list) ] -""" + return res_list class Base(nn.Module): @property @@ -252,8 +233,8 @@ class Base(nn.Module): return False @property - def n_embeddings(self) -> int: - return 2 if self.monolithic else 1 + def version(self) -> int: + return 1 @property def stop_token(self): @@ -298,8 +279,12 @@ class Base(nn.Module): self.text_emb = Embedding(n_tokens, d_model) - self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) #, monolithic=self.monolithic) - self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic) + if self.version == 1: # legacy + self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) + self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic) + else: + self.proms_emb = AudioEmbedding(self.n_prom_levels, n_prom_tokens, d_model) + self.resps_emb = AudioEmbedding(self.n_resp_levels, n_resp_tokens, d_model) self.sep = nn.Parameter(torch.randn(d_model)) diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 88aeede..cd394e2 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -39,6 +39,12 @@ class NAR(Base): def n_tasks(self) -> int: return cfg.models.tasks + @property + def version(self) -> int: + if hasattr(self, "config") and self.config: + return self.config.version + return cfg.models.nar.version + @property def recurrent_chunk_size(self) -> int: return 0