when the
This commit is contained in:
parent
918e0dbac1
commit
a5a04c39ef
|
@ -1640,19 +1640,18 @@ class Base(nn.Module):
|
||||||
if loss_factor == 0.0:
|
if loss_factor == 0.0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# cringe way to deduce "requested" level
|
|
||||||
level = quant_level
|
|
||||||
for i in range( self.n_resp_levels ):
|
|
||||||
if classifier_level == f'NAR:{i}:{i}':
|
|
||||||
level = i
|
|
||||||
break
|
|
||||||
|
|
||||||
if logits[batch_index].dim() < 3:
|
if logits[batch_index].dim() < 3:
|
||||||
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
|
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
|
||||||
|
|
||||||
if name == "resp":
|
if name == "resp":
|
||||||
name = f'{name}[{quant_level}]'
|
name = f'{name}[{quant_level}]'
|
||||||
elif not self.resp_parallel_training:
|
elif not self.resp_parallel_training:
|
||||||
|
# cringe way to deduce "requested" level
|
||||||
|
level = quant_level
|
||||||
|
for i in range( self.n_resp_levels ):
|
||||||
|
if classifier_level.endswith(f':{i}:{i}'):
|
||||||
|
level = i
|
||||||
|
break
|
||||||
if name == "resp":
|
if name == "resp":
|
||||||
name = f'{name}[{level}]'
|
name = f'{name}[{level}]'
|
||||||
sequence = token if token.dim() <= 1 else token[:, level]
|
sequence = token if token.dim() <= 1 else token[:, level]
|
||||||
|
@ -1711,9 +1710,10 @@ class Base(nn.Module):
|
||||||
# cringe way to deduce "requested" level
|
# cringe way to deduce "requested" level
|
||||||
level = 0
|
level = 0
|
||||||
for i in range( self.n_resp_levels ):
|
for i in range( self.n_resp_levels ):
|
||||||
if classifier_level == f'NAR:{i}:{i}':
|
if classifier_level.endswith(f':{i}:{i}'):
|
||||||
level = i
|
level = i
|
||||||
break
|
break
|
||||||
|
|
||||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||||
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
|
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
|
||||||
|
|
Loading…
Reference in New Issue
Block a user