do not pass timestep token/embedding since it doesn't seem to matter at all after all, fixed training masking rate to 80% because a paper said so

This commit is contained in:
mrq 2024-11-13 09:07:10 -06:00
parent caf721c67b
commit 8286aa54c8
2 changed files with 18 additions and 4 deletions

View File

@ -255,6 +255,7 @@ class AR_NAR(Base):
prev_list = resps_list
for timestep, steps_until_x0 in tqdm(zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))), desc="NAR Masked", disable=disable_tqdm, total=max_steps):
#for noise_p, annealed_temperature, temperature, cfg_strength in zip( manual_ratios, manual_temp, manual_samp_temp, manual_cfg ):
annealing = (steps_until_x0 / max_steps)
# get noise level, per cosine scheduling
noise_p = math.cos( timestep * math.pi * 0.5 )
@ -264,7 +265,7 @@ class AR_NAR(Base):
resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ]
# boolean mask
is_masked = [ resps == self.stop_token for resps in resps_list ]
# timestep inputs
time_list = [ timestep for _ in range(batch_size) ]
# setup inputs
@ -314,6 +315,7 @@ class AR_NAR(Base):
)
# retrieves unfiltered logits
"""
unfiltered_sampled = super().sample(
logits=logits,
prev_list=prev_list,
@ -322,6 +324,7 @@ class AR_NAR(Base):
temperature=0.0,
**sampling_kwargs,
)
"""
# update previous list of tokens
prev_list = resps_list
@ -333,7 +336,7 @@ class AR_NAR(Base):
# keep unmasked tokens
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
# update scores (conjugated to put the worst scores at the top)
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in unfiltered_sampled.scores ]
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ]
# refinement step
if refine_on_stop:
@ -374,8 +377,10 @@ class AR_NAR(Base):
# greedy sample from the sequence
refined_list = [ logit.argmax(dim=-1) for logit in logits ]
"""
if cfg.experimental and max_steps > 0:
print( timestep, steps_until_x0, noise_p, resps_list, scores )
"""
return resps_list

View File

@ -1020,17 +1020,26 @@ class Base(nn.Module):
# insert tone token if we're trained for it
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
inputs[i].append( ( "tone", tone_list[i] ) )
# it does not seem to matter whether this is provided or not, I assume the model attends more to the amount of masked tokens in the sequence
"""
# insert timestep token
if timestep is not None:
# store timestep information
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
"""
# insert the current output response
if resps_list is not None and resps_list[i] is not None:
inputs[i].append( ( "resp", resps_list[i] ) )
# store dropout mask (if training)
# store dropout mask (if training, as this gets used later to mask the input embeddings if provided)
if timestep is not None and self.training:
dropout_mask = _dropout_mask( resps_list[i], p=math.cos(timestep * math.pi * 0.5) )
# a paper said to use a fixed masking ratio for training
"""
# cosine scheduled timestep => masking ratio
p = math.cos(timestep * math.pi * 0.5)
"""
p = 0.8
dropout_mask = _dropout_mask( resps_list[i], p )
inputs[i].append( ("dropout_mask", dropout_mask ) )
# Audio length prediction task