fixes, again

This commit is contained in:
mrq 2024-09-06 11:41:41 -05:00
parent 94cf81d38c
commit 341e19162b
3 changed files with 5 additions and 8 deletions

View File

@ -1055,8 +1055,6 @@ class Dataset(_Dataset):
# Base STT (<resp> => <text>)
elif task == "stt":
# easier to just keep it instead of wrangling around trying to remove it
# it might also help to provide a guidance prompt but who knows right now
proms = [
task
]

View File

@ -135,9 +135,8 @@ class AR_NAR(Base):
for j, prom in enumerate( proms ):
if not isinstance( prom, torch.Tensor ):
continue
if quant_level >= prom.shape[-1]:
quant_levels[i] = prom.shape[-1] - 1
if quant_level >= prom.shape[-1]:
quant_levels[i] = prom.shape[-1] - 1
# apply token dropout error compensation
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):

View File

@ -280,7 +280,7 @@ class Classifiers(nn.Module):
xi = [
#x if l == 0 else
x if x.shape[-1] == max_size else
torch.cat( [ x, torch.full( (x.shape[0], max_size - x.shape[-1]), -float("inf"), device=device, dtype=dtype) ], dim=-1 )
torch.cat( [x, torch.full( (x.shape[0], max_size - x.shape[-1]), -float("inf"), device=device, dtype=dtype) ], dim=-1 )
for x, l in zip(xi, levels)
]
return torch.stack( xi )
@ -1057,7 +1057,7 @@ class Base(nn.Module):
embedding = self.langs_emb( input )
elif name == "prom":
proms = [ input ] if isinstance(input, torch.Tensor) else input
input_prom = torch.cat([ prom for prom in proms if isinstance(input, torch.Tensor) ])
input_prom = torch.cat([ prom for prom in proms if isinstance(prom, torch.Tensor) ])
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] )
elif name == "tone" and self.tones_emb is not None:
@ -1164,7 +1164,7 @@ class Base(nn.Module):
# list of tokens
if not isinstance(input, torch.Tensor):
return sum( [ i.shape[0] for i in input if isinstance(i, torch.tensor) ] ) + 1
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
# interleaved model
if self.interleave and name == "resp":