another optimization (within the dataloader because the similar utterance sampler was mondo slow)
This commit is contained in:
parent
5e9d1a5302
commit
00d1fed217
|
@ -703,9 +703,14 @@ def _get_artifact_path(path):
|
|||
return _replace_file_extension(path, _get_artifact_extension())
|
||||
|
||||
_durations_map = {}
|
||||
_similar_map = {}
|
||||
|
||||
def _get_duration_map( type="training" ):
|
||||
return _durations_map[type] if type in _durations_map else {}
|
||||
|
||||
def _get_similar_map( type="training" ):
|
||||
return _similar_map[type] if type in _similar_map else {}
|
||||
|
||||
def _load_paths(dataset, type="training", silent=not is_global_leader(), dataset_hash_key=None):
|
||||
assert cfg.dataset.min_duration >= 1.0, "Minimum duration too low."
|
||||
|
||||
|
@ -716,10 +721,14 @@ def _load_paths(dataset, type="training", silent=not is_global_leader(), dataset
|
|||
|
||||
cached_durations_path = cached_dir / f"durations[{type}].json"
|
||||
cached_paths_path = cached_dir / f"dataloader[{type}].json"
|
||||
cached_similar_path = cached_dir / f"similar[{type}].json"
|
||||
|
||||
# load the duration table first, since this is independent from the loaded paths
|
||||
if cached_durations_path.exists():
|
||||
_durations_map[type] = json_read( cached_durations_path )
|
||||
# load the similar paths table as well, since this is also independent
|
||||
if cached_similar_path.exists():
|
||||
_similar_map[type] = json_read( cached_similar_path )
|
||||
|
||||
# load the cached valid paths (if we're requesting cache use)
|
||||
if cached_paths_path.exists() and cfg.dataset.cache:
|
||||
|
@ -734,6 +743,7 @@ def _load_paths(dataset, type="training", silent=not is_global_leader(), dataset
|
|||
if not cached_dir.exists():
|
||||
cached_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
json_write( _similar_map[type], cached_similar_path, truncate=True )
|
||||
json_write( _durations_map[type], cached_durations_path, truncate=True )
|
||||
json_write( paths, cached_paths_path, truncate=True )
|
||||
|
||||
|
@ -769,9 +779,11 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
|||
|
||||
return (data_dir / id).with_suffix(_get_artifact_extension()).exists()
|
||||
|
||||
metadata_keys = list(metadata.keys())
|
||||
def _validate( id, entry ):
|
||||
phones = entry['phones'] if "phones" in entry else 0
|
||||
duration = entry['duration'] if "duration" in entry else 0
|
||||
phones = entry.get('phones', 0)
|
||||
duration = entry.get('duration', 0)
|
||||
similar = entry.get('similar', None)
|
||||
|
||||
k = key(id, entry)
|
||||
|
||||
|
@ -780,6 +792,11 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
|||
_durations_map[type] = {}
|
||||
_durations_map[type][k] = duration
|
||||
|
||||
# add to similar bucket
|
||||
if type not in _similar_map:
|
||||
_similar_map[type] = {}
|
||||
_similar_map[type][k] = [ metadata_keys[idx] for idx in similar ] if similar else None
|
||||
|
||||
if not validate:
|
||||
return True
|
||||
|
||||
|
@ -1188,43 +1205,28 @@ class Dataset(_Dataset):
|
|||
if offset is None:
|
||||
offset = cfg.dataset.prompt_similar_top_k_offset
|
||||
|
||||
root = Path( *path.parts[:-1] )
|
||||
reference = path.name
|
||||
similars = _similar_map[self.dataset_type].get(str(path), None)
|
||||
|
||||
if cfg.dataset.use_hdf5:
|
||||
root = Path( *path.parts[:-1] )
|
||||
path = Path( *path.parts[2:-1] )
|
||||
else:
|
||||
root = Path( *path.parts[:-1] )
|
||||
path = Path(*path.parts[len(cfg.data_dir.parts):-1])
|
||||
|
||||
metadata = json_read( cfg.metadata_dir / path.with_suffix(".json"), default={} )
|
||||
|
||||
if reference not in metadata:
|
||||
if not similars:
|
||||
return None
|
||||
|
||||
reference_metadata = metadata[reference]
|
||||
|
||||
if "similar" not in reference_metadata:
|
||||
return None
|
||||
|
||||
if len(reference_metadata["similar"]) >= offset:
|
||||
if len(similars) >= offset:
|
||||
offset = 0
|
||||
|
||||
# cringe stopgap
|
||||
offset_end = offset + cfg.dataset.prompt_similar_top_k
|
||||
if offset >= len( reference_metadata["similar"] ):
|
||||
return None
|
||||
if offset_end >= len( reference_metadata["similar"] ):
|
||||
return None
|
||||
|
||||
metadata_keys = list(metadata.keys())
|
||||
if offset >= len( similars ):
|
||||
return None
|
||||
if offset_end >= len( similars ):
|
||||
return None
|
||||
|
||||
if cfg.dataset.prompt_similar_top_k > 1:
|
||||
indices = reference_metadata["similar"][offset:offset_end]
|
||||
index = random.choice( indices )
|
||||
name = random.choice( similars[offset:offset_end] )
|
||||
else:
|
||||
index = reference_metadata["similar"][offset]
|
||||
name = metadata_keys[index]
|
||||
name = similars[offset]
|
||||
|
||||
path = root / name
|
||||
|
||||
|
|
|
@ -409,6 +409,16 @@ class Attention(nn.Module):
|
|||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
elif mode in [torch.nn.attention.SDPBackend.FLASH_ATTENTION]:
|
||||
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=None, # ROCm FA2 through SDPA doesn't allow masks, bummer
|
||||
dropout_p=dropout_rate,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
else:
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
|
|
|
@ -329,13 +329,13 @@ class Base_V2(nn.Module):
|
|||
ignore_inputs_for_loss = config.experimental.ignore_inputs_for_loss if config is not None else False
|
||||
|
||||
resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True
|
||||
len_parallel_training = False # config.experimental.len_parallel_training if config is not None else True
|
||||
predict_causally = config.experimental.predict_causally if config is not None else False
|
||||
monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False
|
||||
audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto"
|
||||
logit_normalization = config.experimental.logit_normalization if config is not None else 0
|
||||
per_level_normalization = config.experimental.per_level_normalization if config is not None else True
|
||||
use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True
|
||||
use_streamlined_calc_loss = config.experimental.use_streamlined_calc_loss if config is not None else True
|
||||
|
||||
n_vocab = 256
|
||||
n_tasks = config.tasks if config is not None else 8
|
||||
|
@ -395,6 +395,7 @@ class Base_V2(nn.Module):
|
|||
self.n_max_levels = self.config.max_levels if self.config else n_resp_levels
|
||||
self.capabilities = self.config.capabilities if self.config else ["ar", "nar"]
|
||||
self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True
|
||||
self.use_streamlined_calc_loss = True
|
||||
|
||||
self.stop_token = self.n_audio_tokens
|
||||
self.mask_token = self.stop_token + 1
|
||||
|
@ -414,9 +415,9 @@ class Base_V2(nn.Module):
|
|||
self.teaching = True
|
||||
self.training = False
|
||||
|
||||
self.resp_parallel_training = resp_parallel_training
|
||||
self.predict_causally = predict_causally
|
||||
|
||||
self.resp_parallel_training = resp_parallel_training
|
||||
self.len_parallel_training = len_parallel_training
|
||||
self.unified_position_ids = unified_position_ids
|
||||
self.inject_timestep_embedding = False # results in bad output
|
||||
self.masking_ratio = masking_ratio
|
||||
|
@ -425,7 +426,6 @@ class Base_V2(nn.Module):
|
|||
self.audio_level_loss_factors = audio_level_loss_factors
|
||||
self.logit_normalization = logit_normalization
|
||||
self.use_segmented_attention_mask = use_segmented_attention_mask
|
||||
self.use_streamlined_calc_loss = use_streamlined_calc_loss
|
||||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
|
@ -901,287 +901,7 @@ class Base_V2(nn.Module):
|
|||
self,
|
||||
inputs: list,
|
||||
logits,
|
||||
|
||||
quant_levels: list[int] | None = None,
|
||||
compute_hard_loss = True,
|
||||
compute_acc = True,
|
||||
):
|
||||
loss = {}
|
||||
stats = {}
|
||||
|
||||
device = logits[0].device
|
||||
batch_size = len(logits)
|
||||
classifier_levels = self.get_input( inputs, "classifier_level" )
|
||||
|
||||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_token( input, quant_level ):
|
||||
if isinstance(input, str):
|
||||
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
|
||||
|
||||
return input
|
||||
|
||||
k_lo, k_hi = 1, 20
|
||||
def _calc_loss( logit, sequence, causal = True, level = None ):
|
||||
level_loss_factors = self.audio_level_loss_factors
|
||||
|
||||
# filter tokens that exceed the vocab size
|
||||
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
|
||||
# drop if all tokens are ignored
|
||||
if torch.all(sequence == self.ignore_index):
|
||||
return None, None
|
||||
|
||||
# shift if causal
|
||||
if causal or self.predict_causally:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||
sequence = sequence[..., l:] # ...predicts token n + 1
|
||||
|
||||
batched = sequence.dim() > 1
|
||||
|
||||
# logit normalization
|
||||
if self.logit_normalization:
|
||||
# it would probably be better to unsqueeze then squeeze to avoid code duplication but who cares
|
||||
if not batched:
|
||||
logit = logit_normalization( logit, self.logit_normalization )
|
||||
else:
|
||||
for i, l in enumerate( logit ):
|
||||
logit[i] = logit_normalization( l, self.logit_normalization )
|
||||
|
||||
# flatten batch
|
||||
if batched:
|
||||
logit = logit.reshape(-1, logit.shape[-1])
|
||||
sequence = sequence.reshape(-1)
|
||||
|
||||
nll = None
|
||||
acc_k_lo = None
|
||||
|
||||
if compute_hard_loss:
|
||||
reduction = 'mean' if not batched else 'none'
|
||||
weight = level_loss_factors[level] if level is not None and not batched else 1
|
||||
loss_func = F.cross_entropy # to-do: add mse_loss
|
||||
loss_kwargs = dict(ignore_index=self.ignore_index) if loss_func == F.cross_entropy else {}
|
||||
|
||||
nll = loss_func( logit, sequence, reduction=reduction, **loss_kwargs ) * weight
|
||||
# manually weigh each level
|
||||
if batched:
|
||||
nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_loss_factors, device=device)
|
||||
|
||||
if compute_acc:
|
||||
if logit.shape[0] >= k_lo:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
logit.shape[-1],
|
||||
top_k = 1,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index = -100
|
||||
).to(logit.device)
|
||||
acc_k_lo = accuracy_metric( logit, sequence )
|
||||
|
||||
if logit.shape[0] >= k_hi:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
logit.shape[-1],
|
||||
top_k = 20,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index = -100
|
||||
).to(logit.device)
|
||||
acc_k_hi = accuracy_metric( logit, sequence )
|
||||
|
||||
return nll, acc_k_lo, acc_k_hi
|
||||
|
||||
for batch_index, batch in enumerate(inputs):
|
||||
quant_level = quant_levels[batch_index]
|
||||
target = []
|
||||
causal = True
|
||||
task_type = "tts"
|
||||
dropout_mask = None
|
||||
classifier_level = None
|
||||
output_len = 0
|
||||
|
||||
for name, input in batch:
|
||||
if name == "task":
|
||||
task_type = input
|
||||
elif name == "dropout_mask":
|
||||
dropout_mask = input
|
||||
elif name == "classifier_level":
|
||||
classifier_level = input
|
||||
|
||||
# autoregressive, causal
|
||||
if classifier_level.startswith("AR:"):
|
||||
causal = True
|
||||
# nonautoregressive, parallel
|
||||
elif classifier_level.startswith("NAR:"):
|
||||
causal = False
|
||||
|
||||
it = 0
|
||||
for name, input in batch:
|
||||
token = None
|
||||
ignored = False
|
||||
|
||||
# non-tokened tasks
|
||||
if name in non_tokened_names:
|
||||
continue
|
||||
# prom can either be a tensor itself or a list of tensors and strings
|
||||
if name == "prom":
|
||||
# expand to list if not a list
|
||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
# iterate over the list to inject their tokens
|
||||
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
|
||||
|
||||
if logits[batch_index].dim() < 3 and token.dim() >= 2:
|
||||
token = token[..., 0]
|
||||
elif name == "resp":
|
||||
token = input
|
||||
|
||||
# mask found, apply it
|
||||
if dropout_mask is not None:
|
||||
token = _dropout_codes( token, dropout_mask, self.ignore_index, swapped = True )
|
||||
# not a special input, inject as-is
|
||||
else:
|
||||
token = input
|
||||
|
||||
if not isinstance(token, torch.Tensor):
|
||||
continue
|
||||
|
||||
if token.is_floating_point():
|
||||
ignored = True
|
||||
|
||||
# grab range of our logits for later
|
||||
seq_len = token.shape[0]
|
||||
start, end = it, it+seq_len
|
||||
it += seq_len + 1 # +1 to incorporate the separator
|
||||
|
||||
# deduce if a name for a task is an input or output
|
||||
if name != task_outputs.get(task_type, name):
|
||||
if self.ignore_inputs_for_loss:
|
||||
ignored = True
|
||||
else:
|
||||
output_len = seq_len
|
||||
|
||||
if ignored:
|
||||
# pruned
|
||||
if self.config.loss_factors:
|
||||
continue
|
||||
# fill with ignored out tensor
|
||||
token = torch.tensor( [ self.ignore_index ] * token.shape[0], device=device, dtype=torch.int16)
|
||||
|
||||
# perform loss calculation on the individual piece
|
||||
if self.config.loss_factors:
|
||||
loss_factor = self.loss_factor(name)
|
||||
|
||||
if loss_factor == 0.0:
|
||||
continue
|
||||
|
||||
if logits[batch_index].dim() < 3:
|
||||
nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][start:end], token.long(), causal )
|
||||
elif not self.resp_parallel_training:
|
||||
# cringe way to deduce "requested" level
|
||||
level = quant_level
|
||||
for i in range( self.n_resp_levels ):
|
||||
if classifier_level.endswith(f':{i}:{i}'):
|
||||
level = i
|
||||
break
|
||||
|
||||
if name == "resp":
|
||||
name = f'{name}[{level}]'
|
||||
|
||||
sequence = token if token.dim() <= 1 else token[:, level]
|
||||
nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level )
|
||||
else:
|
||||
sequence = token.t()
|
||||
nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
|
||||
|
||||
if nll is not None:
|
||||
nll = nll.mean()
|
||||
|
||||
loss_key = f'{name}.nll'
|
||||
acc_k_lo_key = f'{name}.acc[k={k_lo}]'
|
||||
acc_k_hi_key = f'{name}.acc[k={k_hi}]'
|
||||
if nll is not None:
|
||||
if loss_key not in loss:
|
||||
loss[loss_key] = []
|
||||
loss[loss_key].append( nll * loss_factor )
|
||||
|
||||
if acc_k_lo is not None:
|
||||
if acc_k_lo_key not in stats:
|
||||
stats[acc_k_lo_key] = []
|
||||
stats[acc_k_lo_key].append( acc_k_lo )
|
||||
|
||||
if acc_k_hi is not None:
|
||||
if acc_k_hi_key not in stats:
|
||||
stats[acc_k_hi_key] = []
|
||||
stats[acc_k_hi_key].append( acc_k_hi )
|
||||
# add to list
|
||||
else:
|
||||
target.append( token )
|
||||
|
||||
|
||||
# perform loss calculation on the entire sequence
|
||||
if not self.config.loss_factors:
|
||||
if logits[batch_index].dim() < 3:
|
||||
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||
nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index], sequence, causal )
|
||||
elif not self.resp_parallel_training:
|
||||
# cringe way to deduce "requested" level
|
||||
level = 0
|
||||
for i in range( self.n_resp_levels ):
|
||||
if classifier_level.endswith(f':{i}:{i}'):
|
||||
level = i
|
||||
break
|
||||
|
||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||
nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][level], sequence.long(), causal, level )
|
||||
else:
|
||||
nlls = []
|
||||
acc_k_los = []
|
||||
acc_k_his = []
|
||||
|
||||
for level, logit in enumerate( logits[batch_index] ):
|
||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||
nll, acc_k_lo, acc_k_hi = _calc_loss( logit, sequence, causal, level )
|
||||
|
||||
if nll:
|
||||
nlls.append( nll )
|
||||
if acc_k_lo:
|
||||
acc_k_los.append( acc_k_lo )
|
||||
if acc_k_hi:
|
||||
acc_k_his.append( acc_k_hi )
|
||||
|
||||
if nlls:
|
||||
nll = sum(nlls) / len(nlls)
|
||||
if acc_k_los:
|
||||
acc_k_lo = sum(acc_k_los) / len(acc_k_los)
|
||||
if acc_k_his:
|
||||
acc_k_hi = sum(acc_k_his) / len(acc_k_his)
|
||||
|
||||
if nll is not None:
|
||||
if 'nll' not in loss:
|
||||
loss['nll'] = []
|
||||
loss["nll"].append( nll )
|
||||
|
||||
if acc_k_lo is not None:
|
||||
if f'acc[k={k_lo}]' not in stats:
|
||||
stats[f'acc[k={k_lo}]'] = []
|
||||
stats[f"acc[k={k_lo}]"].append( acc_k_lo )
|
||||
|
||||
if acc_k_hi is not None:
|
||||
if f'acc[k={k_hi}]' not in stats:
|
||||
stats[f'acc[k={k_hi}]'] = []
|
||||
stats[f"acc[k={k_hi}]"].append( acc_k_hi )
|
||||
|
||||
# average
|
||||
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
||||
stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() }
|
||||
|
||||
return LossStats(loss, stats)
|
||||
|
||||
# this is a specialized loss calculation that makes a lot of assumptions to try and streamline it by doing one loss calc instead of many
|
||||
def calc_loss_specialized(
|
||||
self,
|
||||
inputs: list,
|
||||
logits,
|
||||
logits_aux = None,
|
||||
|
||||
quant_levels: list[int] | None = None,
|
||||
compute_hard_loss = True,
|
||||
|
@ -1204,9 +924,13 @@ class Base_V2(nn.Module):
|
|||
k_lo, k_hi = 1, 20
|
||||
level_loss_factors = self.audio_level_loss_factors
|
||||
|
||||
# this could be one array of tuples but can't be assed
|
||||
loss_targets = []
|
||||
loss_logits = []
|
||||
loss_levels = []
|
||||
loss_factors = []
|
||||
loss_names = []
|
||||
|
||||
resp_durations = []
|
||||
|
||||
for batch_index, batch in enumerate(inputs):
|
||||
quant_level = quant_levels[batch_index]
|
||||
|
@ -1214,7 +938,6 @@ class Base_V2(nn.Module):
|
|||
task_type = "tts"
|
||||
dropout_mask = None
|
||||
classifier_level = None
|
||||
output_len = 0
|
||||
|
||||
for name, input in batch:
|
||||
if name == "task":
|
||||
|
@ -1273,19 +996,45 @@ class Base_V2(nn.Module):
|
|||
if name != task_outputs.get(task_type, name):
|
||||
continue
|
||||
|
||||
output_len = seq_len
|
||||
|
||||
for level in range( self.n_resp_levels ):
|
||||
if not self.resp_parallel_training and not classifier_level.endswith(f':{level}:{level}'):
|
||||
if token.dim() == 1:
|
||||
loss_factor = self.loss_factor(name)
|
||||
if loss_factor == 0.0:
|
||||
continue
|
||||
|
||||
logit = logits[batch_index][level][start:end]
|
||||
logit = logits[batch_index][start:end]
|
||||
if causal or self.predict_causally:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||
token = sequence[..., l:] # ...predicts token n + 1
|
||||
|
||||
if self.logit_normalization:
|
||||
logit = logit_normalization( logit, self.logit_normalization )
|
||||
|
||||
loss_targets.append( token[:, level].long() )
|
||||
loss_targets.append( token.long() )
|
||||
loss_logits.append( logit )
|
||||
loss_levels.append( level )
|
||||
loss_factors.append( loss_factor )
|
||||
loss_names.append( name )
|
||||
else:
|
||||
if name == "resp" and self.len_parallel_training:
|
||||
resp_durations.append( token.shape[0] )
|
||||
|
||||
for level in range( self.n_resp_levels ):
|
||||
if not self.resp_parallel_training and not classifier_level.endswith(f':{level}:{level}'):
|
||||
continue
|
||||
|
||||
logit = logits[batch_index][level][start:end]
|
||||
if causal or self.predict_causally:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||
token = sequence[..., l:] # ...predicts token n + 1
|
||||
|
||||
if self.logit_normalization:
|
||||
logit = logit_normalization( logit, self.logit_normalization )
|
||||
|
||||
loss_targets.append( token[:, level].long() )
|
||||
loss_logits.append( logit )
|
||||
loss_factors.append( level_loss_factors[level] )
|
||||
loss_names.append( name )
|
||||
|
||||
break
|
||||
|
||||
|
@ -1304,14 +1053,14 @@ class Base_V2(nn.Module):
|
|||
it = 0
|
||||
weights = 0
|
||||
bsz = len( loss_targets )
|
||||
for seq, level in zip( loss_targets, loss_levels ):
|
||||
for seq, loss_factor in zip( loss_targets, loss_factors ):
|
||||
seq_len = seq.shape[0]
|
||||
start = it
|
||||
it += seq_len
|
||||
end = it
|
||||
|
||||
nll += nlls[start:end].mean() * level_loss_factors[level]
|
||||
weights += level_loss_factors[level]
|
||||
nll += nlls[start:end].mean() * loss_factor
|
||||
weights += loss_factor
|
||||
|
||||
# normalize by batch
|
||||
nll /= bsz
|
||||
|
@ -1341,6 +1090,7 @@ class Base_V2(nn.Module):
|
|||
).to(loss_logit.device)
|
||||
acc_k_hi = accuracy_metric( loss_logit, loss_target )
|
||||
|
||||
# to-do: re-add reporting split losses
|
||||
if nll is not None:
|
||||
if 'nll' not in loss:
|
||||
loss['nll'] = []
|
||||
|
@ -1442,6 +1192,7 @@ class Base_V2(nn.Module):
|
|||
|
||||
if self.use_streamlined_calc_loss:
|
||||
logits = self.audio_decoder( output.logits )
|
||||
# to-do: get len logits
|
||||
else:
|
||||
logits = [ logit for logit in output.logits ]
|
||||
grouped_logits = {}
|
||||
|
@ -1486,8 +1237,7 @@ class Base_V2(nn.Module):
|
|||
|
||||
# compute loss if the target is given
|
||||
else:
|
||||
loss_func = self.calc_loss_specialized if self.use_streamlined_calc_loss else self.calc_loss
|
||||
loss, stats = loss_func( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||
|
||||
# include any additional losses (for example: MoE router)
|
||||
if output.loss is not None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user