Compare commits

...

2 Commits

4 changed files with 59 additions and 24 deletions

View File

@ -72,6 +72,7 @@ The optimizer used *mostly* doesn't matter, as AdamW seems to get moving faster,
* `APOLLO` needs more testing, but seemed adequate in cursory tests * `APOLLO` needs more testing, but seemed adequate in cursory tests
* `Muon` requires much more testing, but absolutely cannot be used for predicting tokens in place (NAR demasking), and requires `cfg.model.experimental.predict_causally=True` * `Muon` requires much more testing, but absolutely cannot be used for predicting tokens in place (NAR demasking), and requires `cfg.model.experimental.predict_causally=True`
* I honestly don't think it gives good enough results from curosry tests for this application * I honestly don't think it gives good enough results from curosry tests for this application
* `Adagrad` surprisingly seems to "fix" (for now) my problems with the loss / accuracy bouncing.
## Try Me ## Try Me

View File

@ -54,6 +54,8 @@ try:
heads = config.num_attention_heads heads = config.num_attention_heads
dim_head = getattr(config, "head_dim", dim // heads) dim_head = getattr(config, "head_dim", dim // heads)
kv_heads = config.num_key_value_heads kv_heads = config.num_key_value_heads
causal = False # config.causal # to-do: handle split-causal attention like I do for normal attention
# for now though leave it as false since the mask transformer variant of VALL-E is much more preferable to the causal variant
# to-do: figure out these settings best for VALL-E # to-do: figure out these settings best for VALL-E
compress_block_size = 16 compress_block_size = 16
@ -83,6 +85,8 @@ try:
num_selected_blocks = num_selected_blocks, num_selected_blocks = num_selected_blocks,
num_compressed_mem_kv = num_compressed_mem_kv, num_compressed_mem_kv = num_compressed_mem_kv,
causal = causal,
norm = False, # pre/post norm is done here already norm = False, # pre/post norm is done here already
use_diff_topk = True, use_diff_topk = True,
use_triton_kernel = False, use_triton_kernel = False,

View File

@ -24,12 +24,14 @@ class Config(BaseConfig):
self, self,
attn_mode = "sdpa", attn_mode = "sdpa",
output_norm = True, output_norm = True,
causal = True,
*args, **kwargs *args, **kwargs
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.attn_mode = attn_mode self.attn_mode = attn_mode
self.output_norm = output_norm self.output_norm = output_norm
self.causal = causal
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape batch, num_key_value_heads, slen, head_dim = hidden_states.shape

View File

@ -458,8 +458,11 @@ class Base_V2(nn.Module):
is_encoder_decoder=False, is_encoder_decoder=False,
is_decoder=True, is_decoder=True,
#gradient_checkpointing=self.gradient_checkpointing, #gradient_checkpointing=self.gradient_checkpointing,
# extra parameters
output_norm = not per_level_normalization, # moves the LN out to the decoder output_norm = not per_level_normalization, # moves the LN out to the decoder
attn_mode = attention_backend, attn_mode = attention_backend,
causal = self.causal,
) )
self.model = LlamaModel(self.model_config) self.model = LlamaModel(self.model_config)
@ -903,7 +906,7 @@ class Base_V2(nn.Module):
sequence = sequence.reshape(-1) sequence = sequence.reshape(-1)
nll = None nll = None
metrics = None acc_k1 = None
if compute_hard_loss: if compute_hard_loss:
reduction = 'mean' if not batched else 'none' reduction = 'mean' if not batched else 'none'
@ -917,14 +920,23 @@ class Base_V2(nn.Module):
if compute_acc: if compute_acc:
accuracy_metric = MulticlassAccuracy( accuracy_metric = MulticlassAccuracy(
logit.shape[-1], logit.shape[-1],
top_k = min(logit.shape[0], 10), top_k = 1,
average="micro", average="micro",
multidim_average="global", multidim_average="global",
ignore_index = -100 ignore_index = -100
).to(logit.device) ).to(logit.device)
metrics = accuracy_metric( logit, sequence ) acc_k1 = accuracy_metric( logit, sequence )
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
top_k = min(logit.shape[0], 80),
average="micro",
multidim_average="global",
ignore_index = -100
).to(logit.device)
acc_k80 = accuracy_metric( logit, sequence )
return nll, metrics return nll, acc_k1, acc_k80
for batch_index, batch in enumerate(inputs): for batch_index, batch in enumerate(inputs):
quant_level = quant_levels[batch_index] quant_level = quant_levels[batch_index]
@ -1010,7 +1022,7 @@ class Base_V2(nn.Module):
continue continue
if logits[batch_index].dim() < 3: if logits[batch_index].dim() < 3:
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal ) nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][start:end], token.long(), causal )
elif not self.resp_parallel_training: elif not self.resp_parallel_training:
# cringe way to deduce "requested" level # cringe way to deduce "requested" level
level = quant_level level = quant_level
@ -1023,25 +1035,31 @@ class Base_V2(nn.Module):
name = f'{name}[{level}]' name = f'{name}[{level}]'
sequence = token if token.dim() <= 1 else token[:, level] sequence = token if token.dim() <= 1 else token[:, level]
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level ) nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level )
else: else:
sequence = token.t() sequence = token.t()
nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal ) nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
if nll is not None: if nll is not None:
nll = nll.mean() nll = nll.mean()
loss_key = f'{name}.nll' loss_key = f'{name}.nll'
acc_key = f'{name}.acc' acc_k1_key = f'{name}.acc[k=1]'
acc_k80_key = f'{name}.acc[k=80]'
if nll is not None: if nll is not None:
if loss_key not in loss: if loss_key not in loss:
loss[loss_key] = [] loss[loss_key] = []
loss[loss_key].append( nll * loss_factor ) loss[loss_key].append( nll * loss_factor )
if metrics is not None: if acc_k1 is not None:
if acc_key not in stats: if acc_k1_key not in stats:
stats[acc_key] = [] stats[acc_k1_key] = []
stats[acc_key].append( metrics ) stats[acc_k1_key].append( acc_k1 )
if acc_k80 is not None:
if acc_k80_key not in stats:
stats[acc_k80_key] = []
stats[acc_k80_key].append( acc_k80 )
# add to list # add to list
else: else:
target.append( token ) target.append( token )
@ -1051,7 +1069,7 @@ class Base_V2(nn.Module):
if not self.config.loss_factors: if not self.config.loss_factors:
if logits[batch_index].dim() < 3: if logits[batch_index].dim() < 3:
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
nll, metrics = _calc_loss( logits[batch_index], sequence, causal ) nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index], sequence, causal )
elif not self.resp_parallel_training: elif not self.resp_parallel_training:
# cringe way to deduce "requested" level # cringe way to deduce "requested" level
level = 0 level = 0
@ -1062,35 +1080,45 @@ class Base_V2(nn.Module):
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ] sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) ) sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal, level ) nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][level], sequence.long(), causal, level )
else: else:
nlls = [] nlls = []
accs = [] acc_k1s = []
acc_k80s = []
for level, logit in enumerate( logits[batch_index] ): for level, logit in enumerate( logits[batch_index] ):
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ] sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) ) sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
nll, metrics = _calc_loss( logit, sequence, causal, level ) nll, acc_k1, acc_k80 = _calc_loss( logit, sequence, causal, level )
if nll: if nll:
nlls.append( nll ) nlls.append( nll )
if metrics: if acc_k1:
accs.append( metrics ) acc_k1s.append( acc_k1 )
if acc_k80:
acc_k80s.append( acc_k80 )
if nlls: if nlls:
nll = sum(nlls) / len(nlls) nll = sum(nlls) / len(nlls)
if accs: if acc_k1s:
metrics = sum(accs) / len(accs) acc_k1 = sum(acc_k1s) / len(acc_k1s)
if acc_k80s:
acc_k80 = sum(acc_k80s) / len(acc_k80s)
if nll is not None: if nll is not None:
if 'nll' not in loss: if 'nll' not in loss:
loss['nll'] = [] loss['nll'] = []
loss["nll"].append( nll ) loss["nll"].append( nll )
if metrics is not None: if acc_k1 is not None:
if 'acc' not in stats: if 'acc[k=1]' not in stats:
stats['acc'] = [] stats['acc[k=1]'] = []
stats["acc"].append( metrics ) stats["acc[k=1]"].append( acc_k1 )
if acc_k80 is not None:
if 'acc[k=80]' not in stats:
stats['acc[k=80]'] = []
stats["acc[k=80]"].append( acc_k80 )
# average # average
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() } loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }