another optimization (within the dataloader because the similar utterance sampler was mondo slow)

This commit is contained in:
mrq 2025-03-08 17:10:50 -06:00
parent 5e9d1a5302
commit 00d1fed217
3 changed files with 88 additions and 326 deletions

View File

@ -703,9 +703,14 @@ def _get_artifact_path(path):
return _replace_file_extension(path, _get_artifact_extension())
_durations_map = {}
_similar_map = {}
def _get_duration_map( type="training" ):
return _durations_map[type] if type in _durations_map else {}
def _get_similar_map( type="training" ):
return _similar_map[type] if type in _similar_map else {}
def _load_paths(dataset, type="training", silent=not is_global_leader(), dataset_hash_key=None):
assert cfg.dataset.min_duration >= 1.0, "Minimum duration too low."
@ -716,10 +721,14 @@ def _load_paths(dataset, type="training", silent=not is_global_leader(), dataset
cached_durations_path = cached_dir / f"durations[{type}].json"
cached_paths_path = cached_dir / f"dataloader[{type}].json"
cached_similar_path = cached_dir / f"similar[{type}].json"
# load the duration table first, since this is independent from the loaded paths
if cached_durations_path.exists():
_durations_map[type] = json_read( cached_durations_path )
# load the similar paths table as well, since this is also independent
if cached_similar_path.exists():
_similar_map[type] = json_read( cached_similar_path )
# load the cached valid paths (if we're requesting cache use)
if cached_paths_path.exists() and cfg.dataset.cache:
@ -734,6 +743,7 @@ def _load_paths(dataset, type="training", silent=not is_global_leader(), dataset
if not cached_dir.exists():
cached_dir.mkdir(parents=True, exist_ok=True)
json_write( _similar_map[type], cached_similar_path, truncate=True )
json_write( _durations_map[type], cached_durations_path, truncate=True )
json_write( paths, cached_paths_path, truncate=True )
@ -769,9 +779,11 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
return (data_dir / id).with_suffix(_get_artifact_extension()).exists()
metadata_keys = list(metadata.keys())
def _validate( id, entry ):
phones = entry['phones'] if "phones" in entry else 0
duration = entry['duration'] if "duration" in entry else 0
phones = entry.get('phones', 0)
duration = entry.get('duration', 0)
similar = entry.get('similar', None)
k = key(id, entry)
@ -780,6 +792,11 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
_durations_map[type] = {}
_durations_map[type][k] = duration
# add to similar bucket
if type not in _similar_map:
_similar_map[type] = {}
_similar_map[type][k] = [ metadata_keys[idx] for idx in similar ] if similar else None
if not validate:
return True
@ -1188,43 +1205,28 @@ class Dataset(_Dataset):
if offset is None:
offset = cfg.dataset.prompt_similar_top_k_offset
root = Path( *path.parts[:-1] )
reference = path.name
similars = _similar_map[self.dataset_type].get(str(path), None)
if cfg.dataset.use_hdf5:
root = Path( *path.parts[:-1] )
path = Path( *path.parts[2:-1] )
else:
root = Path( *path.parts[:-1] )
path = Path(*path.parts[len(cfg.data_dir.parts):-1])
metadata = json_read( cfg.metadata_dir / path.with_suffix(".json"), default={} )
if reference not in metadata:
if not similars:
return None
reference_metadata = metadata[reference]
if "similar" not in reference_metadata:
return None
if len(reference_metadata["similar"]) >= offset:
if len(similars) >= offset:
offset = 0
# cringe stopgap
offset_end = offset + cfg.dataset.prompt_similar_top_k
if offset >= len( reference_metadata["similar"] ):
return None
if offset_end >= len( reference_metadata["similar"] ):
return None
metadata_keys = list(metadata.keys())
if offset >= len( similars ):
return None
if offset_end >= len( similars ):
return None
if cfg.dataset.prompt_similar_top_k > 1:
indices = reference_metadata["similar"][offset:offset_end]
index = random.choice( indices )
name = random.choice( similars[offset:offset_end] )
else:
index = reference_metadata["similar"][offset]
name = metadata_keys[index]
name = similars[offset]
path = root / name

View File

@ -409,6 +409,16 @@ class Attention(nn.Module):
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
elif mode in [torch.nn.attention.SDPBackend.FLASH_ATTENTION]:
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=None, # ROCm FA2 through SDPA doesn't allow masks, bummer
dropout_p=dropout_rate,
is_causal=is_causal,
)
else:
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.

View File

@ -329,13 +329,13 @@ class Base_V2(nn.Module):
ignore_inputs_for_loss = config.experimental.ignore_inputs_for_loss if config is not None else False
resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True
len_parallel_training = False # config.experimental.len_parallel_training if config is not None else True
predict_causally = config.experimental.predict_causally if config is not None else False
monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False
audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto"
logit_normalization = config.experimental.logit_normalization if config is not None else 0
per_level_normalization = config.experimental.per_level_normalization if config is not None else True
use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True
use_streamlined_calc_loss = config.experimental.use_streamlined_calc_loss if config is not None else True
n_vocab = 256
n_tasks = config.tasks if config is not None else 8
@ -395,6 +395,7 @@ class Base_V2(nn.Module):
self.n_max_levels = self.config.max_levels if self.config else n_resp_levels
self.capabilities = self.config.capabilities if self.config else ["ar", "nar"]
self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True
self.use_streamlined_calc_loss = True
self.stop_token = self.n_audio_tokens
self.mask_token = self.stop_token + 1
@ -414,9 +415,9 @@ class Base_V2(nn.Module):
self.teaching = True
self.training = False
self.resp_parallel_training = resp_parallel_training
self.predict_causally = predict_causally
self.resp_parallel_training = resp_parallel_training
self.len_parallel_training = len_parallel_training
self.unified_position_ids = unified_position_ids
self.inject_timestep_embedding = False # results in bad output
self.masking_ratio = masking_ratio
@ -425,7 +426,6 @@ class Base_V2(nn.Module):
self.audio_level_loss_factors = audio_level_loss_factors
self.logit_normalization = logit_normalization
self.use_segmented_attention_mask = use_segmented_attention_mask
self.use_streamlined_calc_loss = use_streamlined_calc_loss
self.sep = nn.Parameter(torch.randn(d_model))
@ -901,287 +901,7 @@ class Base_V2(nn.Module):
self,
inputs: list,
logits,
quant_levels: list[int] | None = None,
compute_hard_loss = True,
compute_acc = True,
):
loss = {}
stats = {}
device = logits[0].device
batch_size = len(logits)
classifier_levels = self.get_input( inputs, "classifier_level" )
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ):
if isinstance(input, str):
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
return input
k_lo, k_hi = 1, 20
def _calc_loss( logit, sequence, causal = True, level = None ):
level_loss_factors = self.audio_level_loss_factors
# filter tokens that exceed the vocab size
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
# drop if all tokens are ignored
if torch.all(sequence == self.ignore_index):
return None, None
# shift if causal
if causal or self.predict_causally:
l = self.causal_size
logit = logit[..., :-l, :] # shift the target so that token n...
sequence = sequence[..., l:] # ...predicts token n + 1
batched = sequence.dim() > 1
# logit normalization
if self.logit_normalization:
# it would probably be better to unsqueeze then squeeze to avoid code duplication but who cares
if not batched:
logit = logit_normalization( logit, self.logit_normalization )
else:
for i, l in enumerate( logit ):
logit[i] = logit_normalization( l, self.logit_normalization )
# flatten batch
if batched:
logit = logit.reshape(-1, logit.shape[-1])
sequence = sequence.reshape(-1)
nll = None
acc_k_lo = None
if compute_hard_loss:
reduction = 'mean' if not batched else 'none'
weight = level_loss_factors[level] if level is not None and not batched else 1
loss_func = F.cross_entropy # to-do: add mse_loss
loss_kwargs = dict(ignore_index=self.ignore_index) if loss_func == F.cross_entropy else {}
nll = loss_func( logit, sequence, reduction=reduction, **loss_kwargs ) * weight
# manually weigh each level
if batched:
nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_loss_factors, device=device)
if compute_acc:
if logit.shape[0] >= k_lo:
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
top_k = 1,
average="micro",
multidim_average="global",
ignore_index = -100
).to(logit.device)
acc_k_lo = accuracy_metric( logit, sequence )
if logit.shape[0] >= k_hi:
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
top_k = 20,
average="micro",
multidim_average="global",
ignore_index = -100
).to(logit.device)
acc_k_hi = accuracy_metric( logit, sequence )
return nll, acc_k_lo, acc_k_hi
for batch_index, batch in enumerate(inputs):
quant_level = quant_levels[batch_index]
target = []
causal = True
task_type = "tts"
dropout_mask = None
classifier_level = None
output_len = 0
for name, input in batch:
if name == "task":
task_type = input
elif name == "dropout_mask":
dropout_mask = input
elif name == "classifier_level":
classifier_level = input
# autoregressive, causal
if classifier_level.startswith("AR:"):
causal = True
# nonautoregressive, parallel
elif classifier_level.startswith("NAR:"):
causal = False
it = 0
for name, input in batch:
token = None
ignored = False
# non-tokened tasks
if name in non_tokened_names:
continue
# prom can either be a tensor itself or a list of tensors and strings
if name == "prom":
# expand to list if not a list
proms = [ input ] if isinstance(input, torch.Tensor) else input
# iterate over the list to inject their tokens
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
if logits[batch_index].dim() < 3 and token.dim() >= 2:
token = token[..., 0]
elif name == "resp":
token = input
# mask found, apply it
if dropout_mask is not None:
token = _dropout_codes( token, dropout_mask, self.ignore_index, swapped = True )
# not a special input, inject as-is
else:
token = input
if not isinstance(token, torch.Tensor):
continue
if token.is_floating_point():
ignored = True
# grab range of our logits for later
seq_len = token.shape[0]
start, end = it, it+seq_len
it += seq_len + 1 # +1 to incorporate the separator
# deduce if a name for a task is an input or output
if name != task_outputs.get(task_type, name):
if self.ignore_inputs_for_loss:
ignored = True
else:
output_len = seq_len
if ignored:
# pruned
if self.config.loss_factors:
continue
# fill with ignored out tensor
token = torch.tensor( [ self.ignore_index ] * token.shape[0], device=device, dtype=torch.int16)
# perform loss calculation on the individual piece
if self.config.loss_factors:
loss_factor = self.loss_factor(name)
if loss_factor == 0.0:
continue
if logits[batch_index].dim() < 3:
nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][start:end], token.long(), causal )
elif not self.resp_parallel_training:
# cringe way to deduce "requested" level
level = quant_level
for i in range( self.n_resp_levels ):
if classifier_level.endswith(f':{i}:{i}'):
level = i
break
if name == "resp":
name = f'{name}[{level}]'
sequence = token if token.dim() <= 1 else token[:, level]
nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level )
else:
sequence = token.t()
nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
if nll is not None:
nll = nll.mean()
loss_key = f'{name}.nll'
acc_k_lo_key = f'{name}.acc[k={k_lo}]'
acc_k_hi_key = f'{name}.acc[k={k_hi}]'
if nll is not None:
if loss_key not in loss:
loss[loss_key] = []
loss[loss_key].append( nll * loss_factor )
if acc_k_lo is not None:
if acc_k_lo_key not in stats:
stats[acc_k_lo_key] = []
stats[acc_k_lo_key].append( acc_k_lo )
if acc_k_hi is not None:
if acc_k_hi_key not in stats:
stats[acc_k_hi_key] = []
stats[acc_k_hi_key].append( acc_k_hi )
# add to list
else:
target.append( token )
# perform loss calculation on the entire sequence
if not self.config.loss_factors:
if logits[batch_index].dim() < 3:
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index], sequence, causal )
elif not self.resp_parallel_training:
# cringe way to deduce "requested" level
level = 0
for i in range( self.n_resp_levels ):
if classifier_level.endswith(f':{i}:{i}'):
level = i
break
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][level], sequence.long(), causal, level )
else:
nlls = []
acc_k_los = []
acc_k_his = []
for level, logit in enumerate( logits[batch_index] ):
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
nll, acc_k_lo, acc_k_hi = _calc_loss( logit, sequence, causal, level )
if nll:
nlls.append( nll )
if acc_k_lo:
acc_k_los.append( acc_k_lo )
if acc_k_hi:
acc_k_his.append( acc_k_hi )
if nlls:
nll = sum(nlls) / len(nlls)
if acc_k_los:
acc_k_lo = sum(acc_k_los) / len(acc_k_los)
if acc_k_his:
acc_k_hi = sum(acc_k_his) / len(acc_k_his)
if nll is not None:
if 'nll' not in loss:
loss['nll'] = []
loss["nll"].append( nll )
if acc_k_lo is not None:
if f'acc[k={k_lo}]' not in stats:
stats[f'acc[k={k_lo}]'] = []
stats[f"acc[k={k_lo}]"].append( acc_k_lo )
if acc_k_hi is not None:
if f'acc[k={k_hi}]' not in stats:
stats[f'acc[k={k_hi}]'] = []
stats[f"acc[k={k_hi}]"].append( acc_k_hi )
# average
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() }
return LossStats(loss, stats)
# this is a specialized loss calculation that makes a lot of assumptions to try and streamline it by doing one loss calc instead of many
def calc_loss_specialized(
self,
inputs: list,
logits,
logits_aux = None,
quant_levels: list[int] | None = None,
compute_hard_loss = True,
@ -1204,9 +924,13 @@ class Base_V2(nn.Module):
k_lo, k_hi = 1, 20
level_loss_factors = self.audio_level_loss_factors
# this could be one array of tuples but can't be assed
loss_targets = []
loss_logits = []
loss_levels = []
loss_factors = []
loss_names = []
resp_durations = []
for batch_index, batch in enumerate(inputs):
quant_level = quant_levels[batch_index]
@ -1214,7 +938,6 @@ class Base_V2(nn.Module):
task_type = "tts"
dropout_mask = None
classifier_level = None
output_len = 0
for name, input in batch:
if name == "task":
@ -1273,19 +996,45 @@ class Base_V2(nn.Module):
if name != task_outputs.get(task_type, name):
continue
output_len = seq_len
for level in range( self.n_resp_levels ):
if not self.resp_parallel_training and not classifier_level.endswith(f':{level}:{level}'):
if token.dim() == 1:
loss_factor = self.loss_factor(name)
if loss_factor == 0.0:
continue
logit = logits[batch_index][level][start:end]
logit = logits[batch_index][start:end]
if causal or self.predict_causally:
l = self.causal_size
logit = logit[..., :-l, :] # shift the target so that token n...
token = sequence[..., l:] # ...predicts token n + 1
if self.logit_normalization:
logit = logit_normalization( logit, self.logit_normalization )
loss_targets.append( token[:, level].long() )
loss_targets.append( token.long() )
loss_logits.append( logit )
loss_levels.append( level )
loss_factors.append( loss_factor )
loss_names.append( name )
else:
if name == "resp" and self.len_parallel_training:
resp_durations.append( token.shape[0] )
for level in range( self.n_resp_levels ):
if not self.resp_parallel_training and not classifier_level.endswith(f':{level}:{level}'):
continue
logit = logits[batch_index][level][start:end]
if causal or self.predict_causally:
l = self.causal_size
logit = logit[..., :-l, :] # shift the target so that token n...
token = sequence[..., l:] # ...predicts token n + 1
if self.logit_normalization:
logit = logit_normalization( logit, self.logit_normalization )
loss_targets.append( token[:, level].long() )
loss_logits.append( logit )
loss_factors.append( level_loss_factors[level] )
loss_names.append( name )
break
@ -1304,14 +1053,14 @@ class Base_V2(nn.Module):
it = 0
weights = 0
bsz = len( loss_targets )
for seq, level in zip( loss_targets, loss_levels ):
for seq, loss_factor in zip( loss_targets, loss_factors ):
seq_len = seq.shape[0]
start = it
it += seq_len
end = it
nll += nlls[start:end].mean() * level_loss_factors[level]
weights += level_loss_factors[level]
nll += nlls[start:end].mean() * loss_factor
weights += loss_factor
# normalize by batch
nll /= bsz
@ -1341,6 +1090,7 @@ class Base_V2(nn.Module):
).to(loss_logit.device)
acc_k_hi = accuracy_metric( loss_logit, loss_target )
# to-do: re-add reporting split losses
if nll is not None:
if 'nll' not in loss:
loss['nll'] = []
@ -1442,6 +1192,7 @@ class Base_V2(nn.Module):
if self.use_streamlined_calc_loss:
logits = self.audio_decoder( output.logits )
# to-do: get len logits
else:
logits = [ logit for logit in output.logits ]
grouped_logits = {}
@ -1486,8 +1237,7 @@ class Base_V2(nn.Module):
# compute loss if the target is given
else:
loss_func = self.calc_loss_specialized if self.use_streamlined_calc_loss else self.calc_loss
loss, stats = loss_func( inputs=inputs, logits=logits, quant_levels=quant_levels )
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
# include any additional losses (for example: MoE router)
if output.loss is not None: