one more time (could have sworn i tested it with batch size > 1)
This commit is contained in:
parent
6cea840710
commit
93044829af
|
@ -835,6 +835,7 @@ def example_usage():
|
|||
# cfg.model.experimental.masking_train_p = 0.5
|
||||
cfg.hyperparameters.batch_size = 1
|
||||
cfg.hyperparameters.gradient_accumulation_steps = 1
|
||||
cfg.model.experimental.use_raw_text_p = 0
|
||||
|
||||
setup_logging()
|
||||
|
||||
|
|
|
@ -1297,10 +1297,19 @@ class Base_V2(nn.Module):
|
|||
acc_k_hi = None
|
||||
|
||||
if compute_hard_loss:
|
||||
weight = torch.tensor( [ level_loss_factors[level] for level in loss_levels ], device=logit.device )
|
||||
nll = F.cross_entropy( loss_logit, loss_target, reduction='none', ignore_index=self.ignore_index )
|
||||
nll = nll.view( batch_size, 1 if not self.resp_parallel_training else self.n_resp_levels, -1 ).mean(dim=-1) * weight
|
||||
nll = nll.mean()
|
||||
nll = 0
|
||||
nlls = F.cross_entropy( loss_logit, loss_target, reduction='none', ignore_index=self.ignore_index )
|
||||
|
||||
it = 0
|
||||
for seq, level in zip( loss_targets, loss_levels ):
|
||||
seq_len = seq.shape[0]
|
||||
start = it
|
||||
it += seq_len
|
||||
end = it
|
||||
|
||||
nll += nlls[start:end].mean() * level_loss_factors[level]
|
||||
|
||||
nll /= len( loss_targets )
|
||||
|
||||
if compute_acc:
|
||||
n_vocab = loss_logit.shape[-1]
|
||||
|
|
Loading…
Reference in New Issue
Block a user