Almost made a mistake
This commit is contained in:
parent
c6a38693a2
commit
f50d92ba6c
|
@ -511,6 +511,7 @@ class Base(nn.Module):
|
||||||
self.interleave = interleave
|
self.interleave = interleave
|
||||||
self.layerskip = layerskip
|
self.layerskip = layerskip
|
||||||
self.special_tasks = [ "len", "stt" ]
|
self.special_tasks = [ "len", "stt" ]
|
||||||
|
self.inject_timestep_embedding = False # results in bad output
|
||||||
|
|
||||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||||
self.langs_emb = None
|
self.langs_emb = None
|
||||||
|
@ -1026,6 +1027,12 @@ class Base(nn.Module):
|
||||||
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
||||||
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
|
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
|
||||||
if f'<{task_type}>' in get_task_symmap() and task_type not in self.special_tasks:
|
if f'<{task_type}>' in get_task_symmap() and task_type not in self.special_tasks:
|
||||||
|
# pick a random timestep
|
||||||
|
if "len" in self.capabilities and quant_level == 0:
|
||||||
|
timestep = random.random()
|
||||||
|
else:
|
||||||
|
timestep = 1.0
|
||||||
|
|
||||||
# insert the text prompt
|
# insert the text prompt
|
||||||
if text_list is not None and text_list[i] is not None:
|
if text_list is not None and text_list[i] is not None:
|
||||||
inputs[i].append( ( "text", text_list[i] ) )
|
inputs[i].append( ( "text", text_list[i] ) )
|
||||||
|
@ -1041,22 +1048,18 @@ class Base(nn.Module):
|
||||||
# insert tone token if we're trained for it
|
# insert tone token if we're trained for it
|
||||||
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
|
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
|
||||||
inputs[i].append( ( "tone", tone_list[i] ) )
|
inputs[i].append( ( "tone", tone_list[i] ) )
|
||||||
|
# insert timestep token
|
||||||
|
if "len" in self.capabilities and quant_level == 0:
|
||||||
|
# cosine schedule
|
||||||
|
dropout_mask = _dropout_mask( resps_list[i], p=math.cos(timestep * math.pi * 0.5) )
|
||||||
|
|
||||||
|
# store timestep information
|
||||||
|
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
||||||
|
# store dropout mask
|
||||||
|
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||||
# insert the current output response
|
# insert the current output response
|
||||||
if resps_list is not None and resps_list[i] is not None:
|
if resps_list is not None and resps_list[i] is not None:
|
||||||
inputs[i].append( ( "resp", resps_list[i] ) )
|
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||||
|
|
||||||
# store dropout mask
|
|
||||||
if "len" in self.capabilities and quant_level == 0:
|
|
||||||
t = random.random()
|
|
||||||
p = math.cos(t * math.pi * 0.5)
|
|
||||||
dropout_mask = _dropout_mask( resps_list[i], p=p )
|
|
||||||
|
|
||||||
inputs[i].append( ("timestep", torch.tensor([t], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
|
||||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
|
||||||
else:
|
|
||||||
# in the event it's needed (it makes shit sound worse)
|
|
||||||
#inputs[i].append( ("timestep", torch.tensor([1.0], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
|
||||||
...
|
|
||||||
|
|
||||||
# Audio length prediction task
|
# Audio length prediction task
|
||||||
# Sequence: <text><sep><rvq lvl><prom><sep><len>
|
# Sequence: <text><sep><rvq lvl><prom><sep><len>
|
||||||
|
@ -1618,12 +1621,12 @@ class Base(nn.Module):
|
||||||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||||
|
|
||||||
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
|
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
|
||||||
"""
|
|
||||||
timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ]
|
if self.inject_timestep_embedding:
|
||||||
#timesteps = [ inputs[i][-1] if timestep is not None else None for i, timestep in enumerate(timesteps) ]
|
timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ]
|
||||||
timesteps = [ self.time_emb(timestep) if timestep is not None else None for i, timestep in enumerate(timesteps) ]
|
timesteps = [ self.time_emb(timestep) if timestep is not None else None for i, timestep in enumerate(timesteps) ]
|
||||||
"""
|
else:
|
||||||
timesteps = []
|
timesteps = []
|
||||||
|
|
||||||
classifier_quant_levels = [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
classifier_quant_levels = [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user