This commit is contained in:
mrq 2024-06-05 21:02:05 -05:00
parent 4073656293
commit b05a905b95
2 changed files with 12 additions and 5 deletions

View File

@ -142,6 +142,10 @@ def load_engines(training=True):
for k in erase: for k in erase:
del state[k] del state[k]
# resize text embedding
if cfg.model.text_tokens != state["text_emb.weight"].shape[0]:
state["text_emb.weight"] = state["text_emb.weight"][:cfg.model.text_tokens]
model.load_state_dict(state, strict=cfg.trainer.strict_loading) model.load_state_dict(state, strict=cfg.trainer.strict_loading)
hyper_config = model.hyper_config hyper_config = model.hyper_config

View File

@ -703,8 +703,9 @@ class Base(nn.Module):
target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) ) target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) )
batch_size = len(target_list)
# modify only for the AR so it can properly behave like a transformer # modify only for the AR so it can properly behave like a transformer
for i in range(len(target_list)): for i in range(batch_size):
if quant_levels is not None and quant_levels[i] > 0: if quant_levels is not None and quant_levels[i] > 0:
continue continue
@ -725,10 +726,10 @@ class Base(nn.Module):
) )
else: else:
self.loss = dict( self.loss = dict(
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( target_list, logits ) ]) / len(batch) nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( target_list, logits ) ]) / batch_size
) )
self.stats = dict( self.stats = dict(
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / len(batch) acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
) )
return return
@ -745,6 +746,8 @@ class Base(nn.Module):
self.stats = dict(acc = dict()) self.stats = dict(acc = dict())
info = {} info = {}
batch_size = len( inputs )
for i, batch in enumerate( inputs ): for i, batch in enumerate( inputs ):
quant_level = quant_levels[i] if quant_levels is not None else None quant_level = quant_levels[i] if quant_levels is not None else None
@ -799,8 +802,8 @@ class Base(nn.Module):
# probably consumes less memory due to not having to allocate memory # probably consumes less memory due to not having to allocate memory
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed) # this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
else: else:
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / len(batch) self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / len(batch) self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
# accuracy sometimes breaks for mamba # accuracy sometimes breaks for mamba