ugh
This commit is contained in:
parent
4073656293
commit
b05a905b95
|
@ -142,6 +142,10 @@ def load_engines(training=True):
|
|||
for k in erase:
|
||||
del state[k]
|
||||
|
||||
# resize text embedding
|
||||
if cfg.model.text_tokens != state["text_emb.weight"].shape[0]:
|
||||
state["text_emb.weight"] = state["text_emb.weight"][:cfg.model.text_tokens]
|
||||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
hyper_config = model.hyper_config
|
||||
|
|
|
@ -703,8 +703,9 @@ class Base(nn.Module):
|
|||
|
||||
target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) )
|
||||
|
||||
batch_size = len(target_list)
|
||||
# modify only for the AR so it can properly behave like a transformer
|
||||
for i in range(len(target_list)):
|
||||
for i in range(batch_size):
|
||||
if quant_levels is not None and quant_levels[i] > 0:
|
||||
continue
|
||||
|
||||
|
@ -725,10 +726,10 @@ class Base(nn.Module):
|
|||
)
|
||||
else:
|
||||
self.loss = dict(
|
||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( target_list, logits ) ]) / len(batch)
|
||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
||||
)
|
||||
self.stats = dict(
|
||||
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / len(batch)
|
||||
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
|
||||
)
|
||||
|
||||
return
|
||||
|
@ -745,6 +746,8 @@ class Base(nn.Module):
|
|||
self.stats = dict(acc = dict())
|
||||
|
||||
info = {}
|
||||
batch_size = len( inputs )
|
||||
|
||||
for i, batch in enumerate( inputs ):
|
||||
quant_level = quant_levels[i] if quant_levels is not None else None
|
||||
|
||||
|
@ -799,8 +802,8 @@ class Base(nn.Module):
|
|||
# probably consumes less memory due to not having to allocate memory
|
||||
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
|
||||
else:
|
||||
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / len(batch)
|
||||
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / len(batch)
|
||||
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
|
||||
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
|
||||
|
||||
# accuracy sometimes breaks for mamba
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user