ugh
This commit is contained in:
parent
4073656293
commit
b05a905b95
|
@ -142,6 +142,10 @@ def load_engines(training=True):
|
||||||
for k in erase:
|
for k in erase:
|
||||||
del state[k]
|
del state[k]
|
||||||
|
|
||||||
|
# resize text embedding
|
||||||
|
if cfg.model.text_tokens != state["text_emb.weight"].shape[0]:
|
||||||
|
state["text_emb.weight"] = state["text_emb.weight"][:cfg.model.text_tokens]
|
||||||
|
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
hyper_config = model.hyper_config
|
hyper_config = model.hyper_config
|
||||||
|
|
|
@ -703,8 +703,9 @@ class Base(nn.Module):
|
||||||
|
|
||||||
target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) )
|
target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) )
|
||||||
|
|
||||||
|
batch_size = len(target_list)
|
||||||
# modify only for the AR so it can properly behave like a transformer
|
# modify only for the AR so it can properly behave like a transformer
|
||||||
for i in range(len(target_list)):
|
for i in range(batch_size):
|
||||||
if quant_levels is not None and quant_levels[i] > 0:
|
if quant_levels is not None and quant_levels[i] > 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -725,10 +726,10 @@ class Base(nn.Module):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.loss = dict(
|
self.loss = dict(
|
||||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( target_list, logits ) ]) / len(batch)
|
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
||||||
)
|
)
|
||||||
self.stats = dict(
|
self.stats = dict(
|
||||||
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / len(batch)
|
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -745,6 +746,8 @@ class Base(nn.Module):
|
||||||
self.stats = dict(acc = dict())
|
self.stats = dict(acc = dict())
|
||||||
|
|
||||||
info = {}
|
info = {}
|
||||||
|
batch_size = len( inputs )
|
||||||
|
|
||||||
for i, batch in enumerate( inputs ):
|
for i, batch in enumerate( inputs ):
|
||||||
quant_level = quant_levels[i] if quant_levels is not None else None
|
quant_level = quant_levels[i] if quant_levels is not None else None
|
||||||
|
|
||||||
|
@ -799,8 +802,8 @@ class Base(nn.Module):
|
||||||
# probably consumes less memory due to not having to allocate memory
|
# probably consumes less memory due to not having to allocate memory
|
||||||
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
|
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
|
||||||
else:
|
else:
|
||||||
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / len(batch)
|
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
|
||||||
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / len(batch)
|
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
|
||||||
|
|
||||||
# accuracy sometimes breaks for mamba
|
# accuracy sometimes breaks for mamba
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user