From 00d1fed2170b91621ac7b47c31513962f181be0f Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 8 Mar 2025 17:10:50 -0600 Subject: [PATCH] another optimization (within the dataloader because the similar utterance sampler was mondo slow) --- vall_e/data.py | 56 +++--- vall_e/models/arch/llama.py | 10 ++ vall_e/models/base_v2.py | 348 +++++------------------------------- 3 files changed, 88 insertions(+), 326 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index f30c70b..bcf1747 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -703,9 +703,14 @@ def _get_artifact_path(path): return _replace_file_extension(path, _get_artifact_extension()) _durations_map = {} +_similar_map = {} + def _get_duration_map( type="training" ): return _durations_map[type] if type in _durations_map else {} +def _get_similar_map( type="training" ): + return _similar_map[type] if type in _similar_map else {} + def _load_paths(dataset, type="training", silent=not is_global_leader(), dataset_hash_key=None): assert cfg.dataset.min_duration >= 1.0, "Minimum duration too low." @@ -716,10 +721,14 @@ def _load_paths(dataset, type="training", silent=not is_global_leader(), dataset cached_durations_path = cached_dir / f"durations[{type}].json" cached_paths_path = cached_dir / f"dataloader[{type}].json" + cached_similar_path = cached_dir / f"similar[{type}].json" # load the duration table first, since this is independent from the loaded paths if cached_durations_path.exists(): _durations_map[type] = json_read( cached_durations_path ) + # load the similar paths table as well, since this is also independent + if cached_similar_path.exists(): + _similar_map[type] = json_read( cached_similar_path ) # load the cached valid paths (if we're requesting cache use) if cached_paths_path.exists() and cfg.dataset.cache: @@ -734,6 +743,7 @@ def _load_paths(dataset, type="training", silent=not is_global_leader(), dataset if not cached_dir.exists(): cached_dir.mkdir(parents=True, exist_ok=True) + json_write( _similar_map[type], cached_similar_path, truncate=True ) json_write( _durations_map[type], cached_durations_path, truncate=True ) json_write( paths, cached_paths_path, truncate=True ) @@ -769,9 +779,11 @@ def _load_paths_from_metadata(group_name, type="training", validate=False): return (data_dir / id).with_suffix(_get_artifact_extension()).exists() + metadata_keys = list(metadata.keys()) def _validate( id, entry ): - phones = entry['phones'] if "phones" in entry else 0 - duration = entry['duration'] if "duration" in entry else 0 + phones = entry.get('phones', 0) + duration = entry.get('duration', 0) + similar = entry.get('similar', None) k = key(id, entry) @@ -780,6 +792,11 @@ def _load_paths_from_metadata(group_name, type="training", validate=False): _durations_map[type] = {} _durations_map[type][k] = duration + # add to similar bucket + if type not in _similar_map: + _similar_map[type] = {} + _similar_map[type][k] = [ metadata_keys[idx] for idx in similar ] if similar else None + if not validate: return True @@ -1188,43 +1205,28 @@ class Dataset(_Dataset): if offset is None: offset = cfg.dataset.prompt_similar_top_k_offset + root = Path( *path.parts[:-1] ) reference = path.name + similars = _similar_map[self.dataset_type].get(str(path), None) - if cfg.dataset.use_hdf5: - root = Path( *path.parts[:-1] ) - path = Path( *path.parts[2:-1] ) - else: - root = Path( *path.parts[:-1] ) - path = Path(*path.parts[len(cfg.data_dir.parts):-1]) - - metadata = json_read( cfg.metadata_dir / path.with_suffix(".json"), default={} ) - - if reference not in metadata: + if not similars: return None - reference_metadata = metadata[reference] - - if "similar" not in reference_metadata: - return None - - if len(reference_metadata["similar"]) >= offset: + if len(similars) >= offset: offset = 0 # cringe stopgap offset_end = offset + cfg.dataset.prompt_similar_top_k - if offset >= len( reference_metadata["similar"] ): - return None - if offset_end >= len( reference_metadata["similar"] ): - return None - metadata_keys = list(metadata.keys()) + if offset >= len( similars ): + return None + if offset_end >= len( similars ): + return None if cfg.dataset.prompt_similar_top_k > 1: - indices = reference_metadata["similar"][offset:offset_end] - index = random.choice( indices ) + name = random.choice( similars[offset:offset_end] ) else: - index = reference_metadata["similar"][offset] - name = metadata_keys[index] + name = similars[offset] path = root / name diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index a19992c..c1fb2f0 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -409,6 +409,16 @@ class Attention(nn.Module): f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) + elif mode in [torch.nn.attention.SDPBackend.FLASH_ATTENTION]: + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=None, # ROCm FA2 through SDPA doesn't allow masks, bummer + dropout_p=dropout_rate, + is_causal=is_causal, + ) else: # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index aa71a57..e608aef 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -329,13 +329,13 @@ class Base_V2(nn.Module): ignore_inputs_for_loss = config.experimental.ignore_inputs_for_loss if config is not None else False resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True + len_parallel_training = False # config.experimental.len_parallel_training if config is not None else True predict_causally = config.experimental.predict_causally if config is not None else False monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto" logit_normalization = config.experimental.logit_normalization if config is not None else 0 per_level_normalization = config.experimental.per_level_normalization if config is not None else True use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True - use_streamlined_calc_loss = config.experimental.use_streamlined_calc_loss if config is not None else True n_vocab = 256 n_tasks = config.tasks if config is not None else 8 @@ -395,6 +395,7 @@ class Base_V2(nn.Module): self.n_max_levels = self.config.max_levels if self.config else n_resp_levels self.capabilities = self.config.capabilities if self.config else ["ar", "nar"] self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True + self.use_streamlined_calc_loss = True self.stop_token = self.n_audio_tokens self.mask_token = self.stop_token + 1 @@ -414,9 +415,9 @@ class Base_V2(nn.Module): self.teaching = True self.training = False - self.resp_parallel_training = resp_parallel_training self.predict_causally = predict_causally - + self.resp_parallel_training = resp_parallel_training + self.len_parallel_training = len_parallel_training self.unified_position_ids = unified_position_ids self.inject_timestep_embedding = False # results in bad output self.masking_ratio = masking_ratio @@ -425,7 +426,6 @@ class Base_V2(nn.Module): self.audio_level_loss_factors = audio_level_loss_factors self.logit_normalization = logit_normalization self.use_segmented_attention_mask = use_segmented_attention_mask - self.use_streamlined_calc_loss = use_streamlined_calc_loss self.sep = nn.Parameter(torch.randn(d_model)) @@ -901,287 +901,7 @@ class Base_V2(nn.Module): self, inputs: list, logits, - - quant_levels: list[int] | None = None, - compute_hard_loss = True, - compute_acc = True, - ): - loss = {} - stats = {} - - device = logits[0].device - batch_size = len(logits) - classifier_levels = self.get_input( inputs, "classifier_level" ) - - # handles tasks where the prompt has task tokens injected in the middle - def prompt_input_to_token( input, quant_level ): - if isinstance(input, str): - return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16) - - return input - - k_lo, k_hi = 1, 20 - def _calc_loss( logit, sequence, causal = True, level = None ): - level_loss_factors = self.audio_level_loss_factors - - # filter tokens that exceed the vocab size - sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence ) - # drop if all tokens are ignored - if torch.all(sequence == self.ignore_index): - return None, None - - # shift if causal - if causal or self.predict_causally: - l = self.causal_size - logit = logit[..., :-l, :] # shift the target so that token n... - sequence = sequence[..., l:] # ...predicts token n + 1 - - batched = sequence.dim() > 1 - - # logit normalization - if self.logit_normalization: - # it would probably be better to unsqueeze then squeeze to avoid code duplication but who cares - if not batched: - logit = logit_normalization( logit, self.logit_normalization ) - else: - for i, l in enumerate( logit ): - logit[i] = logit_normalization( l, self.logit_normalization ) - - # flatten batch - if batched: - logit = logit.reshape(-1, logit.shape[-1]) - sequence = sequence.reshape(-1) - - nll = None - acc_k_lo = None - - if compute_hard_loss: - reduction = 'mean' if not batched else 'none' - weight = level_loss_factors[level] if level is not None and not batched else 1 - loss_func = F.cross_entropy # to-do: add mse_loss - loss_kwargs = dict(ignore_index=self.ignore_index) if loss_func == F.cross_entropy else {} - - nll = loss_func( logit, sequence, reduction=reduction, **loss_kwargs ) * weight - # manually weigh each level - if batched: - nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_loss_factors, device=device) - - if compute_acc: - if logit.shape[0] >= k_lo: - accuracy_metric = MulticlassAccuracy( - logit.shape[-1], - top_k = 1, - average="micro", - multidim_average="global", - ignore_index = -100 - ).to(logit.device) - acc_k_lo = accuracy_metric( logit, sequence ) - - if logit.shape[0] >= k_hi: - accuracy_metric = MulticlassAccuracy( - logit.shape[-1], - top_k = 20, - average="micro", - multidim_average="global", - ignore_index = -100 - ).to(logit.device) - acc_k_hi = accuracy_metric( logit, sequence ) - - return nll, acc_k_lo, acc_k_hi - - for batch_index, batch in enumerate(inputs): - quant_level = quant_levels[batch_index] - target = [] - causal = True - task_type = "tts" - dropout_mask = None - classifier_level = None - output_len = 0 - - for name, input in batch: - if name == "task": - task_type = input - elif name == "dropout_mask": - dropout_mask = input - elif name == "classifier_level": - classifier_level = input - - # autoregressive, causal - if classifier_level.startswith("AR:"): - causal = True - # nonautoregressive, parallel - elif classifier_level.startswith("NAR:"): - causal = False - - it = 0 - for name, input in batch: - token = None - ignored = False - - # non-tokened tasks - if name in non_tokened_names: - continue - # prom can either be a tensor itself or a list of tensors and strings - if name == "prom": - # expand to list if not a list - proms = [ input ] if isinstance(input, torch.Tensor) else input - # iterate over the list to inject their tokens - token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) - - if logits[batch_index].dim() < 3 and token.dim() >= 2: - token = token[..., 0] - elif name == "resp": - token = input - - # mask found, apply it - if dropout_mask is not None: - token = _dropout_codes( token, dropout_mask, self.ignore_index, swapped = True ) - # not a special input, inject as-is - else: - token = input - - if not isinstance(token, torch.Tensor): - continue - - if token.is_floating_point(): - ignored = True - - # grab range of our logits for later - seq_len = token.shape[0] - start, end = it, it+seq_len - it += seq_len + 1 # +1 to incorporate the separator - - # deduce if a name for a task is an input or output - if name != task_outputs.get(task_type, name): - if self.ignore_inputs_for_loss: - ignored = True - else: - output_len = seq_len - - if ignored: - # pruned - if self.config.loss_factors: - continue - # fill with ignored out tensor - token = torch.tensor( [ self.ignore_index ] * token.shape[0], device=device, dtype=torch.int16) - - # perform loss calculation on the individual piece - if self.config.loss_factors: - loss_factor = self.loss_factor(name) - - if loss_factor == 0.0: - continue - - if logits[batch_index].dim() < 3: - nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][start:end], token.long(), causal ) - elif not self.resp_parallel_training: - # cringe way to deduce "requested" level - level = quant_level - for i in range( self.n_resp_levels ): - if classifier_level.endswith(f':{i}:{i}'): - level = i - break - - if name == "resp": - name = f'{name}[{level}]' - - sequence = token if token.dim() <= 1 else token[:, level] - nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level ) - else: - sequence = token.t() - nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal ) - - if nll is not None: - nll = nll.mean() - - loss_key = f'{name}.nll' - acc_k_lo_key = f'{name}.acc[k={k_lo}]' - acc_k_hi_key = f'{name}.acc[k={k_hi}]' - if nll is not None: - if loss_key not in loss: - loss[loss_key] = [] - loss[loss_key].append( nll * loss_factor ) - - if acc_k_lo is not None: - if acc_k_lo_key not in stats: - stats[acc_k_lo_key] = [] - stats[acc_k_lo_key].append( acc_k_lo ) - - if acc_k_hi is not None: - if acc_k_hi_key not in stats: - stats[acc_k_hi_key] = [] - stats[acc_k_hi_key].append( acc_k_hi ) - # add to list - else: - target.append( token ) - - - # perform loss calculation on the entire sequence - if not self.config.loss_factors: - if logits[batch_index].dim() < 3: - sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) - nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index], sequence, causal ) - elif not self.resp_parallel_training: - # cringe way to deduce "requested" level - level = 0 - for i in range( self.n_resp_levels ): - if classifier_level.endswith(f':{i}:{i}'): - level = i - break - - sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ] - sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) ) - nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][level], sequence.long(), causal, level ) - else: - nlls = [] - acc_k_los = [] - acc_k_his = [] - - for level, logit in enumerate( logits[batch_index] ): - sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ] - sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) ) - nll, acc_k_lo, acc_k_hi = _calc_loss( logit, sequence, causal, level ) - - if nll: - nlls.append( nll ) - if acc_k_lo: - acc_k_los.append( acc_k_lo ) - if acc_k_hi: - acc_k_his.append( acc_k_hi ) - - if nlls: - nll = sum(nlls) / len(nlls) - if acc_k_los: - acc_k_lo = sum(acc_k_los) / len(acc_k_los) - if acc_k_his: - acc_k_hi = sum(acc_k_his) / len(acc_k_his) - - if nll is not None: - if 'nll' not in loss: - loss['nll'] = [] - loss["nll"].append( nll ) - - if acc_k_lo is not None: - if f'acc[k={k_lo}]' not in stats: - stats[f'acc[k={k_lo}]'] = [] - stats[f"acc[k={k_lo}]"].append( acc_k_lo ) - - if acc_k_hi is not None: - if f'acc[k={k_hi}]' not in stats: - stats[f'acc[k={k_hi}]'] = [] - stats[f"acc[k={k_hi}]"].append( acc_k_hi ) - - # average - loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() } - stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() } - - return LossStats(loss, stats) - - # this is a specialized loss calculation that makes a lot of assumptions to try and streamline it by doing one loss calc instead of many - def calc_loss_specialized( - self, - inputs: list, - logits, + logits_aux = None, quant_levels: list[int] | None = None, compute_hard_loss = True, @@ -1204,9 +924,13 @@ class Base_V2(nn.Module): k_lo, k_hi = 1, 20 level_loss_factors = self.audio_level_loss_factors + # this could be one array of tuples but can't be assed loss_targets = [] loss_logits = [] - loss_levels = [] + loss_factors = [] + loss_names = [] + + resp_durations = [] for batch_index, batch in enumerate(inputs): quant_level = quant_levels[batch_index] @@ -1214,7 +938,6 @@ class Base_V2(nn.Module): task_type = "tts" dropout_mask = None classifier_level = None - output_len = 0 for name, input in batch: if name == "task": @@ -1273,19 +996,45 @@ class Base_V2(nn.Module): if name != task_outputs.get(task_type, name): continue - output_len = seq_len - - for level in range( self.n_resp_levels ): - if not self.resp_parallel_training and not classifier_level.endswith(f':{level}:{level}'): + if token.dim() == 1: + loss_factor = self.loss_factor(name) + if loss_factor == 0.0: continue - logit = logits[batch_index][level][start:end] + logit = logits[batch_index][start:end] + if causal or self.predict_causally: + l = self.causal_size + logit = logit[..., :-l, :] # shift the target so that token n... + token = sequence[..., l:] # ...predicts token n + 1 + if self.logit_normalization: logit = logit_normalization( logit, self.logit_normalization ) - loss_targets.append( token[:, level].long() ) + loss_targets.append( token.long() ) loss_logits.append( logit ) - loss_levels.append( level ) + loss_factors.append( loss_factor ) + loss_names.append( name ) + else: + if name == "resp" and self.len_parallel_training: + resp_durations.append( token.shape[0] ) + + for level in range( self.n_resp_levels ): + if not self.resp_parallel_training and not classifier_level.endswith(f':{level}:{level}'): + continue + + logit = logits[batch_index][level][start:end] + if causal or self.predict_causally: + l = self.causal_size + logit = logit[..., :-l, :] # shift the target so that token n... + token = sequence[..., l:] # ...predicts token n + 1 + + if self.logit_normalization: + logit = logit_normalization( logit, self.logit_normalization ) + + loss_targets.append( token[:, level].long() ) + loss_logits.append( logit ) + loss_factors.append( level_loss_factors[level] ) + loss_names.append( name ) break @@ -1304,14 +1053,14 @@ class Base_V2(nn.Module): it = 0 weights = 0 bsz = len( loss_targets ) - for seq, level in zip( loss_targets, loss_levels ): + for seq, loss_factor in zip( loss_targets, loss_factors ): seq_len = seq.shape[0] start = it it += seq_len end = it - nll += nlls[start:end].mean() * level_loss_factors[level] - weights += level_loss_factors[level] + nll += nlls[start:end].mean() * loss_factor + weights += loss_factor # normalize by batch nll /= bsz @@ -1341,6 +1090,7 @@ class Base_V2(nn.Module): ).to(loss_logit.device) acc_k_hi = accuracy_metric( loss_logit, loss_target ) + # to-do: re-add reporting split losses if nll is not None: if 'nll' not in loss: loss['nll'] = [] @@ -1442,6 +1192,7 @@ class Base_V2(nn.Module): if self.use_streamlined_calc_loss: logits = self.audio_decoder( output.logits ) + # to-do: get len logits else: logits = [ logit for logit in output.logits ] grouped_logits = {} @@ -1486,8 +1237,7 @@ class Base_V2(nn.Module): # compute loss if the target is given else: - loss_func = self.calc_loss_specialized if self.use_streamlined_calc_loss else self.calc_loss - loss, stats = loss_func( inputs=inputs, logits=logits, quant_levels=quant_levels ) + loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels ) # include any additional losses (for example: MoE router) if output.loss is not None: