diff --git a/vall_e/config.py b/vall_e/config.py index 4a4d71a..d5eb5d1 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -260,7 +260,7 @@ class ModelExperimentalSettings: masking_train_p: float = 0.0 # odds of training with masking masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on - masking_separate_embeddings: bool = False + masking_separate_embeddings: bool = False # to-do: explain # classifier-free guidance shit cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 3648e98..ca4cd77 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -883,6 +883,10 @@ def example_usage(): import numpy as np import re + + cfg.model.experimental.masking_train_p = 0.5 + cfg.hyperparameters.batch_size = 1 + cfg.hyperparameters.gradient_accumulation_steps = 1 setup_logging() @@ -896,7 +900,6 @@ def example_usage(): text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") batch_size = cfg.hyperparameters.batch_size - cfg.model.experimental.masking_train_p = 1.0 text_list = [ text ] * batch_size proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] * batch_size diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 0f49c59..a76d64f 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -46,6 +46,9 @@ LossStats = namedtuple('LossStats', ['loss', 'stats']) from ..utils.pattern import DelayedPatternProvider, VALLEPattern """ +summed_embeddings_task = [ "stt" ] +special_tasks = [ "len", "stt" ] + def _dropout_mask( input, p=None ): # cosine scheduling if p is None: @@ -182,78 +185,26 @@ class AudioEmbedding(nn.Module): l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token) token_dim: int, # dimensionality of the embedding sums: bool = True, # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better) - external_mode: str | None = None, # "exclusive" | "inclusive", whether to include the original audio backend's embeddings - - capabilities: list[str] | None = None, # helper shit + l_names: list[str] = [], # names to map to indices ): super().__init__() # array of embeddings # proms are [0, resp_levels] # resp are split to where [0] is for the AR, and [1:] are reserved for NAR - # + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) # further experimentation is needed to see if this actually is useful self.sums = sums + # + self.names = l_names - self.external_mode = external_mode - self.capabilities = capabilities - - # set initial weights to zero - if self.external_mode == "inclusive": - for i, embedding in enumerate(self.embeddings): - embedding.weight = torch.nn.Parameter(torch.zeros( embedding.weight.shape )) - - def external_embeddings(self, input: Tensor, quant_level: int | None = None ) -> Tensor: - if quant_level is None: - quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1 - - # for AR, trim any stop tokens - has_stop_token = False - - # this block apparently doesn't work - """ - if quant_level == 0: - stop_token = self.embeddings[0].weight.shape[0] - 1 - stop_token_indices = (input == stop_token).nonzero() - has_stop_token = len(stop_token_indices) > 0 - - if has_stop_token: - input = input[:stop_token_indices.min().item()] - """ - has_stop_token = False - - if quant_level == 0: - stop_token = self.embeddings[0].weight.shape[0] - 1 - has_stop_token = input[-1] == stop_token - - if has_stop_token: - input = input[:-1] - - # get external embedding - embedding = encode_as_embedding( input, quant_level, sums=self.sums ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype) - # resize if necessary (in case the external embeddings do not match our model dim) - embedding = ml.resize_weight( embedding, self.embeddings[quant_level].weight.shape[-1], dim=-1, random=False ) - - # reintroduce stop token - if has_stop_token: - stop_token = self.internal_forward( torch.tensor([stop_token]).to(device=input.device, dtype=torch.int16), 0 ) - embedding = torch.concat( [ embedding, stop_token ] ) - - return embedding - - def internal_forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, sums = None ) -> Tensor: - if offset is None: - # prom - if self.capabilities is None: - offset = 0 - elif "nar" not in self.capabilities: - offset = 0 - elif quant_level > 0: - offset = 1 - + def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, name: str | None = None, sums = None ) -> Tensor: if sums is None: sums = self.sums + # handle mapping from name + if name in self.names: + offset = self.names.index( name ) + if quant_level is None: quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1 @@ -265,17 +216,6 @@ class AudioEmbedding(nn.Module): return x - def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, sums = None ) -> Tensor: - x = self.internal_forward( xi, offset = offset, quant_level = quant_level, sums = sums ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None - - if self.external_mode and xi.shape[0] > 0: - external_embeddings = self.external_embeddings( xi, quant_level = quant_level ) - if self.external_mode == "exclusive": - return external_embeddings - x += external_embeddings - - return x - # time-step embedding # for the NAR-len, since it probably most likely requires encoding the timestep class TimeEmbedding(nn.Module): @@ -304,14 +244,32 @@ class Classifiers(nn.Module): self, l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token) token_dim: int, # dimensionality of the embedding + l_names: list[str] | None = None, # list of names to map to each classifier ): super().__init__() self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens) for n_tokens in l_tokens]) + self.names = l_names - def forward(self, xi: Tensor, levels: list[int] ) -> Tensor: + def indices( + self, + names + ): + if isinstance( names[-1], int ): + return names + return [ self.names.index(name) for name in names ] + + def forward(self, xi: Tensor, levels: list[int] | None = None, names: list[str] | None = None ) -> Tensor: dtype = xi.dtype device = xi.device + if levels and isinstance( levels[-1], str ): + names = levels + levels = [] + + # map names to levels + if names and not levels: + levels = [ self.names.index(name) for name in names ] + xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ] # pad if needed # to-do: validate that this causes ZERO issues @@ -349,11 +307,11 @@ class Metrics(nn.Module): ignore_index=ignore_index, ) for n_tokens in l_tokens ]) - def calc_accuracy( self, inputs, targets, quant_levels ): - return sum( [ self.accuracy[l]( input[:, :self.accuracy[l].num_classes], target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs ) + def calc_accuracy( self, inputs, targets, classifier_levels ): + return sum( [ self.accuracy[l]( input[:, :self.accuracy[l].num_classes], target ) for target, input, l in zip( targets, inputs, classifier_levels ) ] ) / len( inputs ) - def calc_precision( self, inputs, targets, quant_levels ): - return sum( [ self.precision[l]( input[:, :self.precision[l].num_classes], target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs ) + def calc_precision( self, inputs, targets, classifier_levels ): + return sum( [ self.precision[l]( input[:, :self.precision[l].num_classes], target ) for target, input, l in zip( targets, inputs, classifier_levels ) ] ) / len( inputs ) def __call__(self, *args, **kwargs): return dict( @@ -486,21 +444,29 @@ class Base(nn.Module): if "nar" not in self.capabilities: n_resp_tokens = n_audio_tokens + 1 l_tokens = [n_resp_tokens] * self.n_resp_levels + resp_l_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )] # NAR-len model elif "len" in self.capabilities and masking_separate_embeddings: # +1 to include the stop or mask token n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) - l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens] + l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + + if masking_separate_embeddings: + l_tokens += [n_resp_tokens] + resp_l_names += ['NAR:0:0'] # AR+NAR model else: # +1 to include the stop or mask token n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + + classifier_l_names = resp_l_names + ["stt"] self.unified_position_ids = unified_position_ids self.interleave = interleave self.layerskip = layerskip - self.special_tasks = [ "len", "stt" ] self.inject_timestep_embedding = False # results in bad output self.masking_separate_embeddings = masking_separate_embeddings @@ -534,14 +500,11 @@ class Base(nn.Module): self.proms_emb = AudioEmbedding( [n_audio_tokens] * self.n_resp_levels, d_model, sums=audio_embedding_sums, - external_mode=audio_embedding_mode, - capabilities=None, ) self.resps_emb = AudioEmbedding( l_tokens, d_model, sums=audio_embedding_sums, - external_mode=audio_embedding_mode, - capabilities=self.capabilities, + l_names=resp_l_names, ) if self.version >= 3: @@ -842,7 +805,7 @@ class Base(nn.Module): self.metrics = None else: self.classifier = None - self.classifiers = Classifiers( l_tokens + [ n_text_tokens ], d_model ) + self.classifiers = Classifiers( l_tokens + [ n_text_tokens ], d_model, l_names=classifier_l_names ) self.accuracy_metric = None self.precision_metric = None self.metrics = Metrics( l_tokens + [ n_text_tokens ] ) @@ -1002,6 +965,7 @@ class Base(nn.Module): quant_level = quant_levels[i] if quant_levels is not None else 0 task_type = task_list[i] if task_list is not None else "tts" timestep = time_list[i] if time_list is not None else None + classifier_level = None # insert task type as a string inputs[i].append( ( "task", task_type ) ) @@ -1012,7 +976,7 @@ class Base(nn.Module): # Base-line TTS task # Sequence: # prom /may/ include tokens inside to help guide things, per SpeechX - if f'<{task_type}>' in get_task_symmap() and task_type not in self.special_tasks: + if f'<{task_type}>' in get_task_symmap() and task_type not in special_tasks: # insert the text prompt if text_list is not None and text_list[i] is not None: inputs[i].append( ( "text", text_list[i] ) ) @@ -1022,19 +986,22 @@ class Base(nn.Module): # insert RVQ level guidance token if the model is versioned for it if self.rvq_l_emb is not None and not self.interleave: inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) ) + + classifier_level = "AR:0:0" if quant_level == 0 else f'NAR:{quant_level-1}:{quant_level}' # insert input audio prompt if proms_list is not None and proms_list[i] is not None: inputs[i].append( ( "prom", proms_list[i] ) ) # insert tone token if we're trained for it if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None: inputs[i].append( ( "tone", tone_list[i] ) ) - # it does not seem to matter whether this is provided or not, I assume the model attends more to the amount of masked tokens in the sequence - """ # insert timestep token if timestep is not None: + # it does not seem to matter whether this is provided or not, I assume the model attends more to the amount of masked tokens in the sequence + """ # store timestep information inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) ) - """ + """ + classifier_level = "NAR:0:0" # insert the current output response if resps_list is not None and resps_list[i] is not None: inputs[i].append( ( "resp", resps_list[i] ) ) @@ -1050,6 +1017,7 @@ class Base(nn.Module): dropout_mask = _dropout_mask( resps_list[i], p ) inputs[i].append( ("dropout_mask", dropout_mask ) ) + inputs[i].append( ("classifier_level", classifier_level) ) # Audio length prediction task # Sequence: elif task_type == "len": @@ -1080,6 +1048,8 @@ class Base(nn.Module): elif resps_list is not None and resps_list[i] is not None: # yes this could be encoded better inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) ) + + inputs[i].append( ("classifier_level", "stt") ) # Speech-to-Text prediction task # Sequence: elif task_type == "stt": @@ -1095,6 +1065,8 @@ class Base(nn.Module): # insert the output text prompt if text_list is not None and text_list[i] is not None: inputs[i].append( ( "text", text_list[i] ) ) + + inputs[i].append( ("classifier_level", "stt") ) else: raise Exception(f'Unrecognized task: {task_type}') return inputs @@ -1131,7 +1103,6 @@ class Base(nn.Module): if not token_dropout_rvq_levels: token_dropout_rvq_levels = [1, self.resp_levels] - summed_embeddings_task = [ "stt" ] x_list = [] for batch_index, batch_input in enumerate(inputs): @@ -1140,6 +1111,7 @@ class Base(nn.Module): task_type = "tts" input_prom = None + classifier_level = None dropout_mask = None timestep = None @@ -1147,6 +1119,8 @@ class Base(nn.Module): for name, input in batch_input: if name == "dropout_mask": dropout_mask = input + elif name == "classifier_level": + classifier_level = input for name, input in batch_input: # technically can provide a map for input_name => embedding, but some embedding requires additional processing @@ -1179,8 +1153,9 @@ class Base(nn.Module): if self.interleave: embeddings = [ self.resps_emb( input[:, :l+1], - offset = 0, - quant_level = l + #offset = 0, + #quant_level = l, + name = 'AR:0:0' if l == 0 else f'NAR:{l-1}:{l}', ) for l in range( input.shape[-1] ) ] embedding = _interleave_sequence_reshape( embeddings ) @@ -1190,16 +1165,18 @@ class Base(nn.Module): embedding = self.resps_emb( # if masked use masked token, else original token torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ), - offset = -1 if self.masking_separate_embeddings else 0, # pick last - quant_level = 0, + #offset = -1 if self.masking_separate_embeddings else 0, # pick last + #quant_level = 0, + name = classifier_level, ) # cheat-y way to handle performing STT across all levels elif task_type in summed_embeddings_task: # we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT...... embedding = sum([ self.resps_emb( input[:, :l+1], - offset = 0 if l == 0 else 1, # or maybe set to 1 - quant_level = l, + #offset = 0 if l == 0 else 1, # or maybe set to 1 + #quant_level = l, + name = 'AR:0:0' if l == 0 else f'NAR:{l-1}:{l}', sums = False ) for l in range( input.shape[-1] - 1 ) ]) else: @@ -1210,6 +1187,7 @@ class Base(nn.Module): quant_level ) else: + """ offset = 0 if "nar" not in self.capabilities: offset = 0 @@ -1221,6 +1199,13 @@ class Base(nn.Module): offset = offset, quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level ) + """ + + embedding = self.resps_emb( + input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], + offset = 0 if classifier_level.startswith("AR:") else 1, + quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level + ) # apply token dropout if token_dropout_rate > 0.0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]): @@ -1283,16 +1268,14 @@ class Base(nn.Module): # there's a better way if not self.unified_position_ids: x_list = [] + non_tokens = ["task", "dropout_mask", "classifier_level"] + last_input = ["resp", "len"] def get_input_token_length( name, input ): # task token if isinstance(input, str): return 1 - # a mask - if name in ["dropout_mask"]: - return 0 - # list of tokens if not isinstance(input, torch.Tensor): return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1 @@ -1302,12 +1285,12 @@ class Base(nn.Module): return input.shape[0] * input.shape[1] # ending input will not have a separator later - return input.shape[0] + (0 if name in ["resp", "len"] else 1) + return input.shape[0] + (0 if name in last_input else 1) for batch_index, batch_input in enumerate(inputs): batch = torch.cat( [ torch.tensor([*range(get_input_token_length(name, input))], device=device, dtype=torch.int32) - for name, input in batch_input if name != "task" + for name, input in batch_input if name not in non_tokens ] ) delta = ids[batch_index].shape[0] - batch.shape[0] @@ -1325,17 +1308,14 @@ class Base(nn.Module): inputs: list, logits, - quant_levels: int | list[int] | Tensor | None = None, + quant_levels: list[int] | None = None, ): loss = dict(ce = dict()) stats = dict(acc = dict()) device = logits[0].device batch_size = len(logits) - summed_embeddings_task = [ "stt" ] - tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ] - is_nar_len = [ self.get_input(inputs, "dropout_mask", at=i) is not None and self.masking_separate_embeddings for i in range( batch_size ) ] - classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else (-2 if is_nar_len[i] else l) for i, l in enumerate( quant_levels ) ] + classifier_levels = self.get_input( inputs, "classifier_level" ) # handles tasks where the prompt has task tokens injected in the middle def prompt_input_to_token( input, quant_level ): @@ -1369,7 +1349,7 @@ class Base(nn.Module): if name == "task": task_type = input task_list.append( input ) - if task_type in ["len", "stt"]: + if task_type in special_tasks: causal = True elif name == "prom": proms = [ input ] if isinstance(input, torch.Tensor) else input @@ -1423,7 +1403,7 @@ class Base(nn.Module): loss = dict( nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index ) ) - stats = self.metrics( inputs, targets, classifier_quant_levels ) if self.metrics is not None else dict( + stats = self.metrics( inputs, targets, classifier_levels ) if self.metrics is not None else dict( acc = self.accuracy_metric( inputs, target ), # precision = self.precision_metric( inputs, target ), ) @@ -1432,7 +1412,7 @@ class Base(nn.Module): loss = dict( nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size ) - stats = self.metrics( logits, target_list, classifier_quant_levels ) if self.metrics is not None else dict( + stats = self.metrics( logits, target_list, self.classifiers.indices( classifier_levels ) ) if self.metrics is not None else dict( acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size ) @@ -1466,7 +1446,7 @@ class Base(nn.Module): # meta-input, no corresponding token at the moment if name == "task": task_name = input - if task_type in ["len", "stt"]: + if task_type in special_tasks: causal = True continue # do not use resp as-is @@ -1529,7 +1509,7 @@ class Base(nn.Module): else: loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size if self.metrics is not None: - metrics = self.metrics( batch["logits"], batch["targets"], classifier_quant_levels ) + metrics = self.metrics( batch["logits"], batch["targets"], self.classifiers.indices( classifier_levels ) ) stats["acc"][name] = metrics["acc"] else: stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size @@ -1540,7 +1520,7 @@ class Base(nn.Module): self, inputs: list, - quant_levels: int | list[int] | Tensor | None = None, + quant_levels: list[int] | None = None, state: dict | list | None = None, layer_skip_variables: dict | None = None, @@ -1549,7 +1529,7 @@ class Base(nn.Module): output_hidden_states: bool = False, ): # return early if it's "good" enough" - # lambda because we need to capture the classifier_quant_levels and mask + # lambda because we need to capture the classifier_levels and mask exited_layer = self.n_layers def layer_skip_lambda( layer, logits ): nonlocal exited_layer @@ -1576,7 +1556,7 @@ class Base(nn.Module): if self.classifier is not None: x = self.classifier(x) # * m elif self.classifiers is not None: - logits = self.classifiers(logits, levels = classifier_quant_levels) # * m + logits = self.classifiers(logits, levels = classifier_levels) # * m # calculate metrics metrics = calculate_entropix_metrics( logits ) @@ -1628,10 +1608,7 @@ class Base(nn.Module): # needs to be done here as we still have our raw inputs #position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None - - tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ] - is_nar_len = [ self.get_input(inputs, "dropout_mask", at=i) is not None and self.masking_separate_embeddings for i in range( batch_size ) ] - classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else (-2 if is_nar_len[i] else l) for i, l in enumerate( quant_levels ) ] + classifier_levels = self.get_input( inputs, name="classifier_level" ) if self.inject_timestep_embedding: timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ] @@ -1664,11 +1641,11 @@ class Base(nn.Module): # to-do: piece-wise classification, now that there's a head for text # although again, one single monolithic head would be preferable instead...... elif self.classifiers is not None: - logits = self.classifiers(logits, levels = classifier_quant_levels) # * m + logits = self.classifiers(logits, levels = classifier_levels) # * m if hidden_states is not None: for i, state in enumerate( hidden_states ): - hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_quant_levels) # * m + hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_levels) # * m # Remove padding logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ] @@ -1716,7 +1693,7 @@ class Base(nn.Module): self, logits: list[Tensor], # logit scores prev_list: list[Tensor] | None = None, # previous tokens - quant_levels: int | list[int] | Tensor | None = None, # to-do: derive this from the prev_list + quant_levels: list[int] | None = None, # to-do: derive this from the prev_list **sampling_kwargs, ): # yikes