fixed some logic errors with training (grabbing wrong quant level...)

This commit is contained in:
mrq 2024-06-07 20:34:36 -05:00
parent eafa622be2
commit b0158a61d5

View File

@ -664,10 +664,12 @@ class Base(nn.Module):
elif name == "lang" and self.langs_emb is not None:
embedding = self.langs_emb( input )
elif name == "prom":
# get RVQ level 0, or up to targetted RVQ level inference
embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level if self.version >= 5 else None )
elif name == "tone" and self.tones_emb is not None:
embedding = self.tones_emb( input )
elif name == "resp":
# get RVQ level 0, or up to targetted RVQ level inference
embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level )
else:
continue
@ -688,7 +690,10 @@ class Base(nn.Module):
# old, "naive" way, no loss factoring
if not self.config.loss_factors:
target_list = []
for batch_index, batch in enumerate(inputs):
quant_level = quant_levels[batch_index]
prev_quant_level = 0 if quant_level == 0 else quant_level - 1
target = []
for name, input in batch:
if name == "prom":
@ -697,9 +702,9 @@ class Base(nn.Module):
target.append( torch.full_like(input[..., 0], self.ignore_index) )
# we *CAN* directly map to proms
else:
target.append( input if input.dim() == 1 else input[:, quant_level-1] )
target.append( input if input.dim() == 1 else input[:, prev_quant_level] )
elif name == "resp":
target.append( input if input.dim() == 1 else input[:, quant_level-1] )
target.append( input if input.dim() == 1 else input[:, quant_level] )
elif name in ["text", "quant_level", "lang", "tone"]:
target.append( input )
@ -729,7 +734,7 @@ class Base(nn.Module):
)
else:
self.loss = dict(
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( target_list, logits ) ]) / batch_size
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
)
self.stats = dict(
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
@ -752,7 +757,8 @@ class Base(nn.Module):
batch_size = len( inputs )
for i, batch in enumerate( inputs ):
quant_level = quant_levels[i] if quant_levels is not None else None
quant_level = quant_levels[i]
prev_quant_level = 0 if quant_level == 0 else quant_level - 1
it = 0
for name, input in batch:
@ -760,8 +766,8 @@ class Base(nn.Module):
if name == "resp":
input = input if input.dim() == 1 else input[:, quant_level]
# select prom level
elif name == "prom" and quant_level is not None:
input = input[:, quant_level]
elif name == "prom":
input = input[:, prev_quant_level]
seq_len = input.shape[0]