diff --git a/data/test/test.ar.recon.wav b/data/test/test.ar.recon.wav index dcbc959..3fd7f77 100644 Binary files a/data/test/test.ar.recon.wav and b/data/test/test.ar.recon.wav differ diff --git a/data/test/test.nar.init.wav b/data/test/test.nar.init.wav index 85d2d60..7709797 100644 Binary files a/data/test/test.nar.init.wav and b/data/test/test.nar.init.wav differ diff --git a/data/test/test.nar.recon.wav b/data/test/test.nar.recon.wav index b54416f..7de0c4b 100644 Binary files a/data/test/test.nar.recon.wav and b/data/test/test.nar.recon.wav differ diff --git a/vall_e/vall_e/__init__.py b/vall_e/vall_e/__init__.py index e69de29..ab095f0 100644 --- a/vall_e/vall_e/__init__.py +++ b/vall_e/vall_e/__init__.py @@ -0,0 +1,2 @@ +from .ar import AR +from .nar import NAR diff --git a/vall_e/vall_e/ar.py b/vall_e/vall_e/ar.py index 5fe28b6..bda1bd4 100644 --- a/vall_e/vall_e/ar.py +++ b/vall_e/vall_e/ar.py @@ -28,12 +28,12 @@ class AR(Base): def forward( self, text_list: list[Tensor], - prom_list: list[Tensor], + proms_list: list[Tensor], resp_list: list[Tensor], ): return super().forward( text_list, - prom_list, + proms_list, resp_list, resp_list, quant_level=0, @@ -44,7 +44,7 @@ class AR(Base): def generate( self, text_list: list[Tensor], - prom_list: list[Tensor], + proms_list: list[Tensor], max_steps: int = 1000, ): device = text_list[0].device @@ -53,7 +53,7 @@ class AR(Base): ] stopped = [False] * len(text_list) for _ in trange(max_steps): - r = super().forward(text_list, prom_list, resp_list) + r = super().forward(text_list, proms_list, resp_list) for i, ri in enumerate(r): if ri.item() == self.stop_token: stopped[i] = True @@ -65,7 +65,10 @@ class AR(Base): def example_usage(): + from functools import partial + import soundfile + from einops import repeat device = "cuda" @@ -79,9 +82,10 @@ def example_usage(): torch.tensor([2, 3], device=device), ] - prom_list = [ - torch.tensor([1, 2, 3], device=device), - torch.tensor([2, 3], device=device), + x8 = partial(repeat, pattern="t -> t q", q=8) + proms_list = [ + x8(torch.tensor([1, 2, 3], device=device)), + x8(torch.tensor([2, 3], device=device)), ] resp_list = [ @@ -91,7 +95,7 @@ def example_usage(): out = model.generate( text_list, - prom_list, + proms_list, max_steps=200, ) @@ -101,7 +105,7 @@ def example_usage(): for i in range(100): optimizer.zero_grad() - _ = model(text_list, prom_list, resp_list) + _ = model(text_list, proms_list, resp_list) losses = model.loss sum(losses.values()).backward() @@ -110,7 +114,7 @@ def example_usage(): if i % 20 == 0: print(f"iter={i}, {losses}.") - out = model.generate(text_list, prom_list, max_steps=200) + out = model.generate(text_list, proms_list, max_steps=200) print(qnt) print(out) diff --git a/vall_e/vall_e/base.py b/vall_e/vall_e/base.py index 34870b7..7566f67 100644 --- a/vall_e/vall_e/base.py +++ b/vall_e/vall_e/base.py @@ -178,10 +178,28 @@ class Block(nn.Sequential): class Embedding(nn.Embedding): - def forward(self, x: list[Tensor]) -> list[Tensor]: - if len(x) == 0: + def forward(self, x_list: list[Tensor]) -> list[Tensor]: + if len(x_list) == 0: return [] - return super().forward(torch.cat(x)).split([*map(len, x)]) + return super().forward(torch.cat(x_list)).split([*map(len, x_list)]) + + +class MultiEmbedding(nn.Module): + def __init__(self, num_embeddings, embedding_dim, n_levels): + super().__init__() + self.n_levels = n_levels + self.num_embeddings = num_embeddings + self.emb = nn.Embedding(n_levels * num_embeddings, embedding_dim) + + def forward(self, x_list: list[Tensor]) -> list[Tensor]: + if len(x_list) == 0: + return [] + x = torch.cat(x_list) + assert x.shape[1] == self.n_levels + w = rearrange(self.emb.weight, "(q k) d -> q k d", q=self.n_levels) + x = F.one_hot(x, num_classes=self.num_embeddings).float() # n q -> n q k + x = einsum("q k d, n q k -> n d", w, x) + return x.split([*map(len, x_list)]) def _join(x: tuple[Tensor], sep: Tensor): @@ -216,6 +234,8 @@ class Base(nn.Module): n_heads: int = 8, n_layers: int = 12, p_dropout: float = 0.1, + n_prom_levels: int = 8, + resp_loss_only: bool = False, ): super().__init__() self.n_tokens = n_tokens @@ -227,7 +247,10 @@ class Base(nn.Module): n_resp_tokens = n_tokens + n_stop_tokens self.text_emb = Embedding(n_tokens, d_model) - self.prom_emb = Embedding(n_tokens, d_model) + + # It's not clear whether the whole prom are used or only the first level quantization + # Just use all of them as it is more sufficient and we don't need to sample it, or do we? + self.prom_emb = MultiEmbedding(n_tokens, d_model, n_levels=n_prom_levels) # +1 to include the stop token self.resp_embs = nn.ModuleList( @@ -243,6 +266,8 @@ class Base(nn.Module): self.classifier = nn.Linear(d_model, n_resp_tokens) + self.resp_loss_only = resp_loss_only + @property def stop_token(self): if not self.use_stop_token: @@ -265,7 +290,7 @@ class Base(nn.Module): def forward( self, text_list: list[Tensor], - prom_list: list[Tensor], + proms_list: list[Tensor], resp_list: list[Tensor], targ_list: list[Tensor] | None = None, quant_level: int = 0, @@ -278,7 +303,7 @@ class Base(nn.Module): def forward( self, text_list: list[Tensor], - prom_list: list[Tensor], + proms_list: list[Tensor], resp_list: list[Tensor], targ_list: list[Tensor] | None = None, quant_level: int = 0, @@ -290,7 +315,7 @@ class Base(nn.Module): def forward( self, text_list: list[Tensor], - prom_list: list[Tensor], + proms_list: list[Tensor], resp_list: list[Tensor], targ_list: list[Tensor] | None = None, quant_level: int = 0, @@ -300,7 +325,7 @@ class Base(nn.Module): """ Args: text_list: [t] * b - prom_list: [t'] * b + proms_list: [t' k] * b resp_list: [t''] * b, one quantization level only targ_list: [t''] * b, one quantization level only, when given, loss will be computed quant_level: specify which quant_level to feed forward, used in NAR mode. @@ -311,7 +336,7 @@ class Base(nn.Module): """ x_list = self._samplewise_merge_tensors( self.text_emb(text_list), - self.prom_emb(prom_list), + self.prom_emb(proms_list), self.resp_embs[quant_level](resp_list), sep=self.sep, ) @@ -334,14 +359,21 @@ class Base(nn.Module): device = h.device ignore_sep = torch.tensor(self.ignore_index, device=device) + + # Predict the first level prom + prom_list = [t[..., 0] for t in proms_list] text_prom_list = self._samplewise_merge_tensors( text_list, prom_list, sep=ignore_sep ) # Make every token earlier as it is future that is unknown + # If we don't want compute loss, set all to ignored for i in range(len(text_prom_list)): - text_prom_list[i] = text_prom_list[i].roll(-1, dims=0) - text_prom_list[i][-1] = self.ignore_index + if self.resp_loss_only: + text_prom_list[i][:] = self.ignore_index + else: + text_prom_list[i] = text_prom_list[i].roll(-1, dims=0) + text_prom_list[i][-1] = self.ignore_index if shift_targ_list: # Also make target earlier if in autoregressive mode diff --git a/vall_e/vall_e/nar.py b/vall_e/vall_e/nar.py index 39efde4..b69d8cb 100644 --- a/vall_e/vall_e/nar.py +++ b/vall_e/vall_e/nar.py @@ -21,7 +21,7 @@ class NAR(Base): def forward( self, text_list: list[Tensor], - prom_list: list[Tensor], + proms_list: list[Tensor], *, resp_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None, @@ -29,7 +29,7 @@ class NAR(Base): """ Args: text_list: [t] * b - prom_list: [t'] * b + proms_list: [t' k] * b resp_list: [t'] * b, quants at level 0. resps_list: [t''] * b, 8 quantization levels for training. Returns: @@ -52,7 +52,7 @@ class NAR(Base): for i in range(self.n_levels): hyp_resp_list = super().forward( text_list, - prom_list, + proms_list, hyp_resp_lists[-1], return_all_resp=True, shift_targ_list=False, @@ -70,7 +70,7 @@ class NAR(Base): next_resp_list = [o[..., i + 1] for o in resps_list] hyp_resp_list = super().forward( text_list, - prom_list, + proms_list, resp_list, next_resp_list, return_all_resp=True, @@ -90,7 +90,10 @@ class NAR(Base): def example_usage(): + from functools import partial + import soundfile + from einops import repeat from ..emb.qnt import decode from ..utils import gather_attribute @@ -107,9 +110,10 @@ def example_usage(): torch.tensor([2, 3], device=device), ] - prom_list = [ - torch.tensor([1, 2, 3], device=device), - torch.tensor([2, 3], device=device), + x8 = partial(repeat, pattern="t -> t q", q=8) + proms_list = [ + x8(torch.tensor([1, 2, 3], device=device)), + x8(torch.tensor([2, 3], device=device)), ] resp_list = [ @@ -118,13 +122,11 @@ def example_usage(): ] resps_list = [ - torch.tensor([1, 2, 3], device=device) - .unsqueeze(-1) - .repeat_interleave(8, dim=-1), + x8(torch.tensor([1, 2, 3], device=device)), resps.t().to(device), ] - out = model(text_list, prom_list, resp_list=resp_list) + out = model(text_list, proms_list, resp_list=resp_list) codes = rearrange(out[1], "t k -> 1 k t") print(codes) wavs, sr = decode(codes) @@ -134,7 +136,7 @@ def example_usage(): for i in range(100): optimizer.zero_grad() - _ = model(text_list, prom_list, resps_list=resps_list) + _ = model(text_list, proms_list, resps_list=resps_list) losses = gather_attribute(model, "loss") loss = sum(losses.values()) @@ -146,7 +148,7 @@ def example_usage(): stats["loss"] = loss.item() print(f"iter={i}, {stats}.") - out = model(text_list, prom_list, resp_list=resp_list) + out = model(text_list, proms_list, resp_list=resp_list) codes = rearrange(out[1], "t k -> 1 k t") wavs, sr = decode(codes) soundfile.write("data/test/test.nar.recon.wav", wavs.cpu()[0, 0], sr)