one more time (could have sworn i tested it with batch size > 1)

This commit is contained in:
mrq 2025-03-07 19:14:33 -06:00
parent 6cea840710
commit 93044829af
2 changed files with 14 additions and 4 deletions

View File

@ -835,6 +835,7 @@ def example_usage():
# cfg.model.experimental.masking_train_p = 0.5
cfg.hyperparameters.batch_size = 1
cfg.hyperparameters.gradient_accumulation_steps = 1
cfg.model.experimental.use_raw_text_p = 0
setup_logging()

View File

@ -1297,10 +1297,19 @@ class Base_V2(nn.Module):
acc_k_hi = None
if compute_hard_loss:
weight = torch.tensor( [ level_loss_factors[level] for level in loss_levels ], device=logit.device )
nll = F.cross_entropy( loss_logit, loss_target, reduction='none', ignore_index=self.ignore_index )
nll = nll.view( batch_size, 1 if not self.resp_parallel_training else self.n_resp_levels, -1 ).mean(dim=-1) * weight
nll = nll.mean()
nll = 0
nlls = F.cross_entropy( loss_logit, loss_target, reduction='none', ignore_index=self.ignore_index )
it = 0
for seq, level in zip( loss_targets, loss_levels ):
seq_len = seq.shape[0]
start = it
it += seq_len
end = it
nll += nlls[start:end].mean() * level_loss_factors[level]
nll /= len( loss_targets )
if compute_acc:
n_vocab = loss_logit.shape[-1]