diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index ff459cb..611f69f 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -255,6 +255,7 @@ class AR_NAR(Base): prev_list = resps_list for timestep, steps_until_x0 in tqdm(zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))), desc="NAR Masked", disable=disable_tqdm, total=max_steps): + #for noise_p, annealed_temperature, temperature, cfg_strength in zip( manual_ratios, manual_temp, manual_samp_temp, manual_cfg ): annealing = (steps_until_x0 / max_steps) # get noise level, per cosine scheduling noise_p = math.cos( timestep * math.pi * 0.5 ) @@ -264,7 +265,7 @@ class AR_NAR(Base): resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ] # boolean mask is_masked = [ resps == self.stop_token for resps in resps_list ] - + # timestep inputs time_list = [ timestep for _ in range(batch_size) ] # setup inputs @@ -314,6 +315,7 @@ class AR_NAR(Base): ) # retrieves unfiltered logits + """ unfiltered_sampled = super().sample( logits=logits, prev_list=prev_list, @@ -322,6 +324,7 @@ class AR_NAR(Base): temperature=0.0, **sampling_kwargs, ) + """ # update previous list of tokens prev_list = resps_list @@ -333,7 +336,7 @@ class AR_NAR(Base): # keep unmasked tokens resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ] # update scores (conjugated to put the worst scores at the top) - scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in unfiltered_sampled.scores ] + scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ] # refinement step if refine_on_stop: @@ -374,8 +377,10 @@ class AR_NAR(Base): # greedy sample from the sequence refined_list = [ logit.argmax(dim=-1) for logit in logits ] + """ if cfg.experimental and max_steps > 0: print( timestep, steps_until_x0, noise_p, resps_list, scores ) + """ return resps_list diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 43a18b5..fd5f767 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1020,17 +1020,26 @@ class Base(nn.Module): # insert tone token if we're trained for it if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None: inputs[i].append( ( "tone", tone_list[i] ) ) + # it does not seem to matter whether this is provided or not, I assume the model attends more to the amount of masked tokens in the sequence + """ # insert timestep token if timestep is not None: # store timestep information inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) ) + """ # insert the current output response if resps_list is not None and resps_list[i] is not None: inputs[i].append( ( "resp", resps_list[i] ) ) - # store dropout mask (if training) + # store dropout mask (if training, as this gets used later to mask the input embeddings if provided) if timestep is not None and self.training: - dropout_mask = _dropout_mask( resps_list[i], p=math.cos(timestep * math.pi * 0.5) ) + # a paper said to use a fixed masking ratio for training + """ + # cosine scheduled timestep => masking ratio + p = math.cos(timestep * math.pi * 0.5) + """ + p = 0.8 + dropout_mask = _dropout_mask( resps_list[i], p ) inputs[i].append( ("dropout_mask", dropout_mask ) ) # Audio length prediction task