From a6ae344e5bcaf1ac9bffdc0805115475b4060a85 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 12 Sep 2023 16:04:45 -0500 Subject: [PATCH] some comments --- vall_e/models/ar_nar.py | 6 +++--- vall_e/models/base.py | 13 +++++++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 1d45694..f8d2521 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -94,9 +94,9 @@ class AR_NAR(Base): # is training if n_levels == self.n_resp_levels: - quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) - targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] - resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] + quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) + targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target) + resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # yes I can just do min(1, l) quant_levels.to(device=device) return super().forward( diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 1bd13d6..2e9dad4 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -375,15 +375,20 @@ class Base(nn.Module): x = x[:, -1, :].unsqueeze(1) if self.arch_type == "transformer": - x = self.sin_emb.add_pe(x) + # ensures we specify a quant_level for the transformer implementation's AdaLN l = torch.zeros((batch_size,), dtype=torch.int32) if quant_levels is None else quant_levels l = l.to(device) + # inject position information + x = self.sin_emb.add_pe(x) + # pass our inputs through the transformer for block in self.blocks: x = block(x, m, l) elif self.arch_type == "retnet": + # pass our inputs through the RetNet x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True) - + + # output projection layer with masking x = self.classifier(x) * m # Remove padding @@ -399,10 +404,10 @@ class Base(nn.Module): # process each batch for i in range(len(text_prom_list)): - # for the AR, shift the text/input prompt into the future by 1, and ignore the rolled back text token + # for the AR, shift the text/input prompt and target prompt into the future by 1, and ignore the rolled back text token if quant_levels is None or quant_levels[i] == 0: text_prom_list[i] = text_prom_list[i].roll(-1, dims=0) - targ_list[i] = targ_list[i].clone().roll(-1, dims=0) + targ_list[i] = targ_list[i].clone().roll(-1, dims=0) # clone ensures it's not an aliased copy/view of resps text_prom_list[i][-1] = self.ignore_index targ_list[i][-1] = self.stop_token