diff --git a/vall_e/config.py b/vall_e/config.py
index d41930f..68ca9c4 100755
--- a/vall_e/config.py
+++ b/vall_e/config.py
@@ -9,8 +9,7 @@ import time
import torch
-from dataclasses import asdict, dataclass
-from dataclasses import dataclass, field
+from dataclasses import asdict, dataclass, field
from functools import cached_property
from pathlib import Path
diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py
index e9c965a..4786fcf 100644
--- a/vall_e/models/ar_nar.py
+++ b/vall_e/models/ar_nar.py
@@ -161,13 +161,17 @@ class AR_NAR(Base):
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([[self.stop_token] * n_levels]).to(device=device, dtype=torch.int16) ])
targ_list[i] = torch.cat([targ_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
- return super().forward(
+ inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
targ_list=targ_list,
lang_list=lang_list,
- tone_list=tone_list,
+ tone_list=tone_list
+ )
+
+ return super().forward(
+ inputs=inputs,
quant_levels=quant_levels,
)
# is NAR
@@ -183,12 +187,16 @@ class AR_NAR(Base):
quant_levels = torch.full((len(text_list),), level)
- logits = super().forward(
+ inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=prev_list,
lang_list=lang_list,
tone_list=tone_list,
+ )
+
+ logits = super().forward(
+ inputs=inputs,
quant_levels=quant_levels,
)
@@ -235,23 +243,23 @@ class AR_NAR(Base):
else:
resps_list = self._unsqueeze_list(sequence_list)
+ inputs = self.inputs(
+ text_list=text_list,
+ proms_list=proms_list,
+ resps_list=resps_list,
+ lang_list=lang_list,
+ tone_list=tone_list,
+ )
+
if recurrent_state is not None:
logits, recurrent_state = super().forward(
- text_list=text_list,
- proms_list=proms_list,
- resps_list=resps_list,
- lang_list=lang_list,
- tone_list=tone_list,
- state=recurrent_state
+ inputs=inputs,
+ state=recurrent_state,
)
else:
logits = super().forward(
- text_list=text_list,
- proms_list=proms_list,
- resps_list=resps_list,
- lang_list=lang_list,
- tone_list=tone_list,
- state=recurrent_state
+ inputs=inputs,
+ state=recurrent_state,
)
r = super().sample(
diff --git a/vall_e/models/base.py b/vall_e/models/base.py
index 3255c0a..a953b53 100755
--- a/vall_e/models/base.py
+++ b/vall_e/models/base.py
@@ -10,6 +10,7 @@ from functools import partial
from einops import rearrange
from torch import Tensor, einsum, nn
+from torch.nn import Embedding
from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint
@@ -165,11 +166,13 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
return x, m
# automagically parses a batch-list and returns it as a list
+"""
class Embedding(nn.Embedding):
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
if len(x_list) == 0:
return []
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
+"""
class MultiEmbedding(nn.Module):
"""
@@ -218,22 +221,18 @@ class AudioEmbedding(nn.Module):
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
- def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]:
- res_list = []
-
- for i, xi in enumerate(x_list):
- # prom
- if quant_levels is None and xi.shape[-1] > 1:
- x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
- # AR resp
- elif quant_levels is None or quant_levels[i] == 0:
- x = self.embeddings[0]( xi[:, 0] )
- # NAR resp
- else:
- x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
- res_list.append(x)
-
- return res_list
+ def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor:
+ # prom
+ if quant_levels is None and xi.shape[-1] > 1:
+ x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
+ # AR resp
+ elif quant_levels is None or quant_levels == 0:
+ x = self.embeddings[0]( xi[:, 0] )
+ # NAR resp
+ else:
+ x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
+
+ return x
class Base(nn.Module):
@property
@@ -302,17 +301,6 @@ class Base(nn.Module):
def ignore_index(self):
return -100
- @staticmethod
- def _samplewise_merge_tensors(*l, sep: Tensor | None):
- if sep is None:
- cat = torch.cat
- else:
- cat = partial(_join, sep=sep)
-
- l = [ x for x in l if x is not None ]
-
- return [*map(cat, zip(*l))]
-
def __init__(
self,
n_tokens: int = 1024,
@@ -638,51 +626,104 @@ class Base(nn.Module):
return x, state, aux_loss
- def forward(
+ def inputs(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor],
targ_list: list[Tensor] | None = None,
-
+
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
-
- quant_levels: Tensor | None = None,
- state: dict | list | None = None,
):
device = text_list[0].device
batch_size = len(text_list)
- # silently ignore languages if model does not have it
- if self.langs_emb is None:
- lang_list = None
- # inject default language
- elif lang_list is None:
- lang_list = [ torch.Tensor([ 0 ]).to(dtype=torch.uint8, device=device) for _ in range(batch_size) ]
-
- # silently ignore tones if model does not have it
- if self.tones_emb is None:
- tone_list = None
- # inject default tone
- elif tone_list is None:
- tone_list = [ torch.Tensor([ 0 ]).to(dtype=torch.uint8, device=device) for _ in range(batch_size) ]
+ inputs = [ [] for _ in range(batch_size) ]
+ for i in range(batch_size):
+ if text_list is not None:
+ inputs[i].append( ( "text", text_list[i] ) )
+ if proms_list is not None:
+ inputs[i].append( ( "prom", proms_list[i] ) )
+ if resps_list is not None:
+ inputs[i].append( ( "resp", resps_list[i] ) )
+ if targ_list is not None:
+ inputs[i].append( ( "targ", targ_list[i] ) )
- """
- # Typical sequence format
- # To-do: integrate tasks again
-
- """
- x_list = self._samplewise_merge_tensors(
- self.text_emb(text_list),
- self.langs_emb(lang_list) if lang_list is not None else None,
- self.proms_emb(proms_list),
- self.tones_emb(tone_list) if tone_list is not None else None,
- self.resps_emb(resps_list, quant_levels),
- sep=self.sep,
- )
+ return inputs
+ def inputs_to_embeddings(
+ self,
+ inputs: list,
+ quant_levels: Tensor | None = None
+ ):
+ x_list = []
+ for b_i in range(len(inputs)):
+ batch = []
+ for i in range(len(inputs[b_i])):
+ name, input = inputs[b_i][i]
+ embedding = None
+ if name == "text":
+ embedding = self.text_emb( input )
+ elif name == "lang":
+ embedding = self.langs_emb( input )
+ elif name == "prom":
+ embedding = self.proms_emb( input )
+ elif name == "tone":
+ embedding = self.tones_emb( input )
+ elif name == "resp":
+ embedding = self.resps_emb( input, quant_levels[b_i] if quant_levels is not None else None )
+ else:
+ continue
+
+ batch.append(embedding)
+
+ x_list.append( _join( batch, self.sep ) )
+
+ return x_list
+
+ def training_targets(
+ self,
+ inputs: list,
+ ):
+ x_list = []
+ for bi in range(len(inputs)):
+ batch = []
+ for i in range(len(inputs[bi])):
+ name, input = inputs[bi][i]
+ device = input.device
+
+ if name == "prom":
+ batch.append( torch.full_like(input[..., 0], self.ignore_index) )
+ elif name in ["text", "lang", "tone", "targ"]:
+ batch.append( input )
+
+ x_list.append( _join( batch, torch.tensor(self.ignore_index, device=device) ) )
+
+ return x_list
+
+ def forward(
+ self,
+ inputs: list,
+
+ quant_levels: Tensor | None = None,
+ state: dict | list | None = None,
+ ):
+
+ x_list = self.inputs_to_embeddings( inputs, quant_levels )
x, m = list_to_tensor(x_list)
+
+ # yes, there's a better way.
+ training = False
+ for b_i in range(len(inputs)):
+ for i in range(len(inputs[b_i])):
+ name, input = inputs[b_i][i]
+ if name == "targ":
+ training = True
+
+
+ device = x.device
+ batch_size = len(x_list)
# pad our input and mask, but retain the original length by doing it after
if self.l_padding and x.shape[1] % self.l_padding != 0:
@@ -709,15 +750,9 @@ class Base(nn.Module):
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
# compute loss if the target is given
- if targ_list is not None:
- target_list = self._samplewise_merge_tensors(
- text_list,
- lang_list,
- [ torch.full_like(t[..., 0], self.ignore_index) for t in proms_list ], # create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against
- targ_list,
- sep=torch.tensor(self.ignore_index, device=device)
- )
-
+ if training:
+ target_list = self.training_targets( inputs )
+
# modify only for the AR so it can properly behave like a transformer
for i in range(len(target_list)):
if quant_levels is not None and quant_levels[i] > 0: