fixed some logic errors with training (grabbing wrong quant level...)
This commit is contained in:
parent
eafa622be2
commit
b0158a61d5
|
@ -664,10 +664,12 @@ class Base(nn.Module):
|
||||||
elif name == "lang" and self.langs_emb is not None:
|
elif name == "lang" and self.langs_emb is not None:
|
||||||
embedding = self.langs_emb( input )
|
embedding = self.langs_emb( input )
|
||||||
elif name == "prom":
|
elif name == "prom":
|
||||||
|
# get RVQ level 0, or up to targetted RVQ level inference
|
||||||
embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level if self.version >= 5 else None )
|
embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level if self.version >= 5 else None )
|
||||||
elif name == "tone" and self.tones_emb is not None:
|
elif name == "tone" and self.tones_emb is not None:
|
||||||
embedding = self.tones_emb( input )
|
embedding = self.tones_emb( input )
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
|
# get RVQ level 0, or up to targetted RVQ level inference
|
||||||
embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level )
|
embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level )
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
@ -688,7 +690,10 @@ class Base(nn.Module):
|
||||||
# old, "naive" way, no loss factoring
|
# old, "naive" way, no loss factoring
|
||||||
if not self.config.loss_factors:
|
if not self.config.loss_factors:
|
||||||
target_list = []
|
target_list = []
|
||||||
|
|
||||||
for batch_index, batch in enumerate(inputs):
|
for batch_index, batch in enumerate(inputs):
|
||||||
|
quant_level = quant_levels[batch_index]
|
||||||
|
prev_quant_level = 0 if quant_level == 0 else quant_level - 1
|
||||||
target = []
|
target = []
|
||||||
for name, input in batch:
|
for name, input in batch:
|
||||||
if name == "prom":
|
if name == "prom":
|
||||||
|
@ -697,9 +702,9 @@ class Base(nn.Module):
|
||||||
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||||
# we *CAN* directly map to proms
|
# we *CAN* directly map to proms
|
||||||
else:
|
else:
|
||||||
target.append( input if input.dim() == 1 else input[:, quant_level-1] )
|
target.append( input if input.dim() == 1 else input[:, prev_quant_level] )
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
target.append( input if input.dim() == 1 else input[:, quant_level-1] )
|
target.append( input if input.dim() == 1 else input[:, quant_level] )
|
||||||
elif name in ["text", "quant_level", "lang", "tone"]:
|
elif name in ["text", "quant_level", "lang", "tone"]:
|
||||||
target.append( input )
|
target.append( input )
|
||||||
|
|
||||||
|
@ -729,7 +734,7 @@ class Base(nn.Module):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.loss = dict(
|
self.loss = dict(
|
||||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
||||||
)
|
)
|
||||||
self.stats = dict(
|
self.stats = dict(
|
||||||
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
|
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
|
||||||
|
@ -752,7 +757,8 @@ class Base(nn.Module):
|
||||||
batch_size = len( inputs )
|
batch_size = len( inputs )
|
||||||
|
|
||||||
for i, batch in enumerate( inputs ):
|
for i, batch in enumerate( inputs ):
|
||||||
quant_level = quant_levels[i] if quant_levels is not None else None
|
quant_level = quant_levels[i]
|
||||||
|
prev_quant_level = 0 if quant_level == 0 else quant_level - 1
|
||||||
|
|
||||||
it = 0
|
it = 0
|
||||||
for name, input in batch:
|
for name, input in batch:
|
||||||
|
@ -760,8 +766,8 @@ class Base(nn.Module):
|
||||||
if name == "resp":
|
if name == "resp":
|
||||||
input = input if input.dim() == 1 else input[:, quant_level]
|
input = input if input.dim() == 1 else input[:, quant_level]
|
||||||
# select prom level
|
# select prom level
|
||||||
elif name == "prom" and quant_level is not None:
|
elif name == "prom":
|
||||||
input = input[:, quant_level]
|
input = input[:, prev_quant_level]
|
||||||
|
|
||||||
seq_len = input.shape[0]
|
seq_len = input.shape[0]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user