Almost made a mistake

This commit is contained in:
mrq 2024-11-09 18:12:54 -06:00
parent c6a38693a2
commit f50d92ba6c

View File

@ -511,6 +511,7 @@ class Base(nn.Module):
self.interleave = interleave self.interleave = interleave
self.layerskip = layerskip self.layerskip = layerskip
self.special_tasks = [ "len", "stt" ] self.special_tasks = [ "len", "stt" ]
self.inject_timestep_embedding = False # results in bad output
self.text_emb = Embedding(n_text_tokens, d_model) self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None self.langs_emb = None
@ -1026,6 +1027,12 @@ class Base(nn.Module):
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp> # Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
# prom /may/ include <task> tokens inside to help guide things, per SpeechX # prom /may/ include <task> tokens inside to help guide things, per SpeechX
if f'<{task_type}>' in get_task_symmap() and task_type not in self.special_tasks: if f'<{task_type}>' in get_task_symmap() and task_type not in self.special_tasks:
# pick a random timestep
if "len" in self.capabilities and quant_level == 0:
timestep = random.random()
else:
timestep = 1.0
# insert the text prompt # insert the text prompt
if text_list is not None and text_list[i] is not None: if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) ) inputs[i].append( ( "text", text_list[i] ) )
@ -1041,22 +1048,18 @@ class Base(nn.Module):
# insert tone token if we're trained for it # insert tone token if we're trained for it
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None: if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
inputs[i].append( ( "tone", tone_list[i] ) ) inputs[i].append( ( "tone", tone_list[i] ) )
# insert timestep token
if "len" in self.capabilities and quant_level == 0:
# cosine schedule
dropout_mask = _dropout_mask( resps_list[i], p=math.cos(timestep * math.pi * 0.5) )
# store timestep information
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
# store dropout mask
inputs[i].append( ("dropout_mask", dropout_mask ) )
# insert the current output response # insert the current output response
if resps_list is not None and resps_list[i] is not None: if resps_list is not None and resps_list[i] is not None:
inputs[i].append( ( "resp", resps_list[i] ) ) inputs[i].append( ( "resp", resps_list[i] ) )
# store dropout mask
if "len" in self.capabilities and quant_level == 0:
t = random.random()
p = math.cos(t * math.pi * 0.5)
dropout_mask = _dropout_mask( resps_list[i], p=p )
inputs[i].append( ("timestep", torch.tensor([t], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
inputs[i].append( ("dropout_mask", dropout_mask ) )
else:
# in the event it's needed (it makes shit sound worse)
#inputs[i].append( ("timestep", torch.tensor([1.0], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
...
# Audio length prediction task # Audio length prediction task
# Sequence: <text><sep><rvq lvl><prom><sep><len> # Sequence: <text><sep><rvq lvl><prom><sep><len>
@ -1618,12 +1621,12 @@ class Base(nn.Module):
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ] tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
"""
timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ] if self.inject_timestep_embedding:
#timesteps = [ inputs[i][-1] if timestep is not None else None for i, timestep in enumerate(timesteps) ] timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ]
timesteps = [ self.time_emb(timestep) if timestep is not None else None for i, timestep in enumerate(timesteps) ] timesteps = [ self.time_emb(timestep) if timestep is not None else None for i, timestep in enumerate(timesteps) ]
""" else:
timesteps = [] timesteps = []
classifier_quant_levels = [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ] classifier_quant_levels = [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]