From b0bd88833c43deaa9398368ecd18e17dd5dc4d8f Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 16 Apr 2024 21:04:48 -0500 Subject: [PATCH] refractor cleanup, had a revelation on how I can handle a batch of varying tasks --- vall_e/config.py | 3 +- vall_e/models/ar_nar.py | 38 +++++---- vall_e/models/base.py | 169 ++++++++++++++++++++++++---------------- 3 files changed, 126 insertions(+), 84 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index d41930f..68ca9c4 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -9,8 +9,7 @@ import time import torch -from dataclasses import asdict, dataclass -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from functools import cached_property from pathlib import Path diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index e9c965a..4786fcf 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -161,13 +161,17 @@ class AR_NAR(Base): resps_list[i] = torch.cat([resps_list[i], torch.Tensor([[self.stop_token] * n_levels]).to(device=device, dtype=torch.int16) ]) targ_list[i] = torch.cat([targ_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ]) - return super().forward( + inputs = self.inputs( text_list=text_list, proms_list=proms_list, resps_list=resps_list, targ_list=targ_list, lang_list=lang_list, - tone_list=tone_list, + tone_list=tone_list + ) + + return super().forward( + inputs=inputs, quant_levels=quant_levels, ) # is NAR @@ -183,12 +187,16 @@ class AR_NAR(Base): quant_levels = torch.full((len(text_list),), level) - logits = super().forward( + inputs = self.inputs( text_list=text_list, proms_list=proms_list, resps_list=prev_list, lang_list=lang_list, tone_list=tone_list, + ) + + logits = super().forward( + inputs=inputs, quant_levels=quant_levels, ) @@ -235,23 +243,23 @@ class AR_NAR(Base): else: resps_list = self._unsqueeze_list(sequence_list) + inputs = self.inputs( + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + lang_list=lang_list, + tone_list=tone_list, + ) + if recurrent_state is not None: logits, recurrent_state = super().forward( - text_list=text_list, - proms_list=proms_list, - resps_list=resps_list, - lang_list=lang_list, - tone_list=tone_list, - state=recurrent_state + inputs=inputs, + state=recurrent_state, ) else: logits = super().forward( - text_list=text_list, - proms_list=proms_list, - resps_list=resps_list, - lang_list=lang_list, - tone_list=tone_list, - state=recurrent_state + inputs=inputs, + state=recurrent_state, ) r = super().sample( diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 3255c0a..a953b53 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -10,6 +10,7 @@ from functools import partial from einops import rearrange from torch import Tensor, einsum, nn +from torch.nn import Embedding from torch.distributions import Categorical from torch.nn.utils.rnn import pad_sequence from torch.utils.checkpoint import checkpoint @@ -165,11 +166,13 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"): return x, m # automagically parses a batch-list and returns it as a list +""" class Embedding(nn.Embedding): def forward(self, x_list: list[Tensor]) -> list[Tensor]: if len(x_list) == 0: return [] return super().forward(torch.cat(x_list)).split([*map(len, x_list)]) +""" class MultiEmbedding(nn.Module): """ @@ -218,22 +221,18 @@ class AudioEmbedding(nn.Module): self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None - def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]: - res_list = [] - - for i, xi in enumerate(x_list): - # prom - if quant_levels is None and xi.shape[-1] > 1: - x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] ) - # AR resp - elif quant_levels is None or quant_levels[i] == 0: - x = self.embeddings[0]( xi[:, 0] ) - # NAR resp - else: - x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] ) - res_list.append(x) - - return res_list + def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor: + # prom + if quant_levels is None and xi.shape[-1] > 1: + x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] ) + # AR resp + elif quant_levels is None or quant_levels == 0: + x = self.embeddings[0]( xi[:, 0] ) + # NAR resp + else: + x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] ) + + return x class Base(nn.Module): @property @@ -302,17 +301,6 @@ class Base(nn.Module): def ignore_index(self): return -100 - @staticmethod - def _samplewise_merge_tensors(*l, sep: Tensor | None): - if sep is None: - cat = torch.cat - else: - cat = partial(_join, sep=sep) - - l = [ x for x in l if x is not None ] - - return [*map(cat, zip(*l))] - def __init__( self, n_tokens: int = 1024, @@ -638,51 +626,104 @@ class Base(nn.Module): return x, state, aux_loss - def forward( + def inputs( self, text_list: list[Tensor], proms_list: list[Tensor], resps_list: list[Tensor], targ_list: list[Tensor] | None = None, - + lang_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None, - - quant_levels: Tensor | None = None, - state: dict | list | None = None, ): device = text_list[0].device batch_size = len(text_list) - # silently ignore languages if model does not have it - if self.langs_emb is None: - lang_list = None - # inject default language - elif lang_list is None: - lang_list = [ torch.Tensor([ 0 ]).to(dtype=torch.uint8, device=device) for _ in range(batch_size) ] - - # silently ignore tones if model does not have it - if self.tones_emb is None: - tone_list = None - # inject default tone - elif tone_list is None: - tone_list = [ torch.Tensor([ 0 ]).to(dtype=torch.uint8, device=device) for _ in range(batch_size) ] + inputs = [ [] for _ in range(batch_size) ] + for i in range(batch_size): + if text_list is not None: + inputs[i].append( ( "text", text_list[i] ) ) + if proms_list is not None: + inputs[i].append( ( "prom", proms_list[i] ) ) + if resps_list is not None: + inputs[i].append( ( "resp", resps_list[i] ) ) + if targ_list is not None: + inputs[i].append( ( "targ", targ_list[i] ) ) - """ - # Typical sequence format - # To-do: integrate tasks again - - """ - x_list = self._samplewise_merge_tensors( - self.text_emb(text_list), - self.langs_emb(lang_list) if lang_list is not None else None, - self.proms_emb(proms_list), - self.tones_emb(tone_list) if tone_list is not None else None, - self.resps_emb(resps_list, quant_levels), - sep=self.sep, - ) + return inputs + def inputs_to_embeddings( + self, + inputs: list, + quant_levels: Tensor | None = None + ): + x_list = [] + for b_i in range(len(inputs)): + batch = [] + for i in range(len(inputs[b_i])): + name, input = inputs[b_i][i] + embedding = None + if name == "text": + embedding = self.text_emb( input ) + elif name == "lang": + embedding = self.langs_emb( input ) + elif name == "prom": + embedding = self.proms_emb( input ) + elif name == "tone": + embedding = self.tones_emb( input ) + elif name == "resp": + embedding = self.resps_emb( input, quant_levels[b_i] if quant_levels is not None else None ) + else: + continue + + batch.append(embedding) + + x_list.append( _join( batch, self.sep ) ) + + return x_list + + def training_targets( + self, + inputs: list, + ): + x_list = [] + for bi in range(len(inputs)): + batch = [] + for i in range(len(inputs[bi])): + name, input = inputs[bi][i] + device = input.device + + if name == "prom": + batch.append( torch.full_like(input[..., 0], self.ignore_index) ) + elif name in ["text", "lang", "tone", "targ"]: + batch.append( input ) + + x_list.append( _join( batch, torch.tensor(self.ignore_index, device=device) ) ) + + return x_list + + def forward( + self, + inputs: list, + + quant_levels: Tensor | None = None, + state: dict | list | None = None, + ): + + x_list = self.inputs_to_embeddings( inputs, quant_levels ) x, m = list_to_tensor(x_list) + + # yes, there's a better way. + training = False + for b_i in range(len(inputs)): + for i in range(len(inputs[b_i])): + name, input = inputs[b_i][i] + if name == "targ": + training = True + + + device = x.device + batch_size = len(x_list) # pad our input and mask, but retain the original length by doing it after if self.l_padding and x.shape[1] % self.l_padding != 0: @@ -709,15 +750,9 @@ class Base(nn.Module): logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ] # compute loss if the target is given - if targ_list is not None: - target_list = self._samplewise_merge_tensors( - text_list, - lang_list, - [ torch.full_like(t[..., 0], self.ignore_index) for t in proms_list ], # create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against - targ_list, - sep=torch.tensor(self.ignore_index, device=device) - ) - + if training: + target_list = self.training_targets( inputs ) + # modify only for the AR so it can properly behave like a transformer for i in range(len(target_list)): if quant_levels is not None and quant_levels[i] > 0: