fixed some logic errors with training (grabbing wrong quant level...)
This commit is contained in:
parent
eafa622be2
commit
b0158a61d5
|
@ -664,10 +664,12 @@ class Base(nn.Module):
|
|||
elif name == "lang" and self.langs_emb is not None:
|
||||
embedding = self.langs_emb( input )
|
||||
elif name == "prom":
|
||||
# get RVQ level 0, or up to targetted RVQ level inference
|
||||
embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level if self.version >= 5 else None )
|
||||
elif name == "tone" and self.tones_emb is not None:
|
||||
embedding = self.tones_emb( input )
|
||||
elif name == "resp":
|
||||
# get RVQ level 0, or up to targetted RVQ level inference
|
||||
embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level )
|
||||
else:
|
||||
continue
|
||||
|
@ -688,7 +690,10 @@ class Base(nn.Module):
|
|||
# old, "naive" way, no loss factoring
|
||||
if not self.config.loss_factors:
|
||||
target_list = []
|
||||
|
||||
for batch_index, batch in enumerate(inputs):
|
||||
quant_level = quant_levels[batch_index]
|
||||
prev_quant_level = 0 if quant_level == 0 else quant_level - 1
|
||||
target = []
|
||||
for name, input in batch:
|
||||
if name == "prom":
|
||||
|
@ -697,9 +702,9 @@ class Base(nn.Module):
|
|||
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||
# we *CAN* directly map to proms
|
||||
else:
|
||||
target.append( input if input.dim() == 1 else input[:, quant_level-1] )
|
||||
target.append( input if input.dim() == 1 else input[:, prev_quant_level] )
|
||||
elif name == "resp":
|
||||
target.append( input if input.dim() == 1 else input[:, quant_level-1] )
|
||||
target.append( input if input.dim() == 1 else input[:, quant_level] )
|
||||
elif name in ["text", "quant_level", "lang", "tone"]:
|
||||
target.append( input )
|
||||
|
||||
|
@ -729,7 +734,7 @@ class Base(nn.Module):
|
|||
)
|
||||
else:
|
||||
self.loss = dict(
|
||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
||||
)
|
||||
self.stats = dict(
|
||||
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
|
||||
|
@ -752,7 +757,8 @@ class Base(nn.Module):
|
|||
batch_size = len( inputs )
|
||||
|
||||
for i, batch in enumerate( inputs ):
|
||||
quant_level = quant_levels[i] if quant_levels is not None else None
|
||||
quant_level = quant_levels[i]
|
||||
prev_quant_level = 0 if quant_level == 0 else quant_level - 1
|
||||
|
||||
it = 0
|
||||
for name, input in batch:
|
||||
|
@ -760,8 +766,8 @@ class Base(nn.Module):
|
|||
if name == "resp":
|
||||
input = input if input.dim() == 1 else input[:, quant_level]
|
||||
# select prom level
|
||||
elif name == "prom" and quant_level is not None:
|
||||
input = input[:, quant_level]
|
||||
elif name == "prom":
|
||||
input = input[:, prev_quant_level]
|
||||
|
||||
seq_len = input.shape[0]
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user