overhauled how the right resp level / classifier gets picked to avoid cringemath

This commit is contained in:
mrq 2024-11-13 13:31:17 -06:00
parent 269648605e
commit 910033343c
3 changed files with 102 additions and 122 deletions

View File

@ -260,7 +260,7 @@ class ModelExperimentalSettings:
masking_train_p: float = 0.0 # odds of training with masking
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
masking_separate_embeddings: bool = False
masking_separate_embeddings: bool = False # to-do: explain
# classifier-free guidance shit
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training

View File

@ -883,6 +883,10 @@ def example_usage():
import numpy as np
import re
cfg.model.experimental.masking_train_p = 0.5
cfg.hyperparameters.batch_size = 1
cfg.hyperparameters.gradient_accumulation_steps = 1
setup_logging()
@ -896,7 +900,6 @@ def example_usage():
text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
batch_size = cfg.hyperparameters.batch_size
cfg.model.experimental.masking_train_p = 1.0
text_list = [ text ] * batch_size
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] * batch_size

View File

@ -46,6 +46,9 @@ LossStats = namedtuple('LossStats', ['loss', 'stats'])
from ..utils.pattern import DelayedPatternProvider, VALLEPattern
"""
summed_embeddings_task = [ "stt" ]
special_tasks = [ "len", "stt" ]
def _dropout_mask( input, p=None ):
# cosine scheduling
if p is None:
@ -182,78 +185,26 @@ class AudioEmbedding(nn.Module):
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
token_dim: int, # dimensionality of the embedding
sums: bool = True, # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
external_mode: str | None = None, # "exclusive" | "inclusive", whether to include the original audio backend's embeddings
capabilities: list[str] | None = None, # helper shit
l_names: list[str] = [], # names to map to indices
):
super().__init__()
# array of embeddings
# proms are [0, resp_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
# further experimentation is needed to see if this actually is useful
self.sums = sums
#
self.names = l_names
self.external_mode = external_mode
self.capabilities = capabilities
# set initial weights to zero
if self.external_mode == "inclusive":
for i, embedding in enumerate(self.embeddings):
embedding.weight = torch.nn.Parameter(torch.zeros( embedding.weight.shape ))
def external_embeddings(self, input: Tensor, quant_level: int | None = None ) -> Tensor:
if quant_level is None:
quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1
# for AR, trim any stop tokens
has_stop_token = False
# this block apparently doesn't work
"""
if quant_level == 0:
stop_token = self.embeddings[0].weight.shape[0] - 1
stop_token_indices = (input == stop_token).nonzero()
has_stop_token = len(stop_token_indices) > 0
if has_stop_token:
input = input[:stop_token_indices.min().item()]
"""
has_stop_token = False
if quant_level == 0:
stop_token = self.embeddings[0].weight.shape[0] - 1
has_stop_token = input[-1] == stop_token
if has_stop_token:
input = input[:-1]
# get external embedding
embedding = encode_as_embedding( input, quant_level, sums=self.sums ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype)
# resize if necessary (in case the external embeddings do not match our model dim)
embedding = ml.resize_weight( embedding, self.embeddings[quant_level].weight.shape[-1], dim=-1, random=False )
# reintroduce stop token
if has_stop_token:
stop_token = self.internal_forward( torch.tensor([stop_token]).to(device=input.device, dtype=torch.int16), 0 )
embedding = torch.concat( [ embedding, stop_token ] )
return embedding
def internal_forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, sums = None ) -> Tensor:
if offset is None:
# prom
if self.capabilities is None:
offset = 0
elif "nar" not in self.capabilities:
offset = 0
elif quant_level > 0:
offset = 1
def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, name: str | None = None, sums = None ) -> Tensor:
if sums is None:
sums = self.sums
# handle mapping from name
if name in self.names:
offset = self.names.index( name )
if quant_level is None:
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
@ -265,17 +216,6 @@ class AudioEmbedding(nn.Module):
return x
def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, sums = None ) -> Tensor:
x = self.internal_forward( xi, offset = offset, quant_level = quant_level, sums = sums ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None
if self.external_mode and xi.shape[0] > 0:
external_embeddings = self.external_embeddings( xi, quant_level = quant_level )
if self.external_mode == "exclusive":
return external_embeddings
x += external_embeddings
return x
# time-step embedding
# for the NAR-len, since it probably most likely requires encoding the timestep
class TimeEmbedding(nn.Module):
@ -304,14 +244,32 @@ class Classifiers(nn.Module):
self,
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
token_dim: int, # dimensionality of the embedding
l_names: list[str] | None = None, # list of names to map to each classifier
):
super().__init__()
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens) for n_tokens in l_tokens])
self.names = l_names
def forward(self, xi: Tensor, levels: list[int] ) -> Tensor:
def indices(
self,
names
):
if isinstance( names[-1], int ):
return names
return [ self.names.index(name) for name in names ]
def forward(self, xi: Tensor, levels: list[int] | None = None, names: list[str] | None = None ) -> Tensor:
dtype = xi.dtype
device = xi.device
if levels and isinstance( levels[-1], str ):
names = levels
levels = []
# map names to levels
if names and not levels:
levels = [ self.names.index(name) for name in names ]
xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ]
# pad if needed
# to-do: validate that this causes ZERO issues
@ -349,11 +307,11 @@ class Metrics(nn.Module):
ignore_index=ignore_index,
) for n_tokens in l_tokens ])
def calc_accuracy( self, inputs, targets, quant_levels ):
return sum( [ self.accuracy[l]( input[:, :self.accuracy[l].num_classes], target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
def calc_accuracy( self, inputs, targets, classifier_levels ):
return sum( [ self.accuracy[l]( input[:, :self.accuracy[l].num_classes], target ) for target, input, l in zip( targets, inputs, classifier_levels ) ] ) / len( inputs )
def calc_precision( self, inputs, targets, quant_levels ):
return sum( [ self.precision[l]( input[:, :self.precision[l].num_classes], target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
def calc_precision( self, inputs, targets, classifier_levels ):
return sum( [ self.precision[l]( input[:, :self.precision[l].num_classes], target ) for target, input, l in zip( targets, inputs, classifier_levels ) ] ) / len( inputs )
def __call__(self, *args, **kwargs):
return dict(
@ -486,21 +444,29 @@ class Base(nn.Module):
if "nar" not in self.capabilities:
n_resp_tokens = n_audio_tokens + 1
l_tokens = [n_resp_tokens] * self.n_resp_levels
resp_l_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
# NAR-len model
elif "len" in self.capabilities and masking_separate_embeddings:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
if masking_separate_embeddings:
l_tokens += [n_resp_tokens]
resp_l_names += ['NAR:0:0']
# AR+NAR model
else:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
classifier_l_names = resp_l_names + ["stt"]
self.unified_position_ids = unified_position_ids
self.interleave = interleave
self.layerskip = layerskip
self.special_tasks = [ "len", "stt" ]
self.inject_timestep_embedding = False # results in bad output
self.masking_separate_embeddings = masking_separate_embeddings
@ -534,14 +500,11 @@ class Base(nn.Module):
self.proms_emb = AudioEmbedding(
[n_audio_tokens] * self.n_resp_levels, d_model,
sums=audio_embedding_sums,
external_mode=audio_embedding_mode,
capabilities=None,
)
self.resps_emb = AudioEmbedding(
l_tokens, d_model,
sums=audio_embedding_sums,
external_mode=audio_embedding_mode,
capabilities=self.capabilities,
l_names=resp_l_names,
)
if self.version >= 3:
@ -842,7 +805,7 @@ class Base(nn.Module):
self.metrics = None
else:
self.classifier = None
self.classifiers = Classifiers( l_tokens + [ n_text_tokens ], d_model )
self.classifiers = Classifiers( l_tokens + [ n_text_tokens ], d_model, l_names=classifier_l_names )
self.accuracy_metric = None
self.precision_metric = None
self.metrics = Metrics( l_tokens + [ n_text_tokens ] )
@ -1002,6 +965,7 @@ class Base(nn.Module):
quant_level = quant_levels[i] if quant_levels is not None else 0
task_type = task_list[i] if task_list is not None else "tts"
timestep = time_list[i] if time_list is not None else None
classifier_level = None
# insert task type as a string
inputs[i].append( ( "task", task_type ) )
@ -1012,7 +976,7 @@ class Base(nn.Module):
# Base-line TTS task
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
if f'<{task_type}>' in get_task_symmap() and task_type not in self.special_tasks:
if f'<{task_type}>' in get_task_symmap() and task_type not in special_tasks:
# insert the text prompt
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
@ -1022,19 +986,22 @@ class Base(nn.Module):
# insert RVQ level guidance token if the model is versioned for it
if self.rvq_l_emb is not None and not self.interleave:
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
classifier_level = "AR:0:0" if quant_level == 0 else f'NAR:{quant_level-1}:{quant_level}'
# insert input audio prompt
if proms_list is not None and proms_list[i] is not None:
inputs[i].append( ( "prom", proms_list[i] ) )
# insert tone token if we're trained for it
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
inputs[i].append( ( "tone", tone_list[i] ) )
# it does not seem to matter whether this is provided or not, I assume the model attends more to the amount of masked tokens in the sequence
"""
# insert timestep token
if timestep is not None:
# it does not seem to matter whether this is provided or not, I assume the model attends more to the amount of masked tokens in the sequence
"""
# store timestep information
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
"""
"""
classifier_level = "NAR:0:0"
# insert the current output response
if resps_list is not None and resps_list[i] is not None:
inputs[i].append( ( "resp", resps_list[i] ) )
@ -1050,6 +1017,7 @@ class Base(nn.Module):
dropout_mask = _dropout_mask( resps_list[i], p )
inputs[i].append( ("dropout_mask", dropout_mask ) )
inputs[i].append( ("classifier_level", classifier_level) )
# Audio length prediction task
# Sequence: <text><sep><rvq lvl><prom><sep><len>
elif task_type == "len":
@ -1080,6 +1048,8 @@ class Base(nn.Module):
elif resps_list is not None and resps_list[i] is not None:
# yes this could be encoded better
inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) )
inputs[i].append( ("classifier_level", "stt") )
# Speech-to-Text prediction task
# Sequence: <resp><sep><rvq lvl><sep><text>
elif task_type == "stt":
@ -1095,6 +1065,8 @@ class Base(nn.Module):
# insert the output text prompt
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
inputs[i].append( ("classifier_level", "stt") )
else:
raise Exception(f'Unrecognized task: {task_type}')
return inputs
@ -1131,7 +1103,6 @@ class Base(nn.Module):
if not token_dropout_rvq_levels:
token_dropout_rvq_levels = [1, self.resp_levels]
summed_embeddings_task = [ "stt" ]
x_list = []
for batch_index, batch_input in enumerate(inputs):
@ -1140,6 +1111,7 @@ class Base(nn.Module):
task_type = "tts"
input_prom = None
classifier_level = None
dropout_mask = None
timestep = None
@ -1147,6 +1119,8 @@ class Base(nn.Module):
for name, input in batch_input:
if name == "dropout_mask":
dropout_mask = input
elif name == "classifier_level":
classifier_level = input
for name, input in batch_input:
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
@ -1179,8 +1153,9 @@ class Base(nn.Module):
if self.interleave:
embeddings = [ self.resps_emb(
input[:, :l+1],
offset = 0,
quant_level = l
#offset = 0,
#quant_level = l,
name = 'AR:0:0' if l == 0 else f'NAR:{l-1}:{l}',
) for l in range( input.shape[-1] ) ]
embedding = _interleave_sequence_reshape( embeddings )
@ -1190,16 +1165,18 @@ class Base(nn.Module):
embedding = self.resps_emb(
# if masked use masked token, else original token
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
offset = -1 if self.masking_separate_embeddings else 0, # pick last
quant_level = 0,
#offset = -1 if self.masking_separate_embeddings else 0, # pick last
#quant_level = 0,
name = classifier_level,
)
# cheat-y way to handle performing STT across all levels
elif task_type in summed_embeddings_task:
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
embedding = sum([ self.resps_emb(
input[:, :l+1],
offset = 0 if l == 0 else 1, # or maybe set to 1
quant_level = l,
#offset = 0 if l == 0 else 1, # or maybe set to 1
#quant_level = l,
name = 'AR:0:0' if l == 0 else f'NAR:{l-1}:{l}',
sums = False
) for l in range( input.shape[-1] - 1 ) ])
else:
@ -1210,6 +1187,7 @@ class Base(nn.Module):
quant_level
)
else:
"""
offset = 0
if "nar" not in self.capabilities:
offset = 0
@ -1221,6 +1199,13 @@ class Base(nn.Module):
offset = offset,
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
)
"""
embedding = self.resps_emb(
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
offset = 0 if classifier_level.startswith("AR:") else 1,
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
)
# apply token dropout
if token_dropout_rate > 0.0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
@ -1283,16 +1268,14 @@ class Base(nn.Module):
# there's a better way
if not self.unified_position_ids:
x_list = []
non_tokens = ["task", "dropout_mask", "classifier_level"]
last_input = ["resp", "len"]
def get_input_token_length( name, input ):
# task token
if isinstance(input, str):
return 1
# a mask
if name in ["dropout_mask"]:
return 0
# list of tokens
if not isinstance(input, torch.Tensor):
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
@ -1302,12 +1285,12 @@ class Base(nn.Module):
return input.shape[0] * input.shape[1]
# ending input will not have a separator later
return input.shape[0] + (0 if name in ["resp", "len"] else 1)
return input.shape[0] + (0 if name in last_input else 1)
for batch_index, batch_input in enumerate(inputs):
batch = torch.cat( [
torch.tensor([*range(get_input_token_length(name, input))], device=device, dtype=torch.int32)
for name, input in batch_input if name != "task"
for name, input in batch_input if name not in non_tokens
] )
delta = ids[batch_index].shape[0] - batch.shape[0]
@ -1325,17 +1308,14 @@ class Base(nn.Module):
inputs: list,
logits,
quant_levels: int | list[int] | Tensor | None = None,
quant_levels: list[int] | None = None,
):
loss = dict(ce = dict())
stats = dict(acc = dict())
device = logits[0].device
batch_size = len(logits)
summed_embeddings_task = [ "stt" ]
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
is_nar_len = [ self.get_input(inputs, "dropout_mask", at=i) is not None and self.masking_separate_embeddings for i in range( batch_size ) ]
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else (-2 if is_nar_len[i] else l) for i, l in enumerate( quant_levels ) ]
classifier_levels = self.get_input( inputs, "classifier_level" )
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ):
@ -1369,7 +1349,7 @@ class Base(nn.Module):
if name == "task":
task_type = input
task_list.append( input )
if task_type in ["len", "stt"]:
if task_type in special_tasks:
causal = True
elif name == "prom":
proms = [ input ] if isinstance(input, torch.Tensor) else input
@ -1423,7 +1403,7 @@ class Base(nn.Module):
loss = dict(
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
)
stats = self.metrics( inputs, targets, classifier_quant_levels ) if self.metrics is not None else dict(
stats = self.metrics( inputs, targets, classifier_levels ) if self.metrics is not None else dict(
acc = self.accuracy_metric( inputs, target ),
# precision = self.precision_metric( inputs, target ),
)
@ -1432,7 +1412,7 @@ class Base(nn.Module):
loss = dict(
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
)
stats = self.metrics( logits, target_list, classifier_quant_levels ) if self.metrics is not None else dict(
stats = self.metrics( logits, target_list, self.classifiers.indices( classifier_levels ) ) if self.metrics is not None else dict(
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
)
@ -1466,7 +1446,7 @@ class Base(nn.Module):
# meta-input, no corresponding token at the moment
if name == "task":
task_name = input
if task_type in ["len", "stt"]:
if task_type in special_tasks:
causal = True
continue
# do not use resp as-is
@ -1529,7 +1509,7 @@ class Base(nn.Module):
else:
loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
if self.metrics is not None:
metrics = self.metrics( batch["logits"], batch["targets"], classifier_quant_levels )
metrics = self.metrics( batch["logits"], batch["targets"], self.classifiers.indices( classifier_levels ) )
stats["acc"][name] = metrics["acc"]
else:
stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
@ -1540,7 +1520,7 @@ class Base(nn.Module):
self,
inputs: list,
quant_levels: int | list[int] | Tensor | None = None,
quant_levels: list[int] | None = None,
state: dict | list | None = None,
layer_skip_variables: dict | None = None,
@ -1549,7 +1529,7 @@ class Base(nn.Module):
output_hidden_states: bool = False,
):
# return early if it's "good" enough"
# lambda because we need to capture the classifier_quant_levels and mask
# lambda because we need to capture the classifier_levels and mask
exited_layer = self.n_layers
def layer_skip_lambda( layer, logits ):
nonlocal exited_layer
@ -1576,7 +1556,7 @@ class Base(nn.Module):
if self.classifier is not None:
x = self.classifier(x) # * m
elif self.classifiers is not None:
logits = self.classifiers(logits, levels = classifier_quant_levels) # * m
logits = self.classifiers(logits, levels = classifier_levels) # * m
# calculate metrics
metrics = calculate_entropix_metrics( logits )
@ -1628,10 +1608,7 @@ class Base(nn.Module):
# needs to be done here as we still have our raw inputs
#position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
is_nar_len = [ self.get_input(inputs, "dropout_mask", at=i) is not None and self.masking_separate_embeddings for i in range( batch_size ) ]
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else (-2 if is_nar_len[i] else l) for i, l in enumerate( quant_levels ) ]
classifier_levels = self.get_input( inputs, name="classifier_level" )
if self.inject_timestep_embedding:
timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ]
@ -1664,11 +1641,11 @@ class Base(nn.Module):
# to-do: piece-wise classification, now that there's a head for text
# although again, one single monolithic head would be preferable instead......
elif self.classifiers is not None:
logits = self.classifiers(logits, levels = classifier_quant_levels) # * m
logits = self.classifiers(logits, levels = classifier_levels) # * m
if hidden_states is not None:
for i, state in enumerate( hidden_states ):
hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_quant_levels) # * m
hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_levels) # * m
# Remove padding
logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ]
@ -1716,7 +1693,7 @@ class Base(nn.Module):
self,
logits: list[Tensor], # logit scores
prev_list: list[Tensor] | None = None, # previous tokens
quant_levels: int | list[int] | Tensor | None = None, # to-do: derive this from the prev_list
quant_levels: list[int] | None = None, # to-do: derive this from the prev_list
**sampling_kwargs,
):
# yikes