diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 21a7e81..7c86d0b 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -511,6 +511,7 @@ class Base(nn.Module): self.interleave = interleave self.layerskip = layerskip self.special_tasks = [ "len", "stt" ] + self.inject_timestep_embedding = False # results in bad output self.text_emb = Embedding(n_text_tokens, d_model) self.langs_emb = None @@ -1026,6 +1027,12 @@ class Base(nn.Module): # Sequence: # prom /may/ include tokens inside to help guide things, per SpeechX if f'<{task_type}>' in get_task_symmap() and task_type not in self.special_tasks: + # pick a random timestep + if "len" in self.capabilities and quant_level == 0: + timestep = random.random() + else: + timestep = 1.0 + # insert the text prompt if text_list is not None and text_list[i] is not None: inputs[i].append( ( "text", text_list[i] ) ) @@ -1041,22 +1048,18 @@ class Base(nn.Module): # insert tone token if we're trained for it if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None: inputs[i].append( ( "tone", tone_list[i] ) ) + # insert timestep token + if "len" in self.capabilities and quant_level == 0: + # cosine schedule + dropout_mask = _dropout_mask( resps_list[i], p=math.cos(timestep * math.pi * 0.5) ) + + # store timestep information + inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) ) + # store dropout mask + inputs[i].append( ("dropout_mask", dropout_mask ) ) # insert the current output response if resps_list is not None and resps_list[i] is not None: inputs[i].append( ( "resp", resps_list[i] ) ) - - # store dropout mask - if "len" in self.capabilities and quant_level == 0: - t = random.random() - p = math.cos(t * math.pi * 0.5) - dropout_mask = _dropout_mask( resps_list[i], p=p ) - - inputs[i].append( ("timestep", torch.tensor([t], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) ) - inputs[i].append( ("dropout_mask", dropout_mask ) ) - else: - # in the event it's needed (it makes shit sound worse) - #inputs[i].append( ("timestep", torch.tensor([1.0], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) ) - ... # Audio length prediction task # Sequence: @@ -1618,12 +1621,12 @@ class Base(nn.Module): position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ] - """ - timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ] - #timesteps = [ inputs[i][-1] if timestep is not None else None for i, timestep in enumerate(timesteps) ] - timesteps = [ self.time_emb(timestep) if timestep is not None else None for i, timestep in enumerate(timesteps) ] - """ - timesteps = [] + + if self.inject_timestep_embedding: + timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ] + timesteps = [ self.time_emb(timestep) if timestep is not None else None for i, timestep in enumerate(timesteps) ] + else: + timesteps = [] classifier_quant_levels = [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]