Almost made a mistake
This commit is contained in:
parent
c6a38693a2
commit
f50d92ba6c
|
@ -511,6 +511,7 @@ class Base(nn.Module):
|
|||
self.interleave = interleave
|
||||
self.layerskip = layerskip
|
||||
self.special_tasks = [ "len", "stt" ]
|
||||
self.inject_timestep_embedding = False # results in bad output
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.langs_emb = None
|
||||
|
@ -1026,6 +1027,12 @@ class Base(nn.Module):
|
|||
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
||||
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
|
||||
if f'<{task_type}>' in get_task_symmap() and task_type not in self.special_tasks:
|
||||
# pick a random timestep
|
||||
if "len" in self.capabilities and quant_level == 0:
|
||||
timestep = random.random()
|
||||
else:
|
||||
timestep = 1.0
|
||||
|
||||
# insert the text prompt
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
|
@ -1041,22 +1048,18 @@ class Base(nn.Module):
|
|||
# insert tone token if we're trained for it
|
||||
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
|
||||
inputs[i].append( ( "tone", tone_list[i] ) )
|
||||
# insert timestep token
|
||||
if "len" in self.capabilities and quant_level == 0:
|
||||
# cosine schedule
|
||||
dropout_mask = _dropout_mask( resps_list[i], p=math.cos(timestep * math.pi * 0.5) )
|
||||
|
||||
# store timestep information
|
||||
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
||||
# store dropout mask
|
||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||
# insert the current output response
|
||||
if resps_list is not None and resps_list[i] is not None:
|
||||
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||
|
||||
# store dropout mask
|
||||
if "len" in self.capabilities and quant_level == 0:
|
||||
t = random.random()
|
||||
p = math.cos(t * math.pi * 0.5)
|
||||
dropout_mask = _dropout_mask( resps_list[i], p=p )
|
||||
|
||||
inputs[i].append( ("timestep", torch.tensor([t], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||
else:
|
||||
# in the event it's needed (it makes shit sound worse)
|
||||
#inputs[i].append( ("timestep", torch.tensor([1.0], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
||||
...
|
||||
|
||||
# Audio length prediction task
|
||||
# Sequence: <text><sep><rvq lvl><prom><sep><len>
|
||||
|
@ -1618,12 +1621,12 @@ class Base(nn.Module):
|
|||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||
|
||||
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
|
||||
"""
|
||||
timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ]
|
||||
#timesteps = [ inputs[i][-1] if timestep is not None else None for i, timestep in enumerate(timesteps) ]
|
||||
timesteps = [ self.time_emb(timestep) if timestep is not None else None for i, timestep in enumerate(timesteps) ]
|
||||
"""
|
||||
timesteps = []
|
||||
|
||||
if self.inject_timestep_embedding:
|
||||
timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ]
|
||||
timesteps = [ self.time_emb(timestep) if timestep is not None else None for i, timestep in enumerate(timesteps) ]
|
||||
else:
|
||||
timesteps = []
|
||||
|
||||
classifier_quant_levels = [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user