This commit is contained in:
mrq 2024-06-05 21:02:05 -05:00
parent 4073656293
commit b05a905b95
2 changed files with 12 additions and 5 deletions

View File

@ -142,6 +142,10 @@ def load_engines(training=True):
for k in erase:
del state[k]
# resize text embedding
if cfg.model.text_tokens != state["text_emb.weight"].shape[0]:
state["text_emb.weight"] = state["text_emb.weight"][:cfg.model.text_tokens]
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
hyper_config = model.hyper_config

View File

@ -703,8 +703,9 @@ class Base(nn.Module):
target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) )
batch_size = len(target_list)
# modify only for the AR so it can properly behave like a transformer
for i in range(len(target_list)):
for i in range(batch_size):
if quant_levels is not None and quant_levels[i] > 0:
continue
@ -725,10 +726,10 @@ class Base(nn.Module):
)
else:
self.loss = dict(
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( target_list, logits ) ]) / len(batch)
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( target_list, logits ) ]) / batch_size
)
self.stats = dict(
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / len(batch)
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
)
return
@ -745,6 +746,8 @@ class Base(nn.Module):
self.stats = dict(acc = dict())
info = {}
batch_size = len( inputs )
for i, batch in enumerate( inputs ):
quant_level = quant_levels[i] if quant_levels is not None else None
@ -799,8 +802,8 @@ class Base(nn.Module):
# probably consumes less memory due to not having to allocate memory
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
else:
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / len(batch)
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / len(batch)
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
# accuracy sometimes breaks for mamba