do not pass timestep token/embedding since it doesn't seem to matter at all after all, fixed training masking rate to 80% because a paper said so
This commit is contained in:
parent
caf721c67b
commit
8286aa54c8
|
@ -255,6 +255,7 @@ class AR_NAR(Base):
|
|||
prev_list = resps_list
|
||||
|
||||
for timestep, steps_until_x0 in tqdm(zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))), desc="NAR Masked", disable=disable_tqdm, total=max_steps):
|
||||
#for noise_p, annealed_temperature, temperature, cfg_strength in zip( manual_ratios, manual_temp, manual_samp_temp, manual_cfg ):
|
||||
annealing = (steps_until_x0 / max_steps)
|
||||
# get noise level, per cosine scheduling
|
||||
noise_p = math.cos( timestep * math.pi * 0.5 )
|
||||
|
@ -264,7 +265,7 @@ class AR_NAR(Base):
|
|||
resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ]
|
||||
# boolean mask
|
||||
is_masked = [ resps == self.stop_token for resps in resps_list ]
|
||||
|
||||
# timestep inputs
|
||||
time_list = [ timestep for _ in range(batch_size) ]
|
||||
|
||||
# setup inputs
|
||||
|
@ -314,6 +315,7 @@ class AR_NAR(Base):
|
|||
)
|
||||
|
||||
# retrieves unfiltered logits
|
||||
"""
|
||||
unfiltered_sampled = super().sample(
|
||||
logits=logits,
|
||||
prev_list=prev_list,
|
||||
|
@ -322,6 +324,7 @@ class AR_NAR(Base):
|
|||
temperature=0.0,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
"""
|
||||
# update previous list of tokens
|
||||
prev_list = resps_list
|
||||
|
||||
|
@ -333,7 +336,7 @@ class AR_NAR(Base):
|
|||
# keep unmasked tokens
|
||||
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
||||
# update scores (conjugated to put the worst scores at the top)
|
||||
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in unfiltered_sampled.scores ]
|
||||
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ]
|
||||
|
||||
# refinement step
|
||||
if refine_on_stop:
|
||||
|
@ -374,8 +377,10 @@ class AR_NAR(Base):
|
|||
# greedy sample from the sequence
|
||||
refined_list = [ logit.argmax(dim=-1) for logit in logits ]
|
||||
|
||||
"""
|
||||
if cfg.experimental and max_steps > 0:
|
||||
print( timestep, steps_until_x0, noise_p, resps_list, scores )
|
||||
"""
|
||||
|
||||
return resps_list
|
||||
|
||||
|
|
|
@ -1020,17 +1020,26 @@ class Base(nn.Module):
|
|||
# insert tone token if we're trained for it
|
||||
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
|
||||
inputs[i].append( ( "tone", tone_list[i] ) )
|
||||
# it does not seem to matter whether this is provided or not, I assume the model attends more to the amount of masked tokens in the sequence
|
||||
"""
|
||||
# insert timestep token
|
||||
if timestep is not None:
|
||||
# store timestep information
|
||||
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
||||
"""
|
||||
# insert the current output response
|
||||
if resps_list is not None and resps_list[i] is not None:
|
||||
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||
|
||||
# store dropout mask (if training)
|
||||
# store dropout mask (if training, as this gets used later to mask the input embeddings if provided)
|
||||
if timestep is not None and self.training:
|
||||
dropout_mask = _dropout_mask( resps_list[i], p=math.cos(timestep * math.pi * 0.5) )
|
||||
# a paper said to use a fixed masking ratio for training
|
||||
"""
|
||||
# cosine scheduled timestep => masking ratio
|
||||
p = math.cos(timestep * math.pi * 0.5)
|
||||
"""
|
||||
p = 0.8
|
||||
dropout_mask = _dropout_mask( resps_list[i], p )
|
||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||
|
||||
# Audio length prediction task
|
||||
|
|
Loading…
Reference in New Issue
Block a user