diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index e9cefab..bcd5704 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -118,17 +118,25 @@ class AR_NAR(Base): # is training if n_levels == self.n_resp_levels: + # might be better to have this decided on the dataloader level if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None: quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) else: quant_levels = torch.Tensor([ [ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) ] for _ in range(batch_size) ]) targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target) - resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # yes I can just do min(1, l) + resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # r[..., 0] is technically correct, but only r[:, 0] gets passed through the embedding if cfg.experimental: proms_list = [ r if l == 0 else trim(r, 75 * 3) for r, l in zip(proms_list, quant_levels) ] # trim input prompt to 3 seconds - + # append stop tokens for AR + for i in range(batch_size): + if quant_levels[i] > 0: + continue + + resps_list[i] = torch.cat([resps_list[i], torch.Tensor([[self.stop_token] * n_levels]).to(device=device, dtype=torch.int16) ]) + targ_list[i] = torch.cat([targ_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ]) + return super().forward( text_list=text_list, proms_list=proms_list, @@ -294,6 +302,8 @@ def example_usage(): qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device) + cfg.hyperparameters.gradient_accumulation_steps = 1 + text_list = [ tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device), ] @@ -323,10 +333,9 @@ def example_usage(): """ model = AR_NAR(**kwargs).to(device) - #steps = 500 - #optimizer = ml.Prodigy(model.parameters(), lr=1.0) - steps = 1000 - optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) + steps = 250 + optimizer = ml.Prodigy(model.parameters(), lr=1.0) + #optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) engine = Engine(model=model, optimizer=optimizer) torch.save( { diff --git a/vall_e/models/base.py b/vall_e/models/base.py index c385ad0..e3c6ea1 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -351,30 +351,24 @@ class Base(nn.Module): # compute loss if the target is given if targ_list is not None: - ignore_sep = torch.tensor(self.ignore_index, device=device) - # create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against - prom_list = [ torch.full_like(t[..., 0], self.ignore_index) for t in proms_list ] - # remake input sequence - text_prom_list = self._samplewise_merge_tensors( + + target_list = self._samplewise_merge_tensors( text_list, lang_list, - prom_list, - sep=ignore_sep + [ torch.full_like(t[..., 0], self.ignore_index) for t in proms_list ], # create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against + targ_list, + sep=torch.tensor(self.ignore_index, device=device) ) - # process each batch - for i in range(len(text_prom_list)): - # for the AR and NAR, shift the text/input prompt into the future by 1, and ignore the rolled back token - text_prom_list[i] = text_prom_list[i].roll(-1, dims=0) - text_prom_list[i][-1] = self.ignore_index + # modify only for the AR so it can properly behave like a transformer + for i in range(len(target_list)): + if quant_levels is not None and quant_levels[i] > 0: + continue - # for the AR, shift the target response into the future by 1, and ignore the rolled back text token - if quant_levels is None or quant_levels[i] == 0: - targ_list[i] = targ_list[i].clone().roll(-1, dims=0) # clone ensures it's not an aliased copy/view of resps - targ_list[i][-1] = self.stop_token + logits[i] = logits[i][..., :-1, :] # shift the target so that token n... + target_list[i] = target_list[i][..., 1:] # predicts token n + 1 - # create the new target sequence to compute the loss against - target = torch.cat( self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) ) + target = torch.cat( target_list ) inputs = torch.cat( logits ) self.loss = dict(