Compare commits
2 Commits
462f71e2f7
...
0d809561c6
Author | SHA1 | Date | |
---|---|---|---|
0d809561c6 | |||
2fb2b732fc |
|
@ -72,6 +72,7 @@ The optimizer used *mostly* doesn't matter, as AdamW seems to get moving faster,
|
||||||
* `APOLLO` needs more testing, but seemed adequate in cursory tests
|
* `APOLLO` needs more testing, but seemed adequate in cursory tests
|
||||||
* `Muon` requires much more testing, but absolutely cannot be used for predicting tokens in place (NAR demasking), and requires `cfg.model.experimental.predict_causally=True`
|
* `Muon` requires much more testing, but absolutely cannot be used for predicting tokens in place (NAR demasking), and requires `cfg.model.experimental.predict_causally=True`
|
||||||
* I honestly don't think it gives good enough results from curosry tests for this application
|
* I honestly don't think it gives good enough results from curosry tests for this application
|
||||||
|
* `Adagrad` surprisingly seems to "fix" (for now) my problems with the loss / accuracy bouncing.
|
||||||
|
|
||||||
## Try Me
|
## Try Me
|
||||||
|
|
||||||
|
|
|
@ -54,6 +54,8 @@ try:
|
||||||
heads = config.num_attention_heads
|
heads = config.num_attention_heads
|
||||||
dim_head = getattr(config, "head_dim", dim // heads)
|
dim_head = getattr(config, "head_dim", dim // heads)
|
||||||
kv_heads = config.num_key_value_heads
|
kv_heads = config.num_key_value_heads
|
||||||
|
causal = False # config.causal # to-do: handle split-causal attention like I do for normal attention
|
||||||
|
# for now though leave it as false since the mask transformer variant of VALL-E is much more preferable to the causal variant
|
||||||
|
|
||||||
# to-do: figure out these settings best for VALL-E
|
# to-do: figure out these settings best for VALL-E
|
||||||
compress_block_size = 16
|
compress_block_size = 16
|
||||||
|
@ -83,6 +85,8 @@ try:
|
||||||
num_selected_blocks = num_selected_blocks,
|
num_selected_blocks = num_selected_blocks,
|
||||||
num_compressed_mem_kv = num_compressed_mem_kv,
|
num_compressed_mem_kv = num_compressed_mem_kv,
|
||||||
|
|
||||||
|
causal = causal,
|
||||||
|
|
||||||
norm = False, # pre/post norm is done here already
|
norm = False, # pre/post norm is done here already
|
||||||
use_diff_topk = True,
|
use_diff_topk = True,
|
||||||
use_triton_kernel = False,
|
use_triton_kernel = False,
|
||||||
|
|
|
@ -24,12 +24,14 @@ class Config(BaseConfig):
|
||||||
self,
|
self,
|
||||||
attn_mode = "sdpa",
|
attn_mode = "sdpa",
|
||||||
output_norm = True,
|
output_norm = True,
|
||||||
|
causal = True,
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.attn_mode = attn_mode
|
self.attn_mode = attn_mode
|
||||||
self.output_norm = output_norm
|
self.output_norm = output_norm
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
|
|
@ -458,8 +458,11 @@ class Base_V2(nn.Module):
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
#gradient_checkpointing=self.gradient_checkpointing,
|
#gradient_checkpointing=self.gradient_checkpointing,
|
||||||
|
|
||||||
|
# extra parameters
|
||||||
output_norm = not per_level_normalization, # moves the LN out to the decoder
|
output_norm = not per_level_normalization, # moves the LN out to the decoder
|
||||||
attn_mode = attention_backend,
|
attn_mode = attention_backend,
|
||||||
|
causal = self.causal,
|
||||||
)
|
)
|
||||||
self.model = LlamaModel(self.model_config)
|
self.model = LlamaModel(self.model_config)
|
||||||
|
|
||||||
|
@ -903,7 +906,7 @@ class Base_V2(nn.Module):
|
||||||
sequence = sequence.reshape(-1)
|
sequence = sequence.reshape(-1)
|
||||||
|
|
||||||
nll = None
|
nll = None
|
||||||
metrics = None
|
acc_k1 = None
|
||||||
|
|
||||||
if compute_hard_loss:
|
if compute_hard_loss:
|
||||||
reduction = 'mean' if not batched else 'none'
|
reduction = 'mean' if not batched else 'none'
|
||||||
|
@ -917,14 +920,23 @@ class Base_V2(nn.Module):
|
||||||
if compute_acc:
|
if compute_acc:
|
||||||
accuracy_metric = MulticlassAccuracy(
|
accuracy_metric = MulticlassAccuracy(
|
||||||
logit.shape[-1],
|
logit.shape[-1],
|
||||||
top_k = min(logit.shape[0], 10),
|
top_k = 1,
|
||||||
average="micro",
|
average="micro",
|
||||||
multidim_average="global",
|
multidim_average="global",
|
||||||
ignore_index = -100
|
ignore_index = -100
|
||||||
).to(logit.device)
|
).to(logit.device)
|
||||||
metrics = accuracy_metric( logit, sequence )
|
acc_k1 = accuracy_metric( logit, sequence )
|
||||||
|
|
||||||
|
accuracy_metric = MulticlassAccuracy(
|
||||||
|
logit.shape[-1],
|
||||||
|
top_k = min(logit.shape[0], 80),
|
||||||
|
average="micro",
|
||||||
|
multidim_average="global",
|
||||||
|
ignore_index = -100
|
||||||
|
).to(logit.device)
|
||||||
|
acc_k80 = accuracy_metric( logit, sequence )
|
||||||
|
|
||||||
return nll, metrics
|
return nll, acc_k1, acc_k80
|
||||||
|
|
||||||
for batch_index, batch in enumerate(inputs):
|
for batch_index, batch in enumerate(inputs):
|
||||||
quant_level = quant_levels[batch_index]
|
quant_level = quant_levels[batch_index]
|
||||||
|
@ -1010,7 +1022,7 @@ class Base_V2(nn.Module):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if logits[batch_index].dim() < 3:
|
if logits[batch_index].dim() < 3:
|
||||||
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
|
nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][start:end], token.long(), causal )
|
||||||
elif not self.resp_parallel_training:
|
elif not self.resp_parallel_training:
|
||||||
# cringe way to deduce "requested" level
|
# cringe way to deduce "requested" level
|
||||||
level = quant_level
|
level = quant_level
|
||||||
|
@ -1023,25 +1035,31 @@ class Base_V2(nn.Module):
|
||||||
name = f'{name}[{level}]'
|
name = f'{name}[{level}]'
|
||||||
|
|
||||||
sequence = token if token.dim() <= 1 else token[:, level]
|
sequence = token if token.dim() <= 1 else token[:, level]
|
||||||
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level )
|
nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level )
|
||||||
else:
|
else:
|
||||||
sequence = token.t()
|
sequence = token.t()
|
||||||
nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
|
nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
|
||||||
|
|
||||||
if nll is not None:
|
if nll is not None:
|
||||||
nll = nll.mean()
|
nll = nll.mean()
|
||||||
|
|
||||||
loss_key = f'{name}.nll'
|
loss_key = f'{name}.nll'
|
||||||
acc_key = f'{name}.acc'
|
acc_k1_key = f'{name}.acc[k=1]'
|
||||||
|
acc_k80_key = f'{name}.acc[k=80]'
|
||||||
if nll is not None:
|
if nll is not None:
|
||||||
if loss_key not in loss:
|
if loss_key not in loss:
|
||||||
loss[loss_key] = []
|
loss[loss_key] = []
|
||||||
loss[loss_key].append( nll * loss_factor )
|
loss[loss_key].append( nll * loss_factor )
|
||||||
|
|
||||||
if metrics is not None:
|
if acc_k1 is not None:
|
||||||
if acc_key not in stats:
|
if acc_k1_key not in stats:
|
||||||
stats[acc_key] = []
|
stats[acc_k1_key] = []
|
||||||
stats[acc_key].append( metrics )
|
stats[acc_k1_key].append( acc_k1 )
|
||||||
|
|
||||||
|
if acc_k80 is not None:
|
||||||
|
if acc_k80_key not in stats:
|
||||||
|
stats[acc_k80_key] = []
|
||||||
|
stats[acc_k80_key].append( acc_k80 )
|
||||||
# add to list
|
# add to list
|
||||||
else:
|
else:
|
||||||
target.append( token )
|
target.append( token )
|
||||||
|
@ -1051,7 +1069,7 @@ class Base_V2(nn.Module):
|
||||||
if not self.config.loss_factors:
|
if not self.config.loss_factors:
|
||||||
if logits[batch_index].dim() < 3:
|
if logits[batch_index].dim() < 3:
|
||||||
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||||
nll, metrics = _calc_loss( logits[batch_index], sequence, causal )
|
nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index], sequence, causal )
|
||||||
elif not self.resp_parallel_training:
|
elif not self.resp_parallel_training:
|
||||||
# cringe way to deduce "requested" level
|
# cringe way to deduce "requested" level
|
||||||
level = 0
|
level = 0
|
||||||
|
@ -1062,35 +1080,45 @@ class Base_V2(nn.Module):
|
||||||
|
|
||||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||||
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal, level )
|
nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][level], sequence.long(), causal, level )
|
||||||
else:
|
else:
|
||||||
nlls = []
|
nlls = []
|
||||||
accs = []
|
acc_k1s = []
|
||||||
|
acc_k80s = []
|
||||||
|
|
||||||
for level, logit in enumerate( logits[batch_index] ):
|
for level, logit in enumerate( logits[batch_index] ):
|
||||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||||
nll, metrics = _calc_loss( logit, sequence, causal, level )
|
nll, acc_k1, acc_k80 = _calc_loss( logit, sequence, causal, level )
|
||||||
|
|
||||||
if nll:
|
if nll:
|
||||||
nlls.append( nll )
|
nlls.append( nll )
|
||||||
if metrics:
|
if acc_k1:
|
||||||
accs.append( metrics )
|
acc_k1s.append( acc_k1 )
|
||||||
|
if acc_k80:
|
||||||
|
acc_k80s.append( acc_k80 )
|
||||||
|
|
||||||
if nlls:
|
if nlls:
|
||||||
nll = sum(nlls) / len(nlls)
|
nll = sum(nlls) / len(nlls)
|
||||||
if accs:
|
if acc_k1s:
|
||||||
metrics = sum(accs) / len(accs)
|
acc_k1 = sum(acc_k1s) / len(acc_k1s)
|
||||||
|
if acc_k80s:
|
||||||
|
acc_k80 = sum(acc_k80s) / len(acc_k80s)
|
||||||
|
|
||||||
if nll is not None:
|
if nll is not None:
|
||||||
if 'nll' not in loss:
|
if 'nll' not in loss:
|
||||||
loss['nll'] = []
|
loss['nll'] = []
|
||||||
loss["nll"].append( nll )
|
loss["nll"].append( nll )
|
||||||
|
|
||||||
if metrics is not None:
|
if acc_k1 is not None:
|
||||||
if 'acc' not in stats:
|
if 'acc[k=1]' not in stats:
|
||||||
stats['acc'] = []
|
stats['acc[k=1]'] = []
|
||||||
stats["acc"].append( metrics )
|
stats["acc[k=1]"].append( acc_k1 )
|
||||||
|
|
||||||
|
if acc_k80 is not None:
|
||||||
|
if 'acc[k=80]' not in stats:
|
||||||
|
stats['acc[k=80]'] = []
|
||||||
|
stats["acc[k=80]"].append( acc_k80 )
|
||||||
|
|
||||||
# average
|
# average
|
||||||
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
||||||
|
|
Loading…
Reference in New Issue
Block a user