sanity cleanups with weird off-by-one-ness, cleaned up and validated vall_e.models.experimental works again
This commit is contained in:
parent
06e948aec1
commit
ce8bb1e4f7
274
vall_e/data.py
274
vall_e/data.py
|
@ -34,6 +34,9 @@ _logger = logging.getLogger(__name__)
|
||||||
# fold into a typical LLM sequence (one embedding rather than split embeddings)
|
# fold into a typical LLM sequence (one embedding rather than split embeddings)
|
||||||
def fold_inputs(
|
def fold_inputs(
|
||||||
text_list = [],
|
text_list = [],
|
||||||
|
lang_list = [],
|
||||||
|
task_list = [],
|
||||||
|
tone_list = [],
|
||||||
prom_list = [],
|
prom_list = [],
|
||||||
resp_list = [],
|
resp_list = [],
|
||||||
targ_list = [],
|
targ_list = [],
|
||||||
|
@ -42,12 +45,13 @@ def fold_inputs(
|
||||||
|
|
||||||
sep = 3,
|
sep = 3,
|
||||||
stop = 3,
|
stop = 3,
|
||||||
|
config = None,
|
||||||
|
|
||||||
text_tokens = 256,
|
|
||||||
audio_tokens = 1024,
|
|
||||||
audio_rvq_levels = cfg.model.max_levels,
|
|
||||||
quant_levels = None,
|
quant_levels = None,
|
||||||
):
|
):
|
||||||
|
if config is None:
|
||||||
|
config = cfg.model
|
||||||
|
|
||||||
def _create_mask(l, device):
|
def _create_mask(l, device):
|
||||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
||||||
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
|
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
|
||||||
|
@ -61,53 +65,124 @@ def fold_inputs(
|
||||||
m = m.to(x)
|
m = m.to(x)
|
||||||
return x, m
|
return x, m
|
||||||
|
|
||||||
|
def process_prom_or_task(i, prom):
|
||||||
|
if prom is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(prom, str):
|
||||||
|
task = get_task_symmap()[f'<{input}>']
|
||||||
|
seq = torch.Tensor([task_start + task]).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
input_ids[i].append( seq )
|
||||||
|
input_ids[i].append( sep )
|
||||||
|
return
|
||||||
|
|
||||||
|
# deinterleaved
|
||||||
|
if quant_levels is not None:
|
||||||
|
quant_level = quant_levels[i]
|
||||||
|
if ignore_index is not None:
|
||||||
|
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] ) ] ).to(device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
seq = prom[:, quant_level].to(device=device, dtype=dtype).clone()
|
||||||
|
for idx, token in enumerate( seq ):
|
||||||
|
token += prom_start + ( config.audio_tokens * quant_level )
|
||||||
|
# interleaved
|
||||||
|
else:
|
||||||
|
if ignore_index is not None:
|
||||||
|
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to(device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
seq = prom.flatten().to(device=device, dtype=dtype)
|
||||||
|
for idx, token in enumerate( seq ):
|
||||||
|
token += prom_start + ( config.audio_tokens * ( idx % config.resp_levels ) )
|
||||||
|
|
||||||
|
input_ids[i].append( seq )
|
||||||
|
input_ids[i].append( sep )
|
||||||
|
|
||||||
|
"""
|
||||||
|
if quant_levels is not None:
|
||||||
|
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
|
||||||
|
"""
|
||||||
|
|
||||||
device = text_list[0].device
|
device = text_list[0].device
|
||||||
|
dtype = torch.int64
|
||||||
|
|
||||||
batch_size = len(text_list)
|
batch_size = len(text_list)
|
||||||
input_ids = [ [] for _ in range(batch_size) ]
|
input_ids = [ [] for _ in range(batch_size) ]
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
|
|
||||||
sep = torch.Tensor([ sep ])
|
sep = torch.Tensor([ sep ]).to(device=device, dtype=dtype)
|
||||||
stop = torch.Tensor([ stop ])
|
stop = torch.Tensor([ stop ]).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
text_start = 0
|
||||||
|
text_end = text_start + config.text_tokens
|
||||||
|
|
||||||
|
lang_start = text_end
|
||||||
|
lang_end = lang_start + config.langs
|
||||||
|
|
||||||
|
rvq_start = lang_end
|
||||||
|
rvq_end = rvq_start + config.resp_levels
|
||||||
|
|
||||||
|
prom_start = rvq_end
|
||||||
|
prom_end = prom_start + config.audio_tokens * config.resp_levels
|
||||||
|
|
||||||
|
task_start = prom_end
|
||||||
|
task_end = task_start + config.tasks
|
||||||
|
|
||||||
|
tone_start = task_end
|
||||||
|
tone_end = tone_start + config.tones
|
||||||
|
|
||||||
|
resp_start = tone_end
|
||||||
|
resp_end = resp_start + config.audio_tokens * config.resp_levels
|
||||||
|
|
||||||
|
# text tokens
|
||||||
for i, text in enumerate(text_list):
|
for i, text in enumerate(text_list):
|
||||||
seq = text.to("cpu", dtype=torch.int64)
|
if isinstance(text, torch.Tensor):
|
||||||
|
seq = text + text_start
|
||||||
|
else:
|
||||||
|
seq = torch.Tensor([text_start + text]).to(device=device, dtype=dtype)
|
||||||
|
input_ids[i].append( seq )
|
||||||
|
input_ids[i].append( sep )
|
||||||
|
|
||||||
|
# lang tokens
|
||||||
|
for i, lang in enumerate(lang_list):
|
||||||
|
if isinstance(lang, torch.Tensor):
|
||||||
|
seq = lang + lang_start
|
||||||
|
else:
|
||||||
|
seq = torch.Tensor([lang_start + lang]).to(device=device, dtype=dtype)
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( sep )
|
input_ids[i].append( sep )
|
||||||
|
|
||||||
offset = text_tokens
|
|
||||||
# inject target quant_level
|
# inject target quant_level
|
||||||
if quant_levels is not None:
|
if quant_levels is not None:
|
||||||
for i, rvq in enumerate( quant_levels ):
|
for i, rvq in enumerate( quant_levels ):
|
||||||
seq = torch.Tensor([offset + rvq]).to("cpu", dtype=torch.int64)
|
if isinstance(rvq, torch.Tensor):
|
||||||
|
seq = rvq + rvq_start
|
||||||
|
else:
|
||||||
|
seq = torch.Tensor([rvq_start + rvq]).to(device=device, dtype=dtype)
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( sep )
|
input_ids[i].append( sep )
|
||||||
|
|
||||||
offset = text_tokens + audio_rvq_levels
|
# prom / task tokens
|
||||||
for i, prom in enumerate(prom_list):
|
for i, prom in enumerate(prom_list):
|
||||||
# deinterleaved
|
# list of proms with a possible task token
|
||||||
if quant_levels is not None:
|
if isinstance(prom, list):
|
||||||
quant_level = quant_levels[i]
|
for p in prom:
|
||||||
if ignore_index is not None:
|
process_prom_or_task(i, p)
|
||||||
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] ) ] ).to("cpu", dtype=torch.int64)
|
# raw tensor
|
||||||
else:
|
|
||||||
seq = prom[:, quant_level].to("cpu", dtype=torch.int64)
|
|
||||||
for idx, token in enumerate( seq ):
|
|
||||||
token += offset + ( audio_tokens * quant_level )
|
|
||||||
# interleaved
|
|
||||||
else:
|
else:
|
||||||
if ignore_index is not None:
|
process_prom_or_task(i, prom)
|
||||||
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64)
|
|
||||||
else:
|
|
||||||
seq = prom.flatten().to("cpu", dtype=torch.int64)
|
|
||||||
for idx, token in enumerate( seq ):
|
|
||||||
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
|
|
||||||
|
|
||||||
|
# tone tokens
|
||||||
|
for i, tone in enumerate(tone_list):
|
||||||
|
if isinstance(tone, torch.Tensor):
|
||||||
|
seq = tone + tone_start
|
||||||
|
else:
|
||||||
|
seq = torch.Tensor([tone_start + tone]).to(device=device, dtype=dtype)
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( sep )
|
input_ids[i].append( sep )
|
||||||
|
|
||||||
offset = text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels)
|
# resp tokens
|
||||||
|
|
||||||
for i, resp in enumerate(resp_list):
|
for i, resp in enumerate(resp_list):
|
||||||
# deinterleaved
|
# deinterleaved
|
||||||
if quant_levels is not None:
|
if quant_levels is not None:
|
||||||
|
@ -115,54 +190,51 @@ def fold_inputs(
|
||||||
quant_level = quant_levels[i] - 1
|
quant_level = quant_levels[i] - 1
|
||||||
# way to signal we want to inference for rvq level 0
|
# way to signal we want to inference for rvq level 0
|
||||||
# without it, it's a random chance for any level to be selected again
|
# without it, it's a random chance for any level to be selected again
|
||||||
|
|
||||||
if quant_level < 0:
|
if quant_level < 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
seq = sep
|
|
||||||
else:
|
else:
|
||||||
# my shitcode keeps things as lists of tensors for each level, so this handles it because lists can't index by tuples
|
# my shitcode keeps things as lists of tensors for each level, so this handles it because lists can't index by tuples
|
||||||
if isinstance(resp, list):
|
if isinstance(resp, list):
|
||||||
seq = resp[quant_level].to("cpu", dtype=torch.int64)
|
seq = resp[quant_level].to(device=device, dtype=dtype).clone()
|
||||||
else:
|
else:
|
||||||
seq = resp[:, quant_level].to("cpu", dtype=torch.int64)
|
seq = resp[:, quant_level].to(device=device, dtype=dtype).clone()
|
||||||
|
|
||||||
for idx, token in enumerate( seq ):
|
for idx, token in enumerate( seq ):
|
||||||
token += offset + ( audio_tokens * quant_level )
|
token += resp_start + ( config.audio_tokens * quant_level )
|
||||||
|
|
||||||
|
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( stop )
|
input_ids[i].append( stop )
|
||||||
# interleaved
|
# interleaved
|
||||||
else:
|
else:
|
||||||
seq = resp.flatten().to("cpu", dtype=torch.int64)
|
seq = resp.flatten().to(device=device, dtype=dtype)
|
||||||
for idx, token in enumerate( seq ):
|
for idx, token in enumerate( seq ):
|
||||||
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
|
token += resp_start + ( config.audio_tokens * ( idx % config.resp_levels ) )
|
||||||
|
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( stop )
|
input_ids[i].append( stop )
|
||||||
|
|
||||||
|
# targ list
|
||||||
for i, resp in enumerate(targ_list):
|
for i, resp in enumerate(targ_list):
|
||||||
# deinterleaved
|
# deinterleaved
|
||||||
if quant_levels is not None:
|
if quant_levels is not None:
|
||||||
quant_level = quant_levels[i]
|
quant_level = quant_levels[i]
|
||||||
seq = resp[:, quant_level].to("cpu", dtype=torch.int64)
|
seq = resp[:, quant_level].to(device=device, dtype=dtype)
|
||||||
for idx, token in enumerate( seq ):
|
for idx, token in enumerate( seq ):
|
||||||
token += offset + ( audio_tokens * quant_level )
|
token += resp_start + ( config.audio_tokens * quant_level )
|
||||||
|
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( stop )
|
input_ids[i].append( stop )
|
||||||
# interleaved
|
# interleaved
|
||||||
else:
|
else:
|
||||||
seq = resp.flatten().to("cpu", dtype=torch.int64)
|
seq = resp.flatten().to(device=device, dtype=dtype)
|
||||||
for idx, token in enumerate( seq ):
|
for idx, token in enumerate( seq ):
|
||||||
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
|
token += resp_start + ( config.audio_tokens * ( idx % config.resp_levels ) )
|
||||||
|
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( stop )
|
input_ids[i].append( stop )
|
||||||
|
|
||||||
for i, batch in enumerate(input_ids):
|
for i, batch in enumerate(input_ids):
|
||||||
input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64)
|
input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
return list_to_tensor(input_ids)
|
return list_to_tensor(input_ids)
|
||||||
|
|
||||||
|
@ -174,20 +246,62 @@ def unfold_outputs(
|
||||||
sep = 3,
|
sep = 3,
|
||||||
stop = 3,
|
stop = 3,
|
||||||
|
|
||||||
text_tokens = 256,
|
config = None,
|
||||||
audio_tokens = 1024,
|
|
||||||
audio_rvq_levels = cfg.model.max_levels,
|
|
||||||
quant_levels = None,
|
quant_levels = None,
|
||||||
):
|
):
|
||||||
|
def bin_to_rvqs( tokens ):
|
||||||
|
length = len(tokens)
|
||||||
|
"""
|
||||||
|
if length % config.resp_levels == 0:
|
||||||
|
tokens = torch.Tensor(tokens).reshape( config.resp_levels, length // config.resp_levels ).t()
|
||||||
|
"""
|
||||||
|
bins = [ [] for _ in range(config.resp_levels) ]
|
||||||
|
for pos in range( length ):
|
||||||
|
rvq = pos % config.resp_levels
|
||||||
|
bins[rvq].append( tokens[pos] )
|
||||||
|
nearest = ( len(bins) // config.resp_levels ) * config.resp_levels
|
||||||
|
bins = bins[:nearest]
|
||||||
|
return torch.Tensor(bins).t().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = cfg.model
|
||||||
|
|
||||||
device = output_ids.device
|
device = output_ids.device
|
||||||
|
dtype = torch.int64
|
||||||
|
|
||||||
batch_size = output_ids.shape[0]
|
batch_size = output_ids.shape[0]
|
||||||
|
|
||||||
text_list = [ [] for _ in range(batch_size) ]
|
text_list = [ [] for _ in range(batch_size) ]
|
||||||
|
rvq_list = [ [] for _ in range(batch_size) ]
|
||||||
|
lang_list = [ [] for _ in range(batch_size) ]
|
||||||
|
task_list = [ [] for _ in range(batch_size) ]
|
||||||
|
tone_list = [ [] for _ in range(batch_size) ]
|
||||||
prom_list = [ [] for _ in range(batch_size) ]
|
prom_list = [ [] for _ in range(batch_size) ]
|
||||||
resp_list = [ [] for _ in range(batch_size) ]
|
resp_list = [ [] for _ in range(batch_size) ]
|
||||||
|
|
||||||
|
text_start = 0
|
||||||
|
text_end = text_start + config.text_tokens
|
||||||
|
|
||||||
|
lang_start = text_end
|
||||||
|
lang_end = lang_start + config.langs
|
||||||
|
|
||||||
|
rvq_start = lang_end
|
||||||
|
rvq_end = rvq_start + config.resp_levels
|
||||||
|
|
||||||
|
prom_start = rvq_end
|
||||||
|
prom_end = prom_start + config.audio_tokens * config.resp_levels
|
||||||
|
|
||||||
|
task_start = prom_end
|
||||||
|
task_end = task_start + config.tasks
|
||||||
|
|
||||||
|
tone_start = task_end
|
||||||
|
tone_end = tone_start + config.tones
|
||||||
|
|
||||||
|
resp_start = tone_end
|
||||||
|
resp_end = resp_start + config.audio_tokens * config.resp_levels
|
||||||
|
|
||||||
for i, batch in enumerate( output_ids ):
|
for i, batch in enumerate( output_ids ):
|
||||||
# crigne logic to handle prefix resp for rvq levels > 0
|
# cringe logic to handle prefix resp for rvq levels > 0
|
||||||
# a better way is to observe if the rvq level increased
|
# a better way is to observe if the rvq level increased
|
||||||
should_flush = False
|
should_flush = False
|
||||||
flushed = False
|
flushed = False
|
||||||
|
@ -201,49 +315,51 @@ def unfold_outputs(
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if 0 <= id and id < text_tokens:
|
# text tokens
|
||||||
text_list[i].append( id )
|
if text_start <= id and id < text_end:
|
||||||
elif text_tokens + audio_rvq_levels <= id and id < text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels):
|
text_list[i].append( (id - text_start) % config.text_tokens )
|
||||||
prom_list[i].append( (id - text_tokens - audio_rvq_levels) % audio_tokens )
|
# lang tokens
|
||||||
elif text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels) <= id:
|
elif lang_start <= id and id < lang_end:
|
||||||
resp_list[i].append( (id - text_tokens - audio_rvq_levels) % audio_tokens )
|
lang_list[i].append( (id - lang_start) % config.langs )
|
||||||
|
# rvq levels
|
||||||
|
elif rvq_start <= id and id < rvq_end:
|
||||||
|
rvq_list[i].append( (id - rvq_start) % config.resp_levels )
|
||||||
|
# prom tokens
|
||||||
|
elif prom_start <= id and id < prom_end:
|
||||||
|
prom_list[i].append( (id - prom_start) % config.audio_tokens )
|
||||||
|
# task tokens
|
||||||
|
elif task_start <= id and id < task_end:
|
||||||
|
task_list[i].append( (id - task_start) % config.tasks )
|
||||||
|
# lang tokens
|
||||||
|
elif tone_start <= id and id < tone_end:
|
||||||
|
tone_list[i].append( (id - tone_start) % config.tones )
|
||||||
|
# resp tokens
|
||||||
|
elif resp_start <= id and id < resp_end:
|
||||||
|
resp_list[i].append( (id - resp_start) % config.audio_tokens )
|
||||||
|
|
||||||
if not flushed:
|
if not flushed:
|
||||||
should_flush = True
|
should_flush = True
|
||||||
|
|
||||||
if quant_levels is not None:
|
if quant_levels is not None:
|
||||||
prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=torch.int64)
|
prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=dtype)
|
||||||
resp_list[i] = torch.Tensor(resp_list[i]).t().to(device=device, dtype=torch.int64)
|
resp_list[i] = torch.Tensor(resp_list[i]).t().to(device=device, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
prom_len = len(prom_list[i])
|
prom_list[i] = bin_to_rvqs( prom_list[i] )
|
||||||
if prom_len % audio_rvq_levels == 0 and False:
|
resp_list[i] = bin_to_rvqs( resp_list[i] )
|
||||||
prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t()
|
|
||||||
else:
|
|
||||||
bins = [ [] for _ in range(audio_rvq_levels) ]
|
|
||||||
for pos in range( prom_len ):
|
|
||||||
rvq = pos % audio_rvq_levels
|
|
||||||
bins[rvq].append( prom_list[i][pos] )
|
|
||||||
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
|
||||||
bins = bins[:nearest]
|
|
||||||
prom_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
|
|
||||||
|
|
||||||
resp_len = len(resp_list[i])
|
text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=dtype)
|
||||||
if len(resp_list[i]) % audio_rvq_levels == 0 and False:
|
task_list[i] = torch.Tensor( task_list[i] ).to(device=device, dtype=dtype)
|
||||||
resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t()
|
lang_list[i] = torch.Tensor( lang_list[i] ).to(device=device, dtype=dtype)
|
||||||
else:
|
tone_list[i] = torch.Tensor( tone_list[i] ).to(device=device, dtype=dtype)
|
||||||
bins = [ [] for _ in range(audio_rvq_levels) ]
|
|
||||||
for pos in range( resp_len ):
|
|
||||||
rvq = pos % audio_rvq_levels
|
|
||||||
bins[rvq].append( resp_list[i][pos] )
|
|
||||||
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
|
||||||
bins = bins[:nearest]
|
|
||||||
resp_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
|
|
||||||
|
|
||||||
text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=torch.int64)
|
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
prom_list=prom_list,
|
prom_list=prom_list,
|
||||||
resp_list=resp_list
|
resp_list=resp_list,
|
||||||
|
|
||||||
|
task_list=task_list,
|
||||||
|
lang_list=lang_list,
|
||||||
|
tone_list=tone_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
# to-do: clean up this symmap mess
|
# to-do: clean up this symmap mess
|
||||||
|
@ -1072,7 +1188,7 @@ def _create_dataloader(dataset, training):
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
||||||
drop_last=training,
|
drop_last=training,
|
||||||
sampler=dataset.sampler,
|
sampler=dataset.sampler if training else None,
|
||||||
) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict(
|
) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict(
|
||||||
batch_sampler=dataset.sampler,
|
batch_sampler=dataset.sampler,
|
||||||
)
|
)
|
||||||
|
|
|
@ -137,50 +137,35 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
# is training
|
# is training
|
||||||
if training:
|
if training:
|
||||||
|
# specifies how to sample probabilities of which RVQ levels to train against
|
||||||
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
|
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
|
||||||
|
|
||||||
# determines which RVQ level to target per batch
|
# determines which RVQ level to target per batch
|
||||||
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ]
|
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
|
||||||
|
# rate to perform token dropout errors
|
||||||
token_dropout_error = self.config.experimental.token_dropout_error
|
token_dropout_error = self.config.experimental.token_dropout_error
|
||||||
|
# RVQ levels to apply token dropout on
|
||||||
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
|
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
|
||||||
|
# implicitly set it to all levels
|
||||||
if not token_dropout_rvq_levels:
|
if not token_dropout_rvq_levels:
|
||||||
token_dropout_rvq_levels = [0, self.resp_levels]
|
token_dropout_rvq_levels = [0, self.resp_levels - 1]
|
||||||
|
# allow passing a specific distribution of RVQ levels
|
||||||
if p_rvq_levels == "equal":
|
p_rvq_levels = p_rvq_levels if isinstance(p_rvq_levels, list) else []
|
||||||
|
if not p_rvq_levels:
|
||||||
|
lo, hi = quant_level_range[0], quant_level_range[1] + 1
|
||||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
if p_rvq_levels == "equal":
|
||||||
else: # if p_rvq_levels == "auto":
|
p_rvq_levels = [ i for i in range( lo, hi ) ]
|
||||||
# makes higher levels less likely
|
else:
|
||||||
"""
|
# yuck
|
||||||
def generate( lo=0, hi=8 ):
|
p_rvq_levels = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
|
||||||
index = lo
|
|
||||||
p = random.random()
|
|
||||||
for i in range(lo, hi):
|
|
||||||
if p < 1.0 / (2 ** i):
|
|
||||||
index = i
|
|
||||||
return int(index)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# allow passing a specific distribution of RVQ levels
|
|
||||||
pool = p_rvq_levels if isinstance(p_rvq_levels, list) else []
|
|
||||||
if not pool:
|
|
||||||
lo, hi = quant_level_range[0], quant_level_range[1]
|
|
||||||
for i in range( lo, hi ):
|
|
||||||
rep = hi - i
|
|
||||||
pool += [i] * rep
|
|
||||||
|
|
||||||
quant_levels = [ random.choice( pool ) for i in range(batch_size) ]
|
|
||||||
|
|
||||||
# these two are techinically equivalent if the audio embeddings handle things properly
|
|
||||||
"""
|
|
||||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
|
||||||
stop_sequence = torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
# input RVQ levels
|
||||||
|
quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ]
|
||||||
|
# trim resps to only contain all levels below the target level
|
||||||
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||||
|
# tensor to cat for RVQ level 0
|
||||||
stop_sequence = torch.Tensor([[self.stop_token] * 1]).to(device=device, dtype=torch.int16)
|
stop_sequence = torch.Tensor([[self.stop_token] * 1]).to(device=device, dtype=torch.int16)
|
||||||
|
# I hate python's value/reference semantics so much
|
||||||
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list):
|
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list):
|
||||||
# cap quant_level if it exceeds its corresponding resp/prom
|
# cap quant_level if it exceeds its corresponding resp/prom
|
||||||
if quant_level >= resps.shape[-1]:
|
if quant_level >= resps.shape[-1]:
|
||||||
|
@ -213,7 +198,6 @@ class AR_NAR(Base):
|
||||||
# only apply stop token for RVQ level 0
|
# only apply stop token for RVQ level 0
|
||||||
if quant_level <= 0:
|
if quant_level <= 0:
|
||||||
# append stop tokens for AR
|
# append stop tokens for AR
|
||||||
# could technically do it in the .inputs call
|
|
||||||
resps_list[i] = torch.cat([ resps, stop_sequence ])
|
resps_list[i] = torch.cat([ resps, stop_sequence ])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,30 @@ class Model(LlmArchClass):
|
||||||
hf_attention = config.attention if config is not None else None
|
hf_attention = config.attention if config is not None else None
|
||||||
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
|
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
|
||||||
# text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop
|
# text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop
|
||||||
vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1
|
# vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1
|
||||||
|
|
||||||
|
text_start = 0
|
||||||
|
text_end = text_start + config.text_tokens
|
||||||
|
|
||||||
|
lang_start = text_end
|
||||||
|
lang_end = lang_start + config.langs
|
||||||
|
|
||||||
|
rvq_start = lang_end
|
||||||
|
rvq_end = rvq_start + config.resp_levels
|
||||||
|
|
||||||
|
prom_start = rvq_end
|
||||||
|
prom_end = prom_start + config.audio_tokens * config.resp_levels
|
||||||
|
|
||||||
|
task_start = prom_end
|
||||||
|
task_end = task_start + config.tasks
|
||||||
|
|
||||||
|
tone_start = task_end
|
||||||
|
tone_end = tone_start + config.tones
|
||||||
|
|
||||||
|
resp_start = tone_end
|
||||||
|
resp_end = resp_start + config.audio_tokens * config.resp_levels
|
||||||
|
|
||||||
|
vocab_size = resp_end
|
||||||
|
|
||||||
if cfg.model.arch_type == "llama":
|
if cfg.model.arch_type == "llama":
|
||||||
super().__init__(config=LlamaConfig(
|
super().__init__(config=LlamaConfig(
|
||||||
|
@ -148,11 +171,94 @@ class Model(LlmArchClass):
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if cfg.model.arch_type in ["mamba","mamba2"]:
|
config = self.hyper_config
|
||||||
|
|
||||||
|
if "text_list" in kwargs:
|
||||||
|
text_list = kwargs.pop("text_list", None)
|
||||||
|
proms_list = kwargs.pop("proms_list", None)
|
||||||
|
resps_list = kwargs.pop("resps_list", None)
|
||||||
|
lang_list = kwargs.pop("lang_list", None)
|
||||||
|
tone_list = kwargs.pop("tone_list", None)
|
||||||
|
|
||||||
|
training = kwargs.pop("training", False)
|
||||||
|
steps = kwargs.pop("steps", 500)
|
||||||
|
|
||||||
|
batch_size = len(text_list)
|
||||||
|
|
||||||
|
if training:
|
||||||
|
quant_levels = None if config.experimental.interleave else [ random.randint( 0 if "ar" in config.capabilities else 1, config.max_levels - 1) for _ in range(batch_size) ]
|
||||||
|
|
||||||
|
input_ids, attention_mask = fold_inputs(
|
||||||
|
text_list=text_list,
|
||||||
|
prom_list=proms_list,
|
||||||
|
resp_list=resps_list,
|
||||||
|
targ_list=resps_list,
|
||||||
|
quant_levels=quant_levels,
|
||||||
|
)
|
||||||
|
target_ids, target_attention_mask = fold_inputs(
|
||||||
|
text_list=text_list,
|
||||||
|
prom_list=proms_list,
|
||||||
|
resp_list=resps_list,
|
||||||
|
targ_list=resps_list,
|
||||||
|
quant_levels=quant_levels,
|
||||||
|
ignore_index=-100
|
||||||
|
)
|
||||||
|
return self.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
labels=target_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.experimental.interleave:
|
||||||
|
input_ids, attention_mask = fold_inputs( text_list=text_list, prom_list=proms_list )
|
||||||
|
output = self.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps*config.max_levels, eos_token_id=3, do_sample=False)
|
||||||
|
return unfold_outputs( output )["resp_list"]
|
||||||
|
|
||||||
|
resps_list = [ [] for _ in range(batch_size) ]
|
||||||
|
for l in range(config.max_levels):
|
||||||
|
quant_levels = [ l for _ in range(batch_size) ]
|
||||||
|
|
||||||
|
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=proms_list, resp_list=resps_list, quant_levels=quant_levels)
|
||||||
|
min_length = 1
|
||||||
|
for batch in input_ids:
|
||||||
|
min_length = max( min_length, batch.shape[0] + 1 )
|
||||||
|
|
||||||
|
output = self.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
min_length=min_length,
|
||||||
|
max_length=min_length+steps*2,
|
||||||
|
eos_token_id=3,
|
||||||
|
do_sample=False
|
||||||
|
)
|
||||||
|
|
||||||
|
unfolded = unfold_outputs( output, quant_levels=quant_levels )
|
||||||
|
|
||||||
|
if l == 0:
|
||||||
|
steps = 0
|
||||||
|
|
||||||
|
for batch, resp in enumerate(unfolded["resp_list"]):
|
||||||
|
length = resp.shape[-1]
|
||||||
|
|
||||||
|
# store length
|
||||||
|
if l == 0:
|
||||||
|
steps = max( steps, length )
|
||||||
|
# pad
|
||||||
|
else:
|
||||||
|
resp = resp[:steps]
|
||||||
|
if length < steps:
|
||||||
|
resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ])
|
||||||
|
resps_list[batch].append( resp )
|
||||||
|
|
||||||
|
for i, resp in enumerate( resps_list ):
|
||||||
|
resps_list[i] = torch.stack( resp ).t()
|
||||||
|
|
||||||
|
return resps_list
|
||||||
|
|
||||||
|
if config.arch_type in ["mamba","mamba2"]:
|
||||||
if "attention_mask" in kwargs:
|
if "attention_mask" in kwargs:
|
||||||
kwargs.pop("attention_mask")
|
kwargs.pop("attention_mask")
|
||||||
|
|
||||||
labels = kwargs.pop("labels") if "labels" in kwargs else None
|
labels = kwargs.pop("labels", None)
|
||||||
|
|
||||||
output = super().forward(*args, **kwargs)
|
output = super().forward(*args, **kwargs)
|
||||||
logits = output.logits
|
logits = output.logits
|
||||||
|
@ -322,53 +428,8 @@ def example_usage():
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):
|
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
batch_size = len(text_list)
|
|
||||||
resp_list = None
|
|
||||||
if cfg.model.experimental.interleave:
|
|
||||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list)
|
|
||||||
output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps, eos_token_id=3, do_sample=False)
|
|
||||||
|
|
||||||
unfolded = unfold_outputs( output )
|
resp_list = model( text_list=text_list, proms_list=prom_list )
|
||||||
resp_list = unfolded["resp_list"]
|
|
||||||
else:
|
|
||||||
resp_list = [ [] for _ in range(batch_size) ]
|
|
||||||
for l in range(cfg.model.max_levels):
|
|
||||||
quant_levels = [ l for _ in range(batch_size) ]
|
|
||||||
|
|
||||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels)
|
|
||||||
min_length = 1
|
|
||||||
for batch in input_ids:
|
|
||||||
min_length = max( min_length, batch.shape[0] + 1 )
|
|
||||||
|
|
||||||
output = model.generate(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
min_length=min_length,
|
|
||||||
max_length=min_length+steps*2,
|
|
||||||
eos_token_id=3,
|
|
||||||
do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
unfolded = unfold_outputs( output, quant_levels=quant_levels )
|
|
||||||
|
|
||||||
if l == 0:
|
|
||||||
steps = 0
|
|
||||||
|
|
||||||
for batch, resp in enumerate(unfolded["resp_list"]):
|
|
||||||
length = resp.shape[-1]
|
|
||||||
|
|
||||||
# store length
|
|
||||||
if l == 0:
|
|
||||||
steps = max( steps, length )
|
|
||||||
# pad
|
|
||||||
else:
|
|
||||||
resp = resp[:steps]
|
|
||||||
if length < steps:
|
|
||||||
resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ])
|
|
||||||
resp_list[batch].append( resp )
|
|
||||||
|
|
||||||
for i, resp in enumerate( resp_list ):
|
|
||||||
resp_list[i] = torch.stack( resp ).t()
|
|
||||||
|
|
||||||
for i, batch in enumerate(resp_list):
|
for i, batch in enumerate(resp_list):
|
||||||
_ = decode_to_file(batch.to(device=device), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
|
_ = decode_to_file(batch.to(device=device), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
|
||||||
|
@ -381,18 +442,7 @@ def example_usage():
|
||||||
for i in t:
|
for i in t:
|
||||||
stats = {"step": i}
|
stats = {"step": i}
|
||||||
|
|
||||||
batch_size = len(text_list)
|
stats |= engine.traverse(text_list=text_list, proms_list=prom_list, resps_list=resp_list, training=True)
|
||||||
quant_levels = None if cfg.model.experimental.interleave else torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,))
|
|
||||||
if quant_levels is not None:
|
|
||||||
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
|
|
||||||
else:
|
|
||||||
resps_list = [ resp for resp in resp_list ]
|
|
||||||
|
|
||||||
|
|
||||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resps_list, targ_list=resp_list, quant_levels=quant_levels)
|
|
||||||
target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels)
|
|
||||||
|
|
||||||
stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask)
|
|
||||||
stats |= engine.gather_attribute("stats")
|
stats |= engine.gather_attribute("stats")
|
||||||
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
||||||
|
|
||||||
|
|
|
@ -133,46 +133,62 @@ class NAR(Base):
|
||||||
# generate task list to train against
|
# generate task list to train against
|
||||||
task_list = [ sample_task() for _ in range(batch_size) ]
|
task_list = [ sample_task() for _ in range(batch_size) ]
|
||||||
|
|
||||||
# determines which RVQ level to target per batch
|
# specifies how to sample probabilities of which RVQ levels to train against
|
||||||
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ]
|
|
||||||
|
|
||||||
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
|
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
|
||||||
|
# determines which RVQ level to target per batch
|
||||||
if p_rvq_levels == "equal":
|
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
|
||||||
|
# rate to perform token dropout errors
|
||||||
|
token_dropout_error = self.config.experimental.token_dropout_error
|
||||||
|
# RVQ levels to apply token dropout on
|
||||||
|
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
|
||||||
|
# implicitly set it to all levels
|
||||||
|
if not token_dropout_rvq_levels:
|
||||||
|
token_dropout_rvq_levels = [0, self.resp_levels - 1]
|
||||||
|
# allow passing a specific distribution of RVQ levels
|
||||||
|
p_rvq_levels = p_rvq_levels if isinstance(p_rvq_levels, list) else []
|
||||||
|
if not p_rvq_levels:
|
||||||
|
lo, hi = quant_level_range[0], quant_level_range[1] + 1
|
||||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
if p_rvq_levels == "equal":
|
||||||
else: # if p_rvq_levels == "auto":
|
p_rvq_levels = [ i for i in range( lo, hi ) ]
|
||||||
# makes higher levels less likely
|
else:
|
||||||
"""
|
# yuck
|
||||||
def generate( lo=0, hi=8 ):
|
p_rvq_levels = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
|
||||||
index = lo
|
|
||||||
p = random.random()
|
|
||||||
for i in range(lo, hi):
|
|
||||||
if p < 1.0 / (2 ** i):
|
|
||||||
index = i
|
|
||||||
return int(index)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# allow passing a specific distribution of RVQ levels
|
# input RVQ levels
|
||||||
pool = p_rvq_levels if isinstance(p_rvq_levels, list) else []
|
quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ]
|
||||||
if not pool:
|
# trim resps to only contain all levels below the target level
|
||||||
lo, hi = quant_level_range[0], quant_level_range[1]
|
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||||
for i in range( lo, hi ):
|
|
||||||
rep = hi - i
|
|
||||||
pool += [i] * rep
|
|
||||||
|
|
||||||
quant_levels = [ random.choice( pool ) for i in range(batch_size) ]
|
# I hate python's value/reference semantics so much
|
||||||
|
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list):
|
||||||
# clamp quant_levels because some of my audio was saved for only 8 out of 9 RVQ levels for DAC...
|
|
||||||
for i in range(batch_size):
|
|
||||||
# cap quant_level if it exceeds its corresponding resp/prom
|
# cap quant_level if it exceeds its corresponding resp/prom
|
||||||
if quant_levels[i] >= resps_list[i].shape[-1]:
|
if quant_level >= resps.shape[-1]:
|
||||||
quant_levels[i] = resps_list[i].shape[-1] - 1
|
quant_levels[i] = resps.shape[-1] - 1
|
||||||
|
|
||||||
if quant_levels[i] >= proms_list[i].shape[-1]:
|
# proms could be a Tensor, list[Tensor], or None
|
||||||
quant_levels[i] = proms_list[i].shape[-1] - 1
|
if isinstance( proms, torch.Tensor ):
|
||||||
|
if quant_level >= proms.shape[-1]:
|
||||||
|
quant_levels[i] = proms.shape[-1] - 1
|
||||||
|
|
||||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
elif isinstance( proms, list ):
|
||||||
|
for j, prom in enumerate( proms ):
|
||||||
|
if not isinstance( prom, torch.Tensor ):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if quant_level >= prom.shape[-1]:
|
||||||
|
quant_levels[i] = prom.shape[-1] - 1
|
||||||
|
|
||||||
|
# apply token dropout error compensation
|
||||||
|
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
||||||
|
steps = resps.shape[0]
|
||||||
|
for l in range( quant_level ):
|
||||||
|
for t in range( steps ):
|
||||||
|
token = resps[t, l].item()
|
||||||
|
|
||||||
|
if random.random() < token_dropout_error:
|
||||||
|
offset = 1 * ( 1 if random.random() < 0.5 else -1 )
|
||||||
|
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||||
|
|
||||||
inputs = self.inputs(
|
inputs = self.inputs(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
|
|
112
vall_e/train.py
112
vall_e/train.py
|
@ -31,44 +31,16 @@ def train_feeder(engine, batch):
|
||||||
batch_size = len(batch["text"])
|
batch_size = len(batch["text"])
|
||||||
engine.current_batch_size = batch_size
|
engine.current_batch_size = batch_size
|
||||||
|
|
||||||
if engine.hyper_config.experimental.hf:
|
engine(
|
||||||
if engine.hyper_config.experimental.interleave:
|
text_list=batch["text"],
|
||||||
quant_levels = 0
|
proms_list=batch["proms"],
|
||||||
resps_list = [ resp for resp in batch["resps"] ]
|
resps_list=batch["resps"],
|
||||||
else:
|
lang_list=batch["lang"],
|
||||||
quant_levels = [ random.randint( 0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels) for _ in range(batch_size) ]
|
tone_list=batch["tone"],
|
||||||
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, batch["resps"]) ]
|
task_list=batch["task"],
|
||||||
|
|
||||||
input_ids, attention_mask = fold_inputs(
|
training=True,
|
||||||
text_list=batch["text"],
|
)
|
||||||
prom_list=batch["proms"],
|
|
||||||
resp_list=resps_list,
|
|
||||||
targ_list=batch["resps"],
|
|
||||||
quant_levels=quant_levels,
|
|
||||||
)
|
|
||||||
target_ids, target_attention_mask = fold_inputs(
|
|
||||||
text_list=batch["text"],
|
|
||||||
prom_list=batch["proms"],
|
|
||||||
resp_list=resps_list,
|
|
||||||
targ_list=batch["resps"],
|
|
||||||
quant_levels=quant_levels,
|
|
||||||
ignore_index=-100
|
|
||||||
)
|
|
||||||
engine(
|
|
||||||
input_ids=input_ids,
|
|
||||||
labels=target_ids,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
engine(
|
|
||||||
text_list=batch["text"],
|
|
||||||
proms_list=batch["proms"],
|
|
||||||
resps_list=batch["resps"],
|
|
||||||
lang_list=batch["lang"],
|
|
||||||
tone_list=batch["tone"],
|
|
||||||
task_list=batch["task"],
|
|
||||||
|
|
||||||
training=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
losses = engine.gather_attribute("loss")
|
losses = engine.gather_attribute("loss")
|
||||||
stat = engine.gather_attribute("stats")
|
stat = engine.gather_attribute("stats")
|
||||||
|
@ -137,66 +109,18 @@ def run_eval(engines, eval_name, dl):
|
||||||
engine = engines[name]
|
engine = engines[name]
|
||||||
|
|
||||||
if engine.hyper_config.experimental.hf:
|
if engine.hyper_config.experimental.hf:
|
||||||
if engine.hyper_config.experimental.interleave:
|
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"] )
|
||||||
input_ids, attention_mask = fold_inputs(
|
elif "len" in engine.hyper_config.capabilities:
|
||||||
text_list=batch["text"],
|
len_list = engine(text_list=batch["text"], proms_list=batch["proms"], max_steps=10 ) # don't need more than that
|
||||||
prom_list=batch["proms"],
|
resps_list = engine( text_list=batch["text"], proms_list=batch["proms"], len_list=len_list, max_levels=cfg.evaluation.nar_levels )
|
||||||
)
|
|
||||||
output = engine.module.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False)
|
|
||||||
resps_list = unfold_outputs( output )["resp_list"]
|
|
||||||
else:
|
|
||||||
steps = cfg.evaluation.steps
|
|
||||||
resps_list = [ [] for _ in range(len(text_list)) ]
|
|
||||||
for l in range(cfg.model.max_levels):
|
|
||||||
quant_levels = [ [ l ] for _ in range(len(text_list)) ]
|
|
||||||
|
|
||||||
input_ids, attention_mask = fold_inputs(text_list=batch["text"], prom_list=batch["proms"], resp_list=resps_list, quant_levels=quant_levels, experimental=True)
|
|
||||||
min_length = 1
|
|
||||||
for batch in input_ids:
|
|
||||||
min_length = max( min_length, batch.shape[0] + 1 )
|
|
||||||
|
|
||||||
output = model.generate(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
min_length=min_length,
|
|
||||||
max_length=min_length+steps*(2 if l > 0 else 1),
|
|
||||||
eos_token_id=3,
|
|
||||||
do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
unfolded = unfold_outputs( output, quant_levels=quant_levels )
|
|
||||||
|
|
||||||
if l == 0:
|
|
||||||
steps = 0
|
|
||||||
|
|
||||||
for batch, resp in enumerate(unfolded["resp_list"]):
|
|
||||||
length = resp.shape[-1]
|
|
||||||
|
|
||||||
# store length
|
|
||||||
if l == 0:
|
|
||||||
steps = max( steps, length )
|
|
||||||
# pad
|
|
||||||
else:
|
|
||||||
resp = resp[:steps]
|
|
||||||
if length < steps:
|
|
||||||
resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ])
|
|
||||||
|
|
||||||
resps_list[batch].append( resp )
|
|
||||||
|
|
||||||
for i, resp in enumerate( resps_list ):
|
|
||||||
resps_list[i] = torch.stack( resp ).t()
|
|
||||||
else:
|
else:
|
||||||
if "len" in engine.hyper_config.capabilities:
|
if "ar" in engine.hyper_config.capabilities:
|
||||||
len_list = engine(text_list=batch["text"], proms_list=batch["proms"], max_steps=10 ) # don't need more than that
|
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||||
resps_list = engine( text_list=batch["text"], proms_list=batch["proms"], len_list=len_list, max_levels=cfg.evaluation.nar_levels )
|
|
||||||
else:
|
else:
|
||||||
if "ar" in engine.hyper_config.capabilities:
|
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
||||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
|
||||||
else:
|
|
||||||
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
|
||||||
|
|
||||||
if "nar" in engine.hyper_config.capabilities:
|
if "nar" in engine.hyper_config.capabilities:
|
||||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels )
|
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels )
|
||||||
|
|
||||||
process( name, batch, resps_list )
|
process( name, batch, resps_list )
|
||||||
|
|
||||||
|
|
|
@ -178,3 +178,6 @@ def to_device(x: T | None, *args, **kwargs) -> T:
|
||||||
return
|
return
|
||||||
|
|
||||||
return tree_map(lambda t: t.to(*args, **kwargs), x)
|
return tree_map(lambda t: t.to(*args, **kwargs), x)
|
||||||
|
|
||||||
|
def coalese( *arg, return_last=True ):
|
||||||
|
return [ x for x in arg if x is not None ][-1 if return_last else 0]
|
Loading…
Reference in New Issue
Block a user