sanity cleanups with weird off-by-one-ness, cleaned up and validated vall_e.models.experimental works again

This commit is contained in:
mrq 2024-07-27 15:36:05 -05:00
parent 06e948aec1
commit ce8bb1e4f7
6 changed files with 398 additions and 305 deletions

View File

@ -34,6 +34,9 @@ _logger = logging.getLogger(__name__)
# fold into a typical LLM sequence (one embedding rather than split embeddings) # fold into a typical LLM sequence (one embedding rather than split embeddings)
def fold_inputs( def fold_inputs(
text_list = [], text_list = [],
lang_list = [],
task_list = [],
tone_list = [],
prom_list = [], prom_list = [],
resp_list = [], resp_list = [],
targ_list = [], targ_list = [],
@ -42,12 +45,13 @@ def fold_inputs(
sep = 3, sep = 3,
stop = 3, stop = 3,
config = None,
text_tokens = 256,
audio_tokens = 1024,
audio_rvq_levels = cfg.model.max_levels,
quant_levels = None, quant_levels = None,
): ):
if config is None:
config = cfg.model
def _create_mask(l, device): def _create_mask(l, device):
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1) stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
@ -61,53 +65,124 @@ def fold_inputs(
m = m.to(x) m = m.to(x)
return x, m return x, m
def process_prom_or_task(i, prom):
if prom is None:
return
if isinstance(prom, str):
task = get_task_symmap()[f'<{input}>']
seq = torch.Tensor([task_start + task]).to(device=device, dtype=dtype)
input_ids[i].append( seq )
input_ids[i].append( sep )
return
# deinterleaved
if quant_levels is not None:
quant_level = quant_levels[i]
if ignore_index is not None:
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] ) ] ).to(device=device, dtype=dtype)
else:
seq = prom[:, quant_level].to(device=device, dtype=dtype).clone()
for idx, token in enumerate( seq ):
token += prom_start + ( config.audio_tokens * quant_level )
# interleaved
else:
if ignore_index is not None:
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to(device=device, dtype=dtype)
else:
seq = prom.flatten().to(device=device, dtype=dtype)
for idx, token in enumerate( seq ):
token += prom_start + ( config.audio_tokens * ( idx % config.resp_levels ) )
input_ids[i].append( seq )
input_ids[i].append( sep )
"""
if quant_levels is not None:
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
"""
device = text_list[0].device device = text_list[0].device
dtype = torch.int64
batch_size = len(text_list) batch_size = len(text_list)
input_ids = [ [] for _ in range(batch_size) ] input_ids = [ [] for _ in range(batch_size) ]
offset = 0 offset = 0
sep = torch.Tensor([ sep ]) sep = torch.Tensor([ sep ]).to(device=device, dtype=dtype)
stop = torch.Tensor([ stop ]) stop = torch.Tensor([ stop ]).to(device=device, dtype=dtype)
text_start = 0
text_end = text_start + config.text_tokens
lang_start = text_end
lang_end = lang_start + config.langs
rvq_start = lang_end
rvq_end = rvq_start + config.resp_levels
prom_start = rvq_end
prom_end = prom_start + config.audio_tokens * config.resp_levels
task_start = prom_end
task_end = task_start + config.tasks
tone_start = task_end
tone_end = tone_start + config.tones
resp_start = tone_end
resp_end = resp_start + config.audio_tokens * config.resp_levels
# text tokens
for i, text in enumerate(text_list): for i, text in enumerate(text_list):
seq = text.to("cpu", dtype=torch.int64) if isinstance(text, torch.Tensor):
seq = text + text_start
else:
seq = torch.Tensor([text_start + text]).to(device=device, dtype=dtype)
input_ids[i].append( seq )
input_ids[i].append( sep )
# lang tokens
for i, lang in enumerate(lang_list):
if isinstance(lang, torch.Tensor):
seq = lang + lang_start
else:
seq = torch.Tensor([lang_start + lang]).to(device=device, dtype=dtype)
input_ids[i].append( seq ) input_ids[i].append( seq )
input_ids[i].append( sep ) input_ids[i].append( sep )
offset = text_tokens
# inject target quant_level # inject target quant_level
if quant_levels is not None: if quant_levels is not None:
for i, rvq in enumerate( quant_levels ): for i, rvq in enumerate( quant_levels ):
seq = torch.Tensor([offset + rvq]).to("cpu", dtype=torch.int64) if isinstance(rvq, torch.Tensor):
seq = rvq + rvq_start
else:
seq = torch.Tensor([rvq_start + rvq]).to(device=device, dtype=dtype)
input_ids[i].append( seq ) input_ids[i].append( seq )
input_ids[i].append( sep ) input_ids[i].append( sep )
offset = text_tokens + audio_rvq_levels # prom / task tokens
for i, prom in enumerate(prom_list): for i, prom in enumerate(prom_list):
# deinterleaved # list of proms with a possible task token
if quant_levels is not None: if isinstance(prom, list):
quant_level = quant_levels[i] for p in prom:
if ignore_index is not None: process_prom_or_task(i, p)
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] ) ] ).to("cpu", dtype=torch.int64) # raw tensor
else:
seq = prom[:, quant_level].to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * quant_level )
# interleaved
else: else:
if ignore_index is not None: process_prom_or_task(i, prom)
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64)
else:
seq = prom.flatten().to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
# tone tokens
for i, tone in enumerate(tone_list):
if isinstance(tone, torch.Tensor):
seq = tone + tone_start
else:
seq = torch.Tensor([tone_start + tone]).to(device=device, dtype=dtype)
input_ids[i].append( seq ) input_ids[i].append( seq )
input_ids[i].append( sep ) input_ids[i].append( sep )
offset = text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels) # resp tokens
for i, resp in enumerate(resp_list): for i, resp in enumerate(resp_list):
# deinterleaved # deinterleaved
if quant_levels is not None: if quant_levels is not None:
@ -115,54 +190,51 @@ def fold_inputs(
quant_level = quant_levels[i] - 1 quant_level = quant_levels[i] - 1
# way to signal we want to inference for rvq level 0 # way to signal we want to inference for rvq level 0
# without it, it's a random chance for any level to be selected again # without it, it's a random chance for any level to be selected again
if quant_level < 0: if quant_level < 0:
continue continue
seq = sep
else: else:
# my shitcode keeps things as lists of tensors for each level, so this handles it because lists can't index by tuples # my shitcode keeps things as lists of tensors for each level, so this handles it because lists can't index by tuples
if isinstance(resp, list): if isinstance(resp, list):
seq = resp[quant_level].to("cpu", dtype=torch.int64) seq = resp[quant_level].to(device=device, dtype=dtype).clone()
else: else:
seq = resp[:, quant_level].to("cpu", dtype=torch.int64) seq = resp[:, quant_level].to(device=device, dtype=dtype).clone()
for idx, token in enumerate( seq ): for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * quant_level ) token += resp_start + ( config.audio_tokens * quant_level )
input_ids[i].append( seq ) input_ids[i].append( seq )
input_ids[i].append( stop ) input_ids[i].append( stop )
# interleaved # interleaved
else: else:
seq = resp.flatten().to("cpu", dtype=torch.int64) seq = resp.flatten().to(device=device, dtype=dtype)
for idx, token in enumerate( seq ): for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) ) token += resp_start + ( config.audio_tokens * ( idx % config.resp_levels ) )
input_ids[i].append( seq ) input_ids[i].append( seq )
input_ids[i].append( stop ) input_ids[i].append( stop )
# targ list
for i, resp in enumerate(targ_list): for i, resp in enumerate(targ_list):
# deinterleaved # deinterleaved
if quant_levels is not None: if quant_levels is not None:
quant_level = quant_levels[i] quant_level = quant_levels[i]
seq = resp[:, quant_level].to("cpu", dtype=torch.int64) seq = resp[:, quant_level].to(device=device, dtype=dtype)
for idx, token in enumerate( seq ): for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * quant_level ) token += resp_start + ( config.audio_tokens * quant_level )
input_ids[i].append( seq ) input_ids[i].append( seq )
input_ids[i].append( stop ) input_ids[i].append( stop )
# interleaved # interleaved
else: else:
seq = resp.flatten().to("cpu", dtype=torch.int64) seq = resp.flatten().to(device=device, dtype=dtype)
for idx, token in enumerate( seq ): for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) ) token += resp_start + ( config.audio_tokens * ( idx % config.resp_levels ) )
input_ids[i].append( seq ) input_ids[i].append( seq )
input_ids[i].append( stop ) input_ids[i].append( stop )
for i, batch in enumerate(input_ids): for i, batch in enumerate(input_ids):
input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64) input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=dtype)
return list_to_tensor(input_ids) return list_to_tensor(input_ids)
@ -174,20 +246,62 @@ def unfold_outputs(
sep = 3, sep = 3,
stop = 3, stop = 3,
text_tokens = 256, config = None,
audio_tokens = 1024,
audio_rvq_levels = cfg.model.max_levels,
quant_levels = None, quant_levels = None,
): ):
def bin_to_rvqs( tokens ):
length = len(tokens)
"""
if length % config.resp_levels == 0:
tokens = torch.Tensor(tokens).reshape( config.resp_levels, length // config.resp_levels ).t()
"""
bins = [ [] for _ in range(config.resp_levels) ]
for pos in range( length ):
rvq = pos % config.resp_levels
bins[rvq].append( tokens[pos] )
nearest = ( len(bins) // config.resp_levels ) * config.resp_levels
bins = bins[:nearest]
return torch.Tensor(bins).t().to(device=device, dtype=dtype)
if config is None:
config = cfg.model
device = output_ids.device device = output_ids.device
dtype = torch.int64
batch_size = output_ids.shape[0] batch_size = output_ids.shape[0]
text_list = [ [] for _ in range(batch_size) ] text_list = [ [] for _ in range(batch_size) ]
rvq_list = [ [] for _ in range(batch_size) ]
lang_list = [ [] for _ in range(batch_size) ]
task_list = [ [] for _ in range(batch_size) ]
tone_list = [ [] for _ in range(batch_size) ]
prom_list = [ [] for _ in range(batch_size) ] prom_list = [ [] for _ in range(batch_size) ]
resp_list = [ [] for _ in range(batch_size) ] resp_list = [ [] for _ in range(batch_size) ]
text_start = 0
text_end = text_start + config.text_tokens
lang_start = text_end
lang_end = lang_start + config.langs
rvq_start = lang_end
rvq_end = rvq_start + config.resp_levels
prom_start = rvq_end
prom_end = prom_start + config.audio_tokens * config.resp_levels
task_start = prom_end
task_end = task_start + config.tasks
tone_start = task_end
tone_end = tone_start + config.tones
resp_start = tone_end
resp_end = resp_start + config.audio_tokens * config.resp_levels
for i, batch in enumerate( output_ids ): for i, batch in enumerate( output_ids ):
# crigne logic to handle prefix resp for rvq levels > 0 # cringe logic to handle prefix resp for rvq levels > 0
# a better way is to observe if the rvq level increased # a better way is to observe if the rvq level increased
should_flush = False should_flush = False
flushed = False flushed = False
@ -201,49 +315,51 @@ def unfold_outputs(
continue continue
if 0 <= id and id < text_tokens: # text tokens
text_list[i].append( id ) if text_start <= id and id < text_end:
elif text_tokens + audio_rvq_levels <= id and id < text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels): text_list[i].append( (id - text_start) % config.text_tokens )
prom_list[i].append( (id - text_tokens - audio_rvq_levels) % audio_tokens ) # lang tokens
elif text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels) <= id: elif lang_start <= id and id < lang_end:
resp_list[i].append( (id - text_tokens - audio_rvq_levels) % audio_tokens ) lang_list[i].append( (id - lang_start) % config.langs )
# rvq levels
elif rvq_start <= id and id < rvq_end:
rvq_list[i].append( (id - rvq_start) % config.resp_levels )
# prom tokens
elif prom_start <= id and id < prom_end:
prom_list[i].append( (id - prom_start) % config.audio_tokens )
# task tokens
elif task_start <= id and id < task_end:
task_list[i].append( (id - task_start) % config.tasks )
# lang tokens
elif tone_start <= id and id < tone_end:
tone_list[i].append( (id - tone_start) % config.tones )
# resp tokens
elif resp_start <= id and id < resp_end:
resp_list[i].append( (id - resp_start) % config.audio_tokens )
if not flushed: if not flushed:
should_flush = True should_flush = True
if quant_levels is not None: if quant_levels is not None:
prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=torch.int64) prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=dtype)
resp_list[i] = torch.Tensor(resp_list[i]).t().to(device=device, dtype=torch.int64) resp_list[i] = torch.Tensor(resp_list[i]).t().to(device=device, dtype=dtype)
else: else:
prom_len = len(prom_list[i]) prom_list[i] = bin_to_rvqs( prom_list[i] )
if prom_len % audio_rvq_levels == 0 and False: resp_list[i] = bin_to_rvqs( resp_list[i] )
prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t()
else:
bins = [ [] for _ in range(audio_rvq_levels) ]
for pos in range( prom_len ):
rvq = pos % audio_rvq_levels
bins[rvq].append( prom_list[i][pos] )
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
bins = bins[:nearest]
prom_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
resp_len = len(resp_list[i]) text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=dtype)
if len(resp_list[i]) % audio_rvq_levels == 0 and False: task_list[i] = torch.Tensor( task_list[i] ).to(device=device, dtype=dtype)
resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t() lang_list[i] = torch.Tensor( lang_list[i] ).to(device=device, dtype=dtype)
else: tone_list[i] = torch.Tensor( tone_list[i] ).to(device=device, dtype=dtype)
bins = [ [] for _ in range(audio_rvq_levels) ]
for pos in range( resp_len ):
rvq = pos % audio_rvq_levels
bins[rvq].append( resp_list[i][pos] )
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
bins = bins[:nearest]
resp_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=torch.int64)
return dict( return dict(
text_list=text_list, text_list=text_list,
prom_list=prom_list, prom_list=prom_list,
resp_list=resp_list resp_list=resp_list,
task_list=task_list,
lang_list=lang_list,
tone_list=tone_list,
) )
# to-do: clean up this symmap mess # to-do: clean up this symmap mess
@ -1072,7 +1188,7 @@ def _create_dataloader(dataset, training):
shuffle=False, shuffle=False,
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size, batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
drop_last=training, drop_last=training,
sampler=dataset.sampler, sampler=dataset.sampler if training else None,
) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict( ) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict(
batch_sampler=dataset.sampler, batch_sampler=dataset.sampler,
) )

View File

@ -137,50 +137,35 @@ class AR_NAR(Base):
# is training # is training
if training: if training:
# specifies how to sample probabilities of which RVQ levels to train against
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal" p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
# determines which RVQ level to target per batch # determines which RVQ level to target per batch
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ] quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
# rate to perform token dropout errors
token_dropout_error = self.config.experimental.token_dropout_error token_dropout_error = self.config.experimental.token_dropout_error
# RVQ levels to apply token dropout on
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
# implicitly set it to all levels
if not token_dropout_rvq_levels: if not token_dropout_rvq_levels:
token_dropout_rvq_levels = [0, self.resp_levels] token_dropout_rvq_levels = [0, self.resp_levels - 1]
# allow passing a specific distribution of RVQ levels
if p_rvq_levels == "equal": p_rvq_levels = p_rvq_levels if isinstance(p_rvq_levels, list) else []
if not p_rvq_levels:
lo, hi = quant_level_range[0], quant_level_range[1] + 1
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ] if p_rvq_levels == "equal":
else: # if p_rvq_levels == "auto": p_rvq_levels = [ i for i in range( lo, hi ) ]
# makes higher levels less likely else:
""" # yuck
def generate( lo=0, hi=8 ): p_rvq_levels = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
index = lo
p = random.random()
for i in range(lo, hi):
if p < 1.0 / (2 ** i):
index = i
return int(index)
"""
# allow passing a specific distribution of RVQ levels
pool = p_rvq_levels if isinstance(p_rvq_levels, list) else []
if not pool:
lo, hi = quant_level_range[0], quant_level_range[1]
for i in range( lo, hi ):
rep = hi - i
pool += [i] * rep
quant_levels = [ random.choice( pool ) for i in range(batch_size) ]
# these two are techinically equivalent if the audio embeddings handle things properly
"""
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
stop_sequence = torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16)
"""
# input RVQ levels
quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ]
# trim resps to only contain all levels below the target level
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)] resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
# tensor to cat for RVQ level 0
stop_sequence = torch.Tensor([[self.stop_token] * 1]).to(device=device, dtype=torch.int16) stop_sequence = torch.Tensor([[self.stop_token] * 1]).to(device=device, dtype=torch.int16)
# I hate python's value/reference semantics so much
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list): for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list):
# cap quant_level if it exceeds its corresponding resp/prom # cap quant_level if it exceeds its corresponding resp/prom
if quant_level >= resps.shape[-1]: if quant_level >= resps.shape[-1]:
@ -213,7 +198,6 @@ class AR_NAR(Base):
# only apply stop token for RVQ level 0 # only apply stop token for RVQ level 0
if quant_level <= 0: if quant_level <= 0:
# append stop tokens for AR # append stop tokens for AR
# could technically do it in the .inputs call
resps_list[i] = torch.cat([ resps, stop_sequence ]) resps_list[i] = torch.cat([ resps, stop_sequence ])

View File

@ -57,7 +57,30 @@ class Model(LlmArchClass):
hf_attention = config.attention if config is not None else None hf_attention = config.attention if config is not None else None
gradient_checkpointing = config.gradient_checkpointing if config is not None else True gradient_checkpointing = config.gradient_checkpointing if config is not None else True
# text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop # text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop
vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1 # vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1
text_start = 0
text_end = text_start + config.text_tokens
lang_start = text_end
lang_end = lang_start + config.langs
rvq_start = lang_end
rvq_end = rvq_start + config.resp_levels
prom_start = rvq_end
prom_end = prom_start + config.audio_tokens * config.resp_levels
task_start = prom_end
task_end = task_start + config.tasks
tone_start = task_end
tone_end = tone_start + config.tones
resp_start = tone_end
resp_end = resp_start + config.audio_tokens * config.resp_levels
vocab_size = resp_end
if cfg.model.arch_type == "llama": if cfg.model.arch_type == "llama":
super().__init__(config=LlamaConfig( super().__init__(config=LlamaConfig(
@ -148,11 +171,94 @@ class Model(LlmArchClass):
*args, *args,
**kwargs, **kwargs,
): ):
if cfg.model.arch_type in ["mamba","mamba2"]: config = self.hyper_config
if "text_list" in kwargs:
text_list = kwargs.pop("text_list", None)
proms_list = kwargs.pop("proms_list", None)
resps_list = kwargs.pop("resps_list", None)
lang_list = kwargs.pop("lang_list", None)
tone_list = kwargs.pop("tone_list", None)
training = kwargs.pop("training", False)
steps = kwargs.pop("steps", 500)
batch_size = len(text_list)
if training:
quant_levels = None if config.experimental.interleave else [ random.randint( 0 if "ar" in config.capabilities else 1, config.max_levels - 1) for _ in range(batch_size) ]
input_ids, attention_mask = fold_inputs(
text_list=text_list,
prom_list=proms_list,
resp_list=resps_list,
targ_list=resps_list,
quant_levels=quant_levels,
)
target_ids, target_attention_mask = fold_inputs(
text_list=text_list,
prom_list=proms_list,
resp_list=resps_list,
targ_list=resps_list,
quant_levels=quant_levels,
ignore_index=-100
)
return self.forward(
input_ids=input_ids,
labels=target_ids,
)
if config.experimental.interleave:
input_ids, attention_mask = fold_inputs( text_list=text_list, prom_list=proms_list )
output = self.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps*config.max_levels, eos_token_id=3, do_sample=False)
return unfold_outputs( output )["resp_list"]
resps_list = [ [] for _ in range(batch_size) ]
for l in range(config.max_levels):
quant_levels = [ l for _ in range(batch_size) ]
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=proms_list, resp_list=resps_list, quant_levels=quant_levels)
min_length = 1
for batch in input_ids:
min_length = max( min_length, batch.shape[0] + 1 )
output = self.generate(
input_ids=input_ids,
attention_mask=attention_mask,
min_length=min_length,
max_length=min_length+steps*2,
eos_token_id=3,
do_sample=False
)
unfolded = unfold_outputs( output, quant_levels=quant_levels )
if l == 0:
steps = 0
for batch, resp in enumerate(unfolded["resp_list"]):
length = resp.shape[-1]
# store length
if l == 0:
steps = max( steps, length )
# pad
else:
resp = resp[:steps]
if length < steps:
resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ])
resps_list[batch].append( resp )
for i, resp in enumerate( resps_list ):
resps_list[i] = torch.stack( resp ).t()
return resps_list
if config.arch_type in ["mamba","mamba2"]:
if "attention_mask" in kwargs: if "attention_mask" in kwargs:
kwargs.pop("attention_mask") kwargs.pop("attention_mask")
labels = kwargs.pop("labels") if "labels" in kwargs else None labels = kwargs.pop("labels", None)
output = super().forward(*args, **kwargs) output = super().forward(*args, **kwargs)
logits = output.logits logits = output.logits
@ -322,53 +428,8 @@ def example_usage():
@torch.inference_mode() @torch.inference_mode()
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ): def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):
engine.eval() engine.eval()
batch_size = len(text_list)
resp_list = None
if cfg.model.experimental.interleave:
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list)
output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps, eos_token_id=3, do_sample=False)
unfolded = unfold_outputs( output ) resp_list = model( text_list=text_list, proms_list=prom_list )
resp_list = unfolded["resp_list"]
else:
resp_list = [ [] for _ in range(batch_size) ]
for l in range(cfg.model.max_levels):
quant_levels = [ l for _ in range(batch_size) ]
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels)
min_length = 1
for batch in input_ids:
min_length = max( min_length, batch.shape[0] + 1 )
output = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
min_length=min_length,
max_length=min_length+steps*2,
eos_token_id=3,
do_sample=False
)
unfolded = unfold_outputs( output, quant_levels=quant_levels )
if l == 0:
steps = 0
for batch, resp in enumerate(unfolded["resp_list"]):
length = resp.shape[-1]
# store length
if l == 0:
steps = max( steps, length )
# pad
else:
resp = resp[:steps]
if length < steps:
resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ])
resp_list[batch].append( resp )
for i, resp in enumerate( resp_list ):
resp_list[i] = torch.stack( resp ).t()
for i, batch in enumerate(resp_list): for i, batch in enumerate(resp_list):
_ = decode_to_file(batch.to(device=device), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device) _ = decode_to_file(batch.to(device=device), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
@ -381,18 +442,7 @@ def example_usage():
for i in t: for i in t:
stats = {"step": i} stats = {"step": i}
batch_size = len(text_list) stats |= engine.traverse(text_list=text_list, proms_list=prom_list, resps_list=resp_list, training=True)
quant_levels = None if cfg.model.experimental.interleave else torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,))
if quant_levels is not None:
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
else:
resps_list = [ resp for resp in resp_list ]
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resps_list, targ_list=resp_list, quant_levels=quant_levels)
target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels)
stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask)
stats |= engine.gather_attribute("stats") stats |= engine.gather_attribute("stats")
stats |= {"grad_norm": engine.get_global_grad_norm()} stats |= {"grad_norm": engine.get_global_grad_norm()}

View File

@ -133,46 +133,62 @@ class NAR(Base):
# generate task list to train against # generate task list to train against
task_list = [ sample_task() for _ in range(batch_size) ] task_list = [ sample_task() for _ in range(batch_size) ]
# determines which RVQ level to target per batch # specifies how to sample probabilities of which RVQ levels to train against
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ]
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal" p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
# determines which RVQ level to target per batch
if p_rvq_levels == "equal": quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
# rate to perform token dropout errors
token_dropout_error = self.config.experimental.token_dropout_error
# RVQ levels to apply token dropout on
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
# implicitly set it to all levels
if not token_dropout_rvq_levels:
token_dropout_rvq_levels = [0, self.resp_levels - 1]
# allow passing a specific distribution of RVQ levels
p_rvq_levels = p_rvq_levels if isinstance(p_rvq_levels, list) else []
if not p_rvq_levels:
lo, hi = quant_level_range[0], quant_level_range[1] + 1
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ] if p_rvq_levels == "equal":
else: # if p_rvq_levels == "auto": p_rvq_levels = [ i for i in range( lo, hi ) ]
# makes higher levels less likely else:
""" # yuck
def generate( lo=0, hi=8 ): p_rvq_levels = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
index = lo
p = random.random()
for i in range(lo, hi):
if p < 1.0 / (2 ** i):
index = i
return int(index)
"""
# allow passing a specific distribution of RVQ levels # input RVQ levels
pool = p_rvq_levels if isinstance(p_rvq_levels, list) else [] quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ]
if not pool: # trim resps to only contain all levels below the target level
lo, hi = quant_level_range[0], quant_level_range[1] resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
for i in range( lo, hi ):
rep = hi - i
pool += [i] * rep
quant_levels = [ random.choice( pool ) for i in range(batch_size) ] # I hate python's value/reference semantics so much
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list):
# clamp quant_levels because some of my audio was saved for only 8 out of 9 RVQ levels for DAC...
for i in range(batch_size):
# cap quant_level if it exceeds its corresponding resp/prom # cap quant_level if it exceeds its corresponding resp/prom
if quant_levels[i] >= resps_list[i].shape[-1]: if quant_level >= resps.shape[-1]:
quant_levels[i] = resps_list[i].shape[-1] - 1 quant_levels[i] = resps.shape[-1] - 1
if quant_levels[i] >= proms_list[i].shape[-1]: # proms could be a Tensor, list[Tensor], or None
quant_levels[i] = proms_list[i].shape[-1] - 1 if isinstance( proms, torch.Tensor ):
if quant_level >= proms.shape[-1]:
quant_levels[i] = proms.shape[-1] - 1
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] elif isinstance( proms, list ):
for j, prom in enumerate( proms ):
if not isinstance( prom, torch.Tensor ):
continue
if quant_level >= prom.shape[-1]:
quant_levels[i] = prom.shape[-1] - 1
# apply token dropout error compensation
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
steps = resps.shape[0]
for l in range( quant_level ):
for t in range( steps ):
token = resps[t, l].item()
if random.random() < token_dropout_error:
offset = 1 * ( 1 if random.random() < 0.5 else -1 )
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
inputs = self.inputs( inputs = self.inputs(
text_list=text_list, text_list=text_list,

View File

@ -31,44 +31,16 @@ def train_feeder(engine, batch):
batch_size = len(batch["text"]) batch_size = len(batch["text"])
engine.current_batch_size = batch_size engine.current_batch_size = batch_size
if engine.hyper_config.experimental.hf: engine(
if engine.hyper_config.experimental.interleave: text_list=batch["text"],
quant_levels = 0 proms_list=batch["proms"],
resps_list = [ resp for resp in batch["resps"] ] resps_list=batch["resps"],
else: lang_list=batch["lang"],
quant_levels = [ random.randint( 0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels) for _ in range(batch_size) ] tone_list=batch["tone"],
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, batch["resps"]) ] task_list=batch["task"],
input_ids, attention_mask = fold_inputs( training=True,
text_list=batch["text"], )
prom_list=batch["proms"],
resp_list=resps_list,
targ_list=batch["resps"],
quant_levels=quant_levels,
)
target_ids, target_attention_mask = fold_inputs(
text_list=batch["text"],
prom_list=batch["proms"],
resp_list=resps_list,
targ_list=batch["resps"],
quant_levels=quant_levels,
ignore_index=-100
)
engine(
input_ids=input_ids,
labels=target_ids,
)
else:
engine(
text_list=batch["text"],
proms_list=batch["proms"],
resps_list=batch["resps"],
lang_list=batch["lang"],
tone_list=batch["tone"],
task_list=batch["task"],
training=True,
)
losses = engine.gather_attribute("loss") losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats") stat = engine.gather_attribute("stats")
@ -137,66 +109,18 @@ def run_eval(engines, eval_name, dl):
engine = engines[name] engine = engines[name]
if engine.hyper_config.experimental.hf: if engine.hyper_config.experimental.hf:
if engine.hyper_config.experimental.interleave: resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"] )
input_ids, attention_mask = fold_inputs( elif "len" in engine.hyper_config.capabilities:
text_list=batch["text"], len_list = engine(text_list=batch["text"], proms_list=batch["proms"], max_steps=10 ) # don't need more than that
prom_list=batch["proms"], resps_list = engine( text_list=batch["text"], proms_list=batch["proms"], len_list=len_list, max_levels=cfg.evaluation.nar_levels )
)
output = engine.module.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False)
resps_list = unfold_outputs( output )["resp_list"]
else:
steps = cfg.evaluation.steps
resps_list = [ [] for _ in range(len(text_list)) ]
for l in range(cfg.model.max_levels):
quant_levels = [ [ l ] for _ in range(len(text_list)) ]
input_ids, attention_mask = fold_inputs(text_list=batch["text"], prom_list=batch["proms"], resp_list=resps_list, quant_levels=quant_levels, experimental=True)
min_length = 1
for batch in input_ids:
min_length = max( min_length, batch.shape[0] + 1 )
output = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
min_length=min_length,
max_length=min_length+steps*(2 if l > 0 else 1),
eos_token_id=3,
do_sample=False
)
unfolded = unfold_outputs( output, quant_levels=quant_levels )
if l == 0:
steps = 0
for batch, resp in enumerate(unfolded["resp_list"]):
length = resp.shape[-1]
# store length
if l == 0:
steps = max( steps, length )
# pad
else:
resp = resp[:steps]
if length < steps:
resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ])
resps_list[batch].append( resp )
for i, resp in enumerate( resps_list ):
resps_list[i] = torch.stack( resp ).t()
else: else:
if "len" in engine.hyper_config.capabilities: if "ar" in engine.hyper_config.capabilities:
len_list = engine(text_list=batch["text"], proms_list=batch["proms"], max_steps=10 ) # don't need more than that resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
resps_list = engine( text_list=batch["text"], proms_list=batch["proms"], len_list=len_list, max_levels=cfg.evaluation.nar_levels )
else: else:
if "ar" in engine.hyper_config.capabilities: resps_list = [ resp[:, 0] for resp in batch["resps"] ]
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
else:
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
if "nar" in engine.hyper_config.capabilities: if "nar" in engine.hyper_config.capabilities:
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels ) resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels )
process( name, batch, resps_list ) process( name, batch, resps_list )

View File

@ -178,3 +178,6 @@ def to_device(x: T | None, *args, **kwargs) -> T:
return return
return tree_map(lambda t: t.to(*args, **kwargs), x) return tree_map(lambda t: t.to(*args, **kwargs), x)
def coalese( *arg, return_last=True ):
return [ x for x in arg if x is not None ][-1 if return_last else 0]