This commit is contained in:
mrq 2025-02-24 21:03:23 -06:00
parent 918e0dbac1
commit a5a04c39ef

View File

@ -1640,19 +1640,18 @@ class Base(nn.Module):
if loss_factor == 0.0: if loss_factor == 0.0:
continue continue
# cringe way to deduce "requested" level
level = quant_level
for i in range( self.n_resp_levels ):
if classifier_level == f'NAR:{i}:{i}':
level = i
break
if logits[batch_index].dim() < 3: if logits[batch_index].dim() < 3:
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal ) nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
if name == "resp": if name == "resp":
name = f'{name}[{quant_level}]' name = f'{name}[{quant_level}]'
elif not self.resp_parallel_training: elif not self.resp_parallel_training:
# cringe way to deduce "requested" level
level = quant_level
for i in range( self.n_resp_levels ):
if classifier_level.endswith(f':{i}:{i}'):
level = i
break
if name == "resp": if name == "resp":
name = f'{name}[{level}]' name = f'{name}[{level}]'
sequence = token if token.dim() <= 1 else token[:, level] sequence = token if token.dim() <= 1 else token[:, level]
@ -1711,9 +1710,10 @@ class Base(nn.Module):
# cringe way to deduce "requested" level # cringe way to deduce "requested" level
level = 0 level = 0
for i in range( self.n_resp_levels ): for i in range( self.n_resp_levels ):
if classifier_level == f'NAR:{i}:{i}': if classifier_level.endswith(f':{i}:{i}'):
level = i level = i
break break
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ] sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) ) sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal ) nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )