sanity cleanups with weird off-by-one-ness, cleaned up and validated vall_e.models.experimental works again
This commit is contained in:
parent
06e948aec1
commit
ce8bb1e4f7
276
vall_e/data.py
276
vall_e/data.py
|
@ -34,6 +34,9 @@ _logger = logging.getLogger(__name__)
|
|||
# fold into a typical LLM sequence (one embedding rather than split embeddings)
|
||||
def fold_inputs(
|
||||
text_list = [],
|
||||
lang_list = [],
|
||||
task_list = [],
|
||||
tone_list = [],
|
||||
prom_list = [],
|
||||
resp_list = [],
|
||||
targ_list = [],
|
||||
|
@ -42,12 +45,13 @@ def fold_inputs(
|
|||
|
||||
sep = 3,
|
||||
stop = 3,
|
||||
config = None,
|
||||
|
||||
text_tokens = 256,
|
||||
audio_tokens = 1024,
|
||||
audio_rvq_levels = cfg.model.max_levels,
|
||||
quant_levels = None,
|
||||
):
|
||||
if config is None:
|
||||
config = cfg.model
|
||||
|
||||
def _create_mask(l, device):
|
||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
||||
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
|
||||
|
@ -61,108 +65,176 @@ def fold_inputs(
|
|||
m = m.to(x)
|
||||
return x, m
|
||||
|
||||
def process_prom_or_task(i, prom):
|
||||
if prom is None:
|
||||
return
|
||||
|
||||
if isinstance(prom, str):
|
||||
task = get_task_symmap()[f'<{input}>']
|
||||
seq = torch.Tensor([task_start + task]).to(device=device, dtype=dtype)
|
||||
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
return
|
||||
|
||||
# deinterleaved
|
||||
if quant_levels is not None:
|
||||
quant_level = quant_levels[i]
|
||||
if ignore_index is not None:
|
||||
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] ) ] ).to(device=device, dtype=dtype)
|
||||
else:
|
||||
seq = prom[:, quant_level].to(device=device, dtype=dtype).clone()
|
||||
for idx, token in enumerate( seq ):
|
||||
token += prom_start + ( config.audio_tokens * quant_level )
|
||||
# interleaved
|
||||
else:
|
||||
if ignore_index is not None:
|
||||
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to(device=device, dtype=dtype)
|
||||
else:
|
||||
seq = prom.flatten().to(device=device, dtype=dtype)
|
||||
for idx, token in enumerate( seq ):
|
||||
token += prom_start + ( config.audio_tokens * ( idx % config.resp_levels ) )
|
||||
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
|
||||
"""
|
||||
if quant_levels is not None:
|
||||
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
|
||||
"""
|
||||
|
||||
device = text_list[0].device
|
||||
dtype = torch.int64
|
||||
|
||||
batch_size = len(text_list)
|
||||
input_ids = [ [] for _ in range(batch_size) ]
|
||||
|
||||
offset = 0
|
||||
|
||||
sep = torch.Tensor([ sep ])
|
||||
stop = torch.Tensor([ stop ])
|
||||
sep = torch.Tensor([ sep ]).to(device=device, dtype=dtype)
|
||||
stop = torch.Tensor([ stop ]).to(device=device, dtype=dtype)
|
||||
|
||||
text_start = 0
|
||||
text_end = text_start + config.text_tokens
|
||||
|
||||
lang_start = text_end
|
||||
lang_end = lang_start + config.langs
|
||||
|
||||
rvq_start = lang_end
|
||||
rvq_end = rvq_start + config.resp_levels
|
||||
|
||||
prom_start = rvq_end
|
||||
prom_end = prom_start + config.audio_tokens * config.resp_levels
|
||||
|
||||
task_start = prom_end
|
||||
task_end = task_start + config.tasks
|
||||
|
||||
tone_start = task_end
|
||||
tone_end = tone_start + config.tones
|
||||
|
||||
resp_start = tone_end
|
||||
resp_end = resp_start + config.audio_tokens * config.resp_levels
|
||||
|
||||
# text tokens
|
||||
for i, text in enumerate(text_list):
|
||||
seq = text.to("cpu", dtype=torch.int64)
|
||||
if isinstance(text, torch.Tensor):
|
||||
seq = text + text_start
|
||||
else:
|
||||
seq = torch.Tensor([text_start + text]).to(device=device, dtype=dtype)
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
|
||||
# lang tokens
|
||||
for i, lang in enumerate(lang_list):
|
||||
if isinstance(lang, torch.Tensor):
|
||||
seq = lang + lang_start
|
||||
else:
|
||||
seq = torch.Tensor([lang_start + lang]).to(device=device, dtype=dtype)
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
|
||||
offset = text_tokens
|
||||
# inject target quant_level
|
||||
if quant_levels is not None:
|
||||
for i, rvq in enumerate( quant_levels ):
|
||||
seq = torch.Tensor([offset + rvq]).to("cpu", dtype=torch.int64)
|
||||
if isinstance(rvq, torch.Tensor):
|
||||
seq = rvq + rvq_start
|
||||
else:
|
||||
seq = torch.Tensor([rvq_start + rvq]).to(device=device, dtype=dtype)
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
|
||||
offset = text_tokens + audio_rvq_levels
|
||||
# prom / task tokens
|
||||
for i, prom in enumerate(prom_list):
|
||||
# deinterleaved
|
||||
if quant_levels is not None:
|
||||
quant_level = quant_levels[i]
|
||||
if ignore_index is not None:
|
||||
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] ) ] ).to("cpu", dtype=torch.int64)
|
||||
else:
|
||||
seq = prom[:, quant_level].to("cpu", dtype=torch.int64)
|
||||
for idx, token in enumerate( seq ):
|
||||
token += offset + ( audio_tokens * quant_level )
|
||||
# interleaved
|
||||
# list of proms with a possible task token
|
||||
if isinstance(prom, list):
|
||||
for p in prom:
|
||||
process_prom_or_task(i, p)
|
||||
# raw tensor
|
||||
else:
|
||||
if ignore_index is not None:
|
||||
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64)
|
||||
else:
|
||||
seq = prom.flatten().to("cpu", dtype=torch.int64)
|
||||
for idx, token in enumerate( seq ):
|
||||
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
|
||||
process_prom_or_task(i, prom)
|
||||
|
||||
# tone tokens
|
||||
for i, tone in enumerate(tone_list):
|
||||
if isinstance(tone, torch.Tensor):
|
||||
seq = tone + tone_start
|
||||
else:
|
||||
seq = torch.Tensor([tone_start + tone]).to(device=device, dtype=dtype)
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
|
||||
offset = text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels)
|
||||
|
||||
# resp tokens
|
||||
for i, resp in enumerate(resp_list):
|
||||
# deinterleaved
|
||||
if quant_levels is not None:
|
||||
# grab the previous rvq level
|
||||
quant_level = quant_levels[i] - 1
|
||||
# way to signal we want to inference for rvq level 0
|
||||
# without it, it's a random chance for any level to be selected again
|
||||
|
||||
# without it, it's a random chance for any level to be selected again
|
||||
if quant_level < 0:
|
||||
continue
|
||||
|
||||
seq = sep
|
||||
else:
|
||||
# my shitcode keeps things as lists of tensors for each level, so this handles it because lists can't index by tuples
|
||||
if isinstance(resp, list):
|
||||
seq = resp[quant_level].to("cpu", dtype=torch.int64)
|
||||
seq = resp[quant_level].to(device=device, dtype=dtype).clone()
|
||||
else:
|
||||
seq = resp[:, quant_level].to("cpu", dtype=torch.int64)
|
||||
seq = resp[:, quant_level].to(device=device, dtype=dtype).clone()
|
||||
|
||||
for idx, token in enumerate( seq ):
|
||||
token += offset + ( audio_tokens * quant_level )
|
||||
|
||||
token += resp_start + ( config.audio_tokens * quant_level )
|
||||
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( stop )
|
||||
# interleaved
|
||||
else:
|
||||
seq = resp.flatten().to("cpu", dtype=torch.int64)
|
||||
seq = resp.flatten().to(device=device, dtype=dtype)
|
||||
for idx, token in enumerate( seq ):
|
||||
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
|
||||
token += resp_start + ( config.audio_tokens * ( idx % config.resp_levels ) )
|
||||
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( stop )
|
||||
|
||||
# targ list
|
||||
for i, resp in enumerate(targ_list):
|
||||
# deinterleaved
|
||||
if quant_levels is not None:
|
||||
quant_level = quant_levels[i]
|
||||
seq = resp[:, quant_level].to("cpu", dtype=torch.int64)
|
||||
seq = resp[:, quant_level].to(device=device, dtype=dtype)
|
||||
for idx, token in enumerate( seq ):
|
||||
token += offset + ( audio_tokens * quant_level )
|
||||
token += resp_start + ( config.audio_tokens * quant_level )
|
||||
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( stop )
|
||||
# interleaved
|
||||
else:
|
||||
seq = resp.flatten().to("cpu", dtype=torch.int64)
|
||||
seq = resp.flatten().to(device=device, dtype=dtype)
|
||||
for idx, token in enumerate( seq ):
|
||||
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
|
||||
token += resp_start + ( config.audio_tokens * ( idx % config.resp_levels ) )
|
||||
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( stop )
|
||||
|
||||
for i, batch in enumerate(input_ids):
|
||||
input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64)
|
||||
input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=dtype)
|
||||
|
||||
return list_to_tensor(input_ids)
|
||||
|
||||
|
@ -174,20 +246,62 @@ def unfold_outputs(
|
|||
sep = 3,
|
||||
stop = 3,
|
||||
|
||||
text_tokens = 256,
|
||||
audio_tokens = 1024,
|
||||
audio_rvq_levels = cfg.model.max_levels,
|
||||
config = None,
|
||||
quant_levels = None,
|
||||
):
|
||||
def bin_to_rvqs( tokens ):
|
||||
length = len(tokens)
|
||||
"""
|
||||
if length % config.resp_levels == 0:
|
||||
tokens = torch.Tensor(tokens).reshape( config.resp_levels, length // config.resp_levels ).t()
|
||||
"""
|
||||
bins = [ [] for _ in range(config.resp_levels) ]
|
||||
for pos in range( length ):
|
||||
rvq = pos % config.resp_levels
|
||||
bins[rvq].append( tokens[pos] )
|
||||
nearest = ( len(bins) // config.resp_levels ) * config.resp_levels
|
||||
bins = bins[:nearest]
|
||||
return torch.Tensor(bins).t().to(device=device, dtype=dtype)
|
||||
|
||||
if config is None:
|
||||
config = cfg.model
|
||||
|
||||
device = output_ids.device
|
||||
dtype = torch.int64
|
||||
|
||||
batch_size = output_ids.shape[0]
|
||||
|
||||
text_list = [ [] for _ in range(batch_size) ]
|
||||
rvq_list = [ [] for _ in range(batch_size) ]
|
||||
lang_list = [ [] for _ in range(batch_size) ]
|
||||
task_list = [ [] for _ in range(batch_size) ]
|
||||
tone_list = [ [] for _ in range(batch_size) ]
|
||||
prom_list = [ [] for _ in range(batch_size) ]
|
||||
resp_list = [ [] for _ in range(batch_size) ]
|
||||
|
||||
text_start = 0
|
||||
text_end = text_start + config.text_tokens
|
||||
|
||||
lang_start = text_end
|
||||
lang_end = lang_start + config.langs
|
||||
|
||||
rvq_start = lang_end
|
||||
rvq_end = rvq_start + config.resp_levels
|
||||
|
||||
prom_start = rvq_end
|
||||
prom_end = prom_start + config.audio_tokens * config.resp_levels
|
||||
|
||||
task_start = prom_end
|
||||
task_end = task_start + config.tasks
|
||||
|
||||
tone_start = task_end
|
||||
tone_end = tone_start + config.tones
|
||||
|
||||
resp_start = tone_end
|
||||
resp_end = resp_start + config.audio_tokens * config.resp_levels
|
||||
|
||||
for i, batch in enumerate( output_ids ):
|
||||
# crigne logic to handle prefix resp for rvq levels > 0
|
||||
# cringe logic to handle prefix resp for rvq levels > 0
|
||||
# a better way is to observe if the rvq level increased
|
||||
should_flush = False
|
||||
flushed = False
|
||||
|
@ -201,49 +315,51 @@ def unfold_outputs(
|
|||
|
||||
continue
|
||||
|
||||
if 0 <= id and id < text_tokens:
|
||||
text_list[i].append( id )
|
||||
elif text_tokens + audio_rvq_levels <= id and id < text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels):
|
||||
prom_list[i].append( (id - text_tokens - audio_rvq_levels) % audio_tokens )
|
||||
elif text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels) <= id:
|
||||
resp_list[i].append( (id - text_tokens - audio_rvq_levels) % audio_tokens )
|
||||
# text tokens
|
||||
if text_start <= id and id < text_end:
|
||||
text_list[i].append( (id - text_start) % config.text_tokens )
|
||||
# lang tokens
|
||||
elif lang_start <= id and id < lang_end:
|
||||
lang_list[i].append( (id - lang_start) % config.langs )
|
||||
# rvq levels
|
||||
elif rvq_start <= id and id < rvq_end:
|
||||
rvq_list[i].append( (id - rvq_start) % config.resp_levels )
|
||||
# prom tokens
|
||||
elif prom_start <= id and id < prom_end:
|
||||
prom_list[i].append( (id - prom_start) % config.audio_tokens )
|
||||
# task tokens
|
||||
elif task_start <= id and id < task_end:
|
||||
task_list[i].append( (id - task_start) % config.tasks )
|
||||
# lang tokens
|
||||
elif tone_start <= id and id < tone_end:
|
||||
tone_list[i].append( (id - tone_start) % config.tones )
|
||||
# resp tokens
|
||||
elif resp_start <= id and id < resp_end:
|
||||
resp_list[i].append( (id - resp_start) % config.audio_tokens )
|
||||
|
||||
if not flushed:
|
||||
should_flush = True
|
||||
|
||||
if quant_levels is not None:
|
||||
prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=torch.int64)
|
||||
resp_list[i] = torch.Tensor(resp_list[i]).t().to(device=device, dtype=torch.int64)
|
||||
prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=dtype)
|
||||
resp_list[i] = torch.Tensor(resp_list[i]).t().to(device=device, dtype=dtype)
|
||||
else:
|
||||
prom_len = len(prom_list[i])
|
||||
if prom_len % audio_rvq_levels == 0 and False:
|
||||
prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t()
|
||||
else:
|
||||
bins = [ [] for _ in range(audio_rvq_levels) ]
|
||||
for pos in range( prom_len ):
|
||||
rvq = pos % audio_rvq_levels
|
||||
bins[rvq].append( prom_list[i][pos] )
|
||||
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
||||
bins = bins[:nearest]
|
||||
prom_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
|
||||
prom_list[i] = bin_to_rvqs( prom_list[i] )
|
||||
resp_list[i] = bin_to_rvqs( resp_list[i] )
|
||||
|
||||
resp_len = len(resp_list[i])
|
||||
if len(resp_list[i]) % audio_rvq_levels == 0 and False:
|
||||
resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t()
|
||||
else:
|
||||
bins = [ [] for _ in range(audio_rvq_levels) ]
|
||||
for pos in range( resp_len ):
|
||||
rvq = pos % audio_rvq_levels
|
||||
bins[rvq].append( resp_list[i][pos] )
|
||||
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
||||
bins = bins[:nearest]
|
||||
resp_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
|
||||
|
||||
text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=torch.int64)
|
||||
text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=dtype)
|
||||
task_list[i] = torch.Tensor( task_list[i] ).to(device=device, dtype=dtype)
|
||||
lang_list[i] = torch.Tensor( lang_list[i] ).to(device=device, dtype=dtype)
|
||||
tone_list[i] = torch.Tensor( tone_list[i] ).to(device=device, dtype=dtype)
|
||||
|
||||
return dict(
|
||||
text_list=text_list,
|
||||
prom_list=prom_list,
|
||||
resp_list=resp_list
|
||||
resp_list=resp_list,
|
||||
|
||||
task_list=task_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
)
|
||||
|
||||
# to-do: clean up this symmap mess
|
||||
|
@ -1072,7 +1188,7 @@ def _create_dataloader(dataset, training):
|
|||
shuffle=False,
|
||||
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
||||
drop_last=training,
|
||||
sampler=dataset.sampler,
|
||||
sampler=dataset.sampler if training else None,
|
||||
) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict(
|
||||
batch_sampler=dataset.sampler,
|
||||
)
|
||||
|
|
|
@ -137,50 +137,35 @@ class AR_NAR(Base):
|
|||
|
||||
# is training
|
||||
if training:
|
||||
# specifies how to sample probabilities of which RVQ levels to train against
|
||||
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
|
||||
|
||||
# determines which RVQ level to target per batch
|
||||
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ]
|
||||
|
||||
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
|
||||
# rate to perform token dropout errors
|
||||
token_dropout_error = self.config.experimental.token_dropout_error
|
||||
# RVQ levels to apply token dropout on
|
||||
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
|
||||
# implicitly set it to all levels
|
||||
if not token_dropout_rvq_levels:
|
||||
token_dropout_rvq_levels = [0, self.resp_levels]
|
||||
|
||||
if p_rvq_levels == "equal":
|
||||
token_dropout_rvq_levels = [0, self.resp_levels - 1]
|
||||
# allow passing a specific distribution of RVQ levels
|
||||
p_rvq_levels = p_rvq_levels if isinstance(p_rvq_levels, list) else []
|
||||
if not p_rvq_levels:
|
||||
lo, hi = quant_level_range[0], quant_level_range[1] + 1
|
||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||
else: # if p_rvq_levels == "auto":
|
||||
# makes higher levels less likely
|
||||
"""
|
||||
def generate( lo=0, hi=8 ):
|
||||
index = lo
|
||||
p = random.random()
|
||||
for i in range(lo, hi):
|
||||
if p < 1.0 / (2 ** i):
|
||||
index = i
|
||||
return int(index)
|
||||
"""
|
||||
|
||||
# allow passing a specific distribution of RVQ levels
|
||||
pool = p_rvq_levels if isinstance(p_rvq_levels, list) else []
|
||||
if not pool:
|
||||
lo, hi = quant_level_range[0], quant_level_range[1]
|
||||
for i in range( lo, hi ):
|
||||
rep = hi - i
|
||||
pool += [i] * rep
|
||||
|
||||
quant_levels = [ random.choice( pool ) for i in range(batch_size) ]
|
||||
|
||||
# these two are techinically equivalent if the audio embeddings handle things properly
|
||||
"""
|
||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
stop_sequence = torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16)
|
||||
"""
|
||||
if p_rvq_levels == "equal":
|
||||
p_rvq_levels = [ i for i in range( lo, hi ) ]
|
||||
else:
|
||||
# yuck
|
||||
p_rvq_levels = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
|
||||
|
||||
# input RVQ levels
|
||||
quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ]
|
||||
# trim resps to only contain all levels below the target level
|
||||
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
# tensor to cat for RVQ level 0
|
||||
stop_sequence = torch.Tensor([[self.stop_token] * 1]).to(device=device, dtype=torch.int16)
|
||||
|
||||
# I hate python's value/reference semantics so much
|
||||
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list):
|
||||
# cap quant_level if it exceeds its corresponding resp/prom
|
||||
if quant_level >= resps.shape[-1]:
|
||||
|
@ -213,7 +198,6 @@ class AR_NAR(Base):
|
|||
# only apply stop token for RVQ level 0
|
||||
if quant_level <= 0:
|
||||
# append stop tokens for AR
|
||||
# could technically do it in the .inputs call
|
||||
resps_list[i] = torch.cat([ resps, stop_sequence ])
|
||||
|
||||
|
||||
|
|
|
@ -57,7 +57,30 @@ class Model(LlmArchClass):
|
|||
hf_attention = config.attention if config is not None else None
|
||||
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
|
||||
# text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop
|
||||
vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1
|
||||
# vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1
|
||||
|
||||
text_start = 0
|
||||
text_end = text_start + config.text_tokens
|
||||
|
||||
lang_start = text_end
|
||||
lang_end = lang_start + config.langs
|
||||
|
||||
rvq_start = lang_end
|
||||
rvq_end = rvq_start + config.resp_levels
|
||||
|
||||
prom_start = rvq_end
|
||||
prom_end = prom_start + config.audio_tokens * config.resp_levels
|
||||
|
||||
task_start = prom_end
|
||||
task_end = task_start + config.tasks
|
||||
|
||||
tone_start = task_end
|
||||
tone_end = tone_start + config.tones
|
||||
|
||||
resp_start = tone_end
|
||||
resp_end = resp_start + config.audio_tokens * config.resp_levels
|
||||
|
||||
vocab_size = resp_end
|
||||
|
||||
if cfg.model.arch_type == "llama":
|
||||
super().__init__(config=LlamaConfig(
|
||||
|
@ -148,11 +171,94 @@ class Model(LlmArchClass):
|
|||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if cfg.model.arch_type in ["mamba","mamba2"]:
|
||||
config = self.hyper_config
|
||||
|
||||
if "text_list" in kwargs:
|
||||
text_list = kwargs.pop("text_list", None)
|
||||
proms_list = kwargs.pop("proms_list", None)
|
||||
resps_list = kwargs.pop("resps_list", None)
|
||||
lang_list = kwargs.pop("lang_list", None)
|
||||
tone_list = kwargs.pop("tone_list", None)
|
||||
|
||||
training = kwargs.pop("training", False)
|
||||
steps = kwargs.pop("steps", 500)
|
||||
|
||||
batch_size = len(text_list)
|
||||
|
||||
if training:
|
||||
quant_levels = None if config.experimental.interleave else [ random.randint( 0 if "ar" in config.capabilities else 1, config.max_levels - 1) for _ in range(batch_size) ]
|
||||
|
||||
input_ids, attention_mask = fold_inputs(
|
||||
text_list=text_list,
|
||||
prom_list=proms_list,
|
||||
resp_list=resps_list,
|
||||
targ_list=resps_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
target_ids, target_attention_mask = fold_inputs(
|
||||
text_list=text_list,
|
||||
prom_list=proms_list,
|
||||
resp_list=resps_list,
|
||||
targ_list=resps_list,
|
||||
quant_levels=quant_levels,
|
||||
ignore_index=-100
|
||||
)
|
||||
return self.forward(
|
||||
input_ids=input_ids,
|
||||
labels=target_ids,
|
||||
)
|
||||
|
||||
if config.experimental.interleave:
|
||||
input_ids, attention_mask = fold_inputs( text_list=text_list, prom_list=proms_list )
|
||||
output = self.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps*config.max_levels, eos_token_id=3, do_sample=False)
|
||||
return unfold_outputs( output )["resp_list"]
|
||||
|
||||
resps_list = [ [] for _ in range(batch_size) ]
|
||||
for l in range(config.max_levels):
|
||||
quant_levels = [ l for _ in range(batch_size) ]
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=proms_list, resp_list=resps_list, quant_levels=quant_levels)
|
||||
min_length = 1
|
||||
for batch in input_ids:
|
||||
min_length = max( min_length, batch.shape[0] + 1 )
|
||||
|
||||
output = self.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
min_length=min_length,
|
||||
max_length=min_length+steps*2,
|
||||
eos_token_id=3,
|
||||
do_sample=False
|
||||
)
|
||||
|
||||
unfolded = unfold_outputs( output, quant_levels=quant_levels )
|
||||
|
||||
if l == 0:
|
||||
steps = 0
|
||||
|
||||
for batch, resp in enumerate(unfolded["resp_list"]):
|
||||
length = resp.shape[-1]
|
||||
|
||||
# store length
|
||||
if l == 0:
|
||||
steps = max( steps, length )
|
||||
# pad
|
||||
else:
|
||||
resp = resp[:steps]
|
||||
if length < steps:
|
||||
resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ])
|
||||
resps_list[batch].append( resp )
|
||||
|
||||
for i, resp in enumerate( resps_list ):
|
||||
resps_list[i] = torch.stack( resp ).t()
|
||||
|
||||
return resps_list
|
||||
|
||||
if config.arch_type in ["mamba","mamba2"]:
|
||||
if "attention_mask" in kwargs:
|
||||
kwargs.pop("attention_mask")
|
||||
|
||||
labels = kwargs.pop("labels") if "labels" in kwargs else None
|
||||
labels = kwargs.pop("labels", None)
|
||||
|
||||
output = super().forward(*args, **kwargs)
|
||||
logits = output.logits
|
||||
|
@ -322,53 +428,8 @@ def example_usage():
|
|||
@torch.inference_mode()
|
||||
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):
|
||||
engine.eval()
|
||||
batch_size = len(text_list)
|
||||
resp_list = None
|
||||
if cfg.model.experimental.interleave:
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list)
|
||||
output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps, eos_token_id=3, do_sample=False)
|
||||
|
||||
unfolded = unfold_outputs( output )
|
||||
resp_list = unfolded["resp_list"]
|
||||
else:
|
||||
resp_list = [ [] for _ in range(batch_size) ]
|
||||
for l in range(cfg.model.max_levels):
|
||||
quant_levels = [ l for _ in range(batch_size) ]
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels)
|
||||
min_length = 1
|
||||
for batch in input_ids:
|
||||
min_length = max( min_length, batch.shape[0] + 1 )
|
||||
|
||||
output = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
min_length=min_length,
|
||||
max_length=min_length+steps*2,
|
||||
eos_token_id=3,
|
||||
do_sample=False
|
||||
)
|
||||
|
||||
unfolded = unfold_outputs( output, quant_levels=quant_levels )
|
||||
|
||||
if l == 0:
|
||||
steps = 0
|
||||
|
||||
for batch, resp in enumerate(unfolded["resp_list"]):
|
||||
length = resp.shape[-1]
|
||||
|
||||
# store length
|
||||
if l == 0:
|
||||
steps = max( steps, length )
|
||||
# pad
|
||||
else:
|
||||
resp = resp[:steps]
|
||||
if length < steps:
|
||||
resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ])
|
||||
resp_list[batch].append( resp )
|
||||
|
||||
for i, resp in enumerate( resp_list ):
|
||||
resp_list[i] = torch.stack( resp ).t()
|
||||
|
||||
resp_list = model( text_list=text_list, proms_list=prom_list )
|
||||
|
||||
for i, batch in enumerate(resp_list):
|
||||
_ = decode_to_file(batch.to(device=device), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
|
||||
|
@ -380,19 +441,8 @@ def example_usage():
|
|||
t = trange(steps)
|
||||
for i in t:
|
||||
stats = {"step": i}
|
||||
|
||||
batch_size = len(text_list)
|
||||
quant_levels = None if cfg.model.experimental.interleave else torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,))
|
||||
if quant_levels is not None:
|
||||
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
|
||||
else:
|
||||
resps_list = [ resp for resp in resp_list ]
|
||||
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resps_list, targ_list=resp_list, quant_levels=quant_levels)
|
||||
target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels)
|
||||
|
||||
stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask)
|
||||
stats |= engine.traverse(text_list=text_list, proms_list=prom_list, resps_list=resp_list, training=True)
|
||||
stats |= engine.gather_attribute("stats")
|
||||
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
||||
|
||||
|
|
|
@ -133,46 +133,62 @@ class NAR(Base):
|
|||
# generate task list to train against
|
||||
task_list = [ sample_task() for _ in range(batch_size) ]
|
||||
|
||||
# determines which RVQ level to target per batch
|
||||
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ]
|
||||
|
||||
# specifies how to sample probabilities of which RVQ levels to train against
|
||||
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
|
||||
|
||||
if p_rvq_levels == "equal":
|
||||
# determines which RVQ level to target per batch
|
||||
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
|
||||
# rate to perform token dropout errors
|
||||
token_dropout_error = self.config.experimental.token_dropout_error
|
||||
# RVQ levels to apply token dropout on
|
||||
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
|
||||
# implicitly set it to all levels
|
||||
if not token_dropout_rvq_levels:
|
||||
token_dropout_rvq_levels = [0, self.resp_levels - 1]
|
||||
# allow passing a specific distribution of RVQ levels
|
||||
p_rvq_levels = p_rvq_levels if isinstance(p_rvq_levels, list) else []
|
||||
if not p_rvq_levels:
|
||||
lo, hi = quant_level_range[0], quant_level_range[1] + 1
|
||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||
else: # if p_rvq_levels == "auto":
|
||||
# makes higher levels less likely
|
||||
"""
|
||||
def generate( lo=0, hi=8 ):
|
||||
index = lo
|
||||
p = random.random()
|
||||
for i in range(lo, hi):
|
||||
if p < 1.0 / (2 ** i):
|
||||
index = i
|
||||
return int(index)
|
||||
"""
|
||||
if p_rvq_levels == "equal":
|
||||
p_rvq_levels = [ i for i in range( lo, hi ) ]
|
||||
else:
|
||||
# yuck
|
||||
p_rvq_levels = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
|
||||
|
||||
# allow passing a specific distribution of RVQ levels
|
||||
pool = p_rvq_levels if isinstance(p_rvq_levels, list) else []
|
||||
if not pool:
|
||||
lo, hi = quant_level_range[0], quant_level_range[1]
|
||||
for i in range( lo, hi ):
|
||||
rep = hi - i
|
||||
pool += [i] * rep
|
||||
# input RVQ levels
|
||||
quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ]
|
||||
# trim resps to only contain all levels below the target level
|
||||
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
|
||||
quant_levels = [ random.choice( pool ) for i in range(batch_size) ]
|
||||
|
||||
# clamp quant_levels because some of my audio was saved for only 8 out of 9 RVQ levels for DAC...
|
||||
for i in range(batch_size):
|
||||
# I hate python's value/reference semantics so much
|
||||
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list):
|
||||
# cap quant_level if it exceeds its corresponding resp/prom
|
||||
if quant_levels[i] >= resps_list[i].shape[-1]:
|
||||
quant_levels[i] = resps_list[i].shape[-1] - 1
|
||||
if quant_level >= resps.shape[-1]:
|
||||
quant_levels[i] = resps.shape[-1] - 1
|
||||
|
||||
if quant_levels[i] >= proms_list[i].shape[-1]:
|
||||
quant_levels[i] = proms_list[i].shape[-1] - 1
|
||||
# proms could be a Tensor, list[Tensor], or None
|
||||
if isinstance( proms, torch.Tensor ):
|
||||
if quant_level >= proms.shape[-1]:
|
||||
quant_levels[i] = proms.shape[-1] - 1
|
||||
|
||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
elif isinstance( proms, list ):
|
||||
for j, prom in enumerate( proms ):
|
||||
if not isinstance( prom, torch.Tensor ):
|
||||
continue
|
||||
|
||||
if quant_level >= prom.shape[-1]:
|
||||
quant_levels[i] = prom.shape[-1] - 1
|
||||
|
||||
# apply token dropout error compensation
|
||||
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
||||
steps = resps.shape[0]
|
||||
for l in range( quant_level ):
|
||||
for t in range( steps ):
|
||||
token = resps[t, l].item()
|
||||
|
||||
if random.random() < token_dropout_error:
|
||||
offset = 1 * ( 1 if random.random() < 0.5 else -1 )
|
||||
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
|
|
112
vall_e/train.py
112
vall_e/train.py
|
@ -31,44 +31,16 @@ def train_feeder(engine, batch):
|
|||
batch_size = len(batch["text"])
|
||||
engine.current_batch_size = batch_size
|
||||
|
||||
if engine.hyper_config.experimental.hf:
|
||||
if engine.hyper_config.experimental.interleave:
|
||||
quant_levels = 0
|
||||
resps_list = [ resp for resp in batch["resps"] ]
|
||||
else:
|
||||
quant_levels = [ random.randint( 0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels) for _ in range(batch_size) ]
|
||||
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, batch["resps"]) ]
|
||||
engine(
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
resps_list=batch["resps"],
|
||||
lang_list=batch["lang"],
|
||||
tone_list=batch["tone"],
|
||||
task_list=batch["task"],
|
||||
|
||||
input_ids, attention_mask = fold_inputs(
|
||||
text_list=batch["text"],
|
||||
prom_list=batch["proms"],
|
||||
resp_list=resps_list,
|
||||
targ_list=batch["resps"],
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
target_ids, target_attention_mask = fold_inputs(
|
||||
text_list=batch["text"],
|
||||
prom_list=batch["proms"],
|
||||
resp_list=resps_list,
|
||||
targ_list=batch["resps"],
|
||||
quant_levels=quant_levels,
|
||||
ignore_index=-100
|
||||
)
|
||||
engine(
|
||||
input_ids=input_ids,
|
||||
labels=target_ids,
|
||||
)
|
||||
else:
|
||||
engine(
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
resps_list=batch["resps"],
|
||||
lang_list=batch["lang"],
|
||||
tone_list=batch["tone"],
|
||||
task_list=batch["task"],
|
||||
|
||||
training=True,
|
||||
)
|
||||
training=True,
|
||||
)
|
||||
|
||||
losses = engine.gather_attribute("loss")
|
||||
stat = engine.gather_attribute("stats")
|
||||
|
@ -137,66 +109,18 @@ def run_eval(engines, eval_name, dl):
|
|||
engine = engines[name]
|
||||
|
||||
if engine.hyper_config.experimental.hf:
|
||||
if engine.hyper_config.experimental.interleave:
|
||||
input_ids, attention_mask = fold_inputs(
|
||||
text_list=batch["text"],
|
||||
prom_list=batch["proms"],
|
||||
)
|
||||
output = engine.module.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False)
|
||||
resps_list = unfold_outputs( output )["resp_list"]
|
||||
else:
|
||||
steps = cfg.evaluation.steps
|
||||
resps_list = [ [] for _ in range(len(text_list)) ]
|
||||
for l in range(cfg.model.max_levels):
|
||||
quant_levels = [ [ l ] for _ in range(len(text_list)) ]
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list=batch["text"], prom_list=batch["proms"], resp_list=resps_list, quant_levels=quant_levels, experimental=True)
|
||||
min_length = 1
|
||||
for batch in input_ids:
|
||||
min_length = max( min_length, batch.shape[0] + 1 )
|
||||
|
||||
output = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
min_length=min_length,
|
||||
max_length=min_length+steps*(2 if l > 0 else 1),
|
||||
eos_token_id=3,
|
||||
do_sample=False
|
||||
)
|
||||
|
||||
unfolded = unfold_outputs( output, quant_levels=quant_levels )
|
||||
|
||||
if l == 0:
|
||||
steps = 0
|
||||
|
||||
for batch, resp in enumerate(unfolded["resp_list"]):
|
||||
length = resp.shape[-1]
|
||||
|
||||
# store length
|
||||
if l == 0:
|
||||
steps = max( steps, length )
|
||||
# pad
|
||||
else:
|
||||
resp = resp[:steps]
|
||||
if length < steps:
|
||||
resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ])
|
||||
|
||||
resps_list[batch].append( resp )
|
||||
|
||||
for i, resp in enumerate( resps_list ):
|
||||
resps_list[i] = torch.stack( resp ).t()
|
||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"] )
|
||||
elif "len" in engine.hyper_config.capabilities:
|
||||
len_list = engine(text_list=batch["text"], proms_list=batch["proms"], max_steps=10 ) # don't need more than that
|
||||
resps_list = engine( text_list=batch["text"], proms_list=batch["proms"], len_list=len_list, max_levels=cfg.evaluation.nar_levels )
|
||||
else:
|
||||
if "len" in engine.hyper_config.capabilities:
|
||||
len_list = engine(text_list=batch["text"], proms_list=batch["proms"], max_steps=10 ) # don't need more than that
|
||||
resps_list = engine( text_list=batch["text"], proms_list=batch["proms"], len_list=len_list, max_levels=cfg.evaluation.nar_levels )
|
||||
if "ar" in engine.hyper_config.capabilities:
|
||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||
else:
|
||||
if "ar" in engine.hyper_config.capabilities:
|
||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||
else:
|
||||
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
||||
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
||||
|
||||
if "nar" in engine.hyper_config.capabilities:
|
||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels )
|
||||
if "nar" in engine.hyper_config.capabilities:
|
||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels )
|
||||
|
||||
process( name, batch, resps_list )
|
||||
|
||||
|
|
|
@ -178,3 +178,6 @@ def to_device(x: T | None, *args, **kwargs) -> T:
|
|||
return
|
||||
|
||||
return tree_map(lambda t: t.to(*args, **kwargs), x)
|
||||
|
||||
def coalese( *arg, return_last=True ):
|
||||
return [ x for x in arg if x is not None ][-1 if return_last else 0]
|
Loading…
Reference in New Issue
Block a user