fixes, again

This commit is contained in:
mrq 2024-09-06 11:41:41 -05:00
parent 94cf81d38c
commit 341e19162b
3 changed files with 5 additions and 8 deletions

View File

@ -1055,8 +1055,6 @@ class Dataset(_Dataset):
# Base STT (<resp> => <text>) # Base STT (<resp> => <text>)
elif task == "stt": elif task == "stt":
# easier to just keep it instead of wrangling around trying to remove it
# it might also help to provide a guidance prompt but who knows right now
proms = [ proms = [
task task
] ]

View File

@ -135,9 +135,8 @@ class AR_NAR(Base):
for j, prom in enumerate( proms ): for j, prom in enumerate( proms ):
if not isinstance( prom, torch.Tensor ): if not isinstance( prom, torch.Tensor ):
continue continue
if quant_level >= prom.shape[-1]:
if quant_level >= prom.shape[-1]: quant_levels[i] = prom.shape[-1] - 1
quant_levels[i] = prom.shape[-1] - 1
# apply token dropout error compensation # apply token dropout error compensation
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]): if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):

View File

@ -280,7 +280,7 @@ class Classifiers(nn.Module):
xi = [ xi = [
#x if l == 0 else #x if l == 0 else
x if x.shape[-1] == max_size else x if x.shape[-1] == max_size else
torch.cat( [ x, torch.full( (x.shape[0], max_size - x.shape[-1]), -float("inf"), device=device, dtype=dtype) ], dim=-1 ) torch.cat( [x, torch.full( (x.shape[0], max_size - x.shape[-1]), -float("inf"), device=device, dtype=dtype) ], dim=-1 )
for x, l in zip(xi, levels) for x, l in zip(xi, levels)
] ]
return torch.stack( xi ) return torch.stack( xi )
@ -1057,7 +1057,7 @@ class Base(nn.Module):
embedding = self.langs_emb( input ) embedding = self.langs_emb( input )
elif name == "prom": elif name == "prom":
proms = [ input ] if isinstance(input, torch.Tensor) else input proms = [ input ] if isinstance(input, torch.Tensor) else input
input_prom = torch.cat([ prom for prom in proms if isinstance(input, torch.Tensor) ]) input_prom = torch.cat([ prom for prom in proms if isinstance(prom, torch.Tensor) ])
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] ) embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] )
elif name == "tone" and self.tones_emb is not None: elif name == "tone" and self.tones_emb is not None:
@ -1164,7 +1164,7 @@ class Base(nn.Module):
# list of tokens # list of tokens
if not isinstance(input, torch.Tensor): if not isinstance(input, torch.Tensor):
return sum( [ i.shape[0] for i in input if isinstance(i, torch.tensor) ] ) + 1 return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
# interleaved model # interleaved model
if self.interleave and name == "resp": if self.interleave and name == "resp":