cleanup for AR_NAR inferencing to allow both TTS and STT tasks simultaneously (need to have training eval do this to though)

This commit is contained in:
mrq 2024-09-06 14:30:12 -05:00
parent 341e19162b
commit d33a906119
2 changed files with 93 additions and 39 deletions

View File

@ -76,8 +76,15 @@ class AR_NAR(Base):
if task_list is None:
task_list = [ default_task for _ in range(batch_size) ]
has_none = resps_list is None or text_list is None
if not has_none:
for i, task in enumerate( task_list ):
if resps_list[i] is None or text_list[i] is None:
has_none = True
break
# is training or NAR
if resps_list is not None and text_list is not None:
if not has_none:
n_levels_set = {r.shape[-1] for r in resps_list}
n_levels = next(iter(n_levels_set))
@ -241,7 +248,8 @@ class AR_NAR(Base):
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
stopped = torch.zeros(batch_size, device=device).bool()
stop_token = self.stop_token if task_list[0] != "stt" else 2 # to-do: derive from tokenizer
audio_stop_token = self.stop_token
text_stop_token = 2
state = None
mirostat = [
@ -257,10 +265,15 @@ class AR_NAR(Base):
# get next in sequence
for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
if task_list[0] in text_task:
text_list = [x for x in sequence_list]
else:
resps_list = [x.unsqueeze(dim=-1) for x in sequence_list]
#
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
"""
print( "task_list:", task_list )
print( "text_list:", text_list )
print( "resps_list:", resps_list )
"""
inputs = self.inputs(
text_list=text_list,
@ -286,7 +299,7 @@ class AR_NAR(Base):
r = super().sample(
logits=logits,
prev_list=resps_list,
prev_list=[ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ],
temperature=sampling_temperature,
min_temperature=sampling_min_temperature,
@ -325,12 +338,14 @@ class AR_NAR(Base):
# append tokens
for i, ri in enumerate(r):
task = task_list[i]
stop_token = audio_stop_token if task not in text_task else text_stop_token
if stop_token in ri:
stopped[i] = True
sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)])
# stop token found
stopped |= r == stop_token
# stopped |= r == stop_token
if stopped.all().item():
break
@ -339,7 +354,10 @@ class AR_NAR(Base):
if sampling_beam_width:
sequence_list = [ sequence_list[0] ]
sequence_list = [self._prune(r, stop_token) for r in sequence_list]
# remove stop token
sequence_list = [self._prune(r, audio_stop_token if task_list[i] not in text_task else text_stop_token) for i, r in enumerate(sequence_list)]
# remove <bos>
sequence_list = [ sequence_list[i] if task not in text_task else sequence_list[i][1:] for i, task in enumerate( task_list ) ]
return sequence_list
@ -426,7 +444,8 @@ def example_usage():
"""
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
available_tasks = cfg.dataset.tasks_list
#available_tasks = cfg.dataset.tasks_list
available_tasks = ["tts", "stt"]
model = AR_NAR(**kwargs).to(device)
steps = 150 * len(available_tasks) # * cfg.model.experimental.causal_size
@ -515,6 +534,14 @@ def example_usage():
@torch.no_grad()
def sample_data(t=None):
if isinstance(t, list):
tasks = t
texts = [ text_list[0].to(device) if task != "stt" else None for i, task in enumerate( tasks ) ]
proms = [ proms_list[0].to(device) if task != "stt" else [ "stt" ] for i, task in enumerate( tasks ) ]
resps = [ None if task != "stt" else resps_list[0].to(device) for i, task in enumerate( tasks ) ]
return texts, proms, resps, tasks
texts = []
proms = []
resps = []
@ -523,25 +550,32 @@ def example_usage():
for i in range(batch_size):
task = random.choice(available_tasks) if t is None else t
text = text_list[i]
prom = proms_list[i]
resp = resps_list[i]
text = text_list[i].to(device)
prom = proms_list[i].to(device)
resp = resps_list[i].to(device)
# do nothing
if task == "tts":
...
elif task == "stt":
...
prom = [
task
]
# to-do: reimplement this from data.py
"""
elif task == "tts-c":
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
prom = resp[:trim_length]
resp = resp[trim_length:]
prom = prom.to(device)
elif task == "ns" or task == "sr":
# extend the noise to fill the target audio
noise_ext = repeat_extend_audio( noise, resp.shape[0] )
# create the input prompt by merging the target audio with the noise
prom = merge_audio( resp.cpu(), noise_ext, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
prom = prom.to(device)
# set the target to just be the noise if <sr>
if task == "sr":
resp = noise_ext
@ -550,9 +584,15 @@ def example_usage():
if random.random() < 0.5:
text = torch.tensor([bos_id, eos_id], device=device, dtype=torch.uint8)
texts.append( text.to(device) )
proms.append( prom.to(device) )
resps.append( resp.to(device) )
prom = [
task,
prom,
]
"""
texts.append( text )
proms.append( prom )
resps.append( resp )
tasks.append( task )
return texts, proms, resps, tasks
@ -563,25 +603,25 @@ def example_usage():
texts, proms, resps, tasks = sample_data( task )
if tasks[0] == "stt":
text = engine( None, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 )
"""
# to-do: STT for NAR
text = engine( text, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 )
"""
text = [ cfg.tokenizer.decode( t ) for t in text ]
if "ar" in cfg.model.capabilities:
output = engine( texts, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 )
print( text )
text = [ cfg.tokenizer.decode( output[i] ) for i, task in enumerate( tasks ) if task == "stt" ]
texts = [ texts[i] for i, task in enumerate( tasks ) if task != "stt" ]
proms = [ proms[i] for i, task in enumerate( tasks ) if task != "stt" ]
resps = [ output[i] for i, task in enumerate( tasks ) if task != "stt" ]
tasks = [ tasks[i] for i, task in enumerate( tasks ) if task != "stt" ]
print( "STT:", text )
else:
if "ar" in cfg.model.capabilities:
resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 )
else:
resps = [ resp[:, 0] for resp in resps ]
resps = [ resp[:, 0] for resp in resps ]
if "nar" in cfg.model.capabilities:
resps = engine( texts, proms, resps, sampling_temperature=0.2 )
for i, o in enumerate(resps):
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device)
if "nar" in cfg.model.capabilities:
resps = engine( texts, proms, resps, task_list=tasks, sampling_temperature=0.2 )
for i, o in enumerate(resps):
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device)
unload_model()
@ -611,8 +651,11 @@ def example_usage():
model = ml.compile_model(model, backend=cfg.optimizations.compile)
"""
"""
for task in available_tasks:
sample("final", task=task)
"""
sample("final", task=available_tasks)
engines.quit()

View File

@ -108,24 +108,35 @@ def run_eval(engines, eval_name, dl):
for name in engines:
engine = engines[name]
# to-do: eval for text tasks
for i, task in batch["task"]:
if task == "stt":
batch["task"][i] = "tts"
kwargs = dict(
text_list=batch["text"],
prom_list=batch["proms"],
lang_list=batch["lang"],
task_list=batch["task"],
)
if engine.hyper_config.experimental.hf:
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"] )
resps_list = engine( **kwargs )
elif "len" in engine.hyper_config.capabilities:
len_list = engine(text_list=batch["text"], proms_list=batch["proms"], max_steps=10 ) # don't need more than that
len_list = engine( **kwargs, max_steps=10 ) # don't need more than that
len_list = [ min( l, cfg.evaluation.steps ) for l in len_list ]
resps_list = engine( text_list=batch["text"], proms_list=batch["proms"], len_list=len_list, max_levels=cfg.evaluation.nar_levels )
resps_list = engine( **kwargs, len_list=len_list, max_levels=cfg.evaluation.nar_levels )
else:
if "ar" in engine.hyper_config.capabilities:
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
resps_list = engine( **kwargs, max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
else:
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
if "nar" in engine.hyper_config.capabilities:
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels )
resps_list = engine( **kwargs, resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels )
process( name, batch, resps_list )
stats = {k: sum(v) / len(v) for k, v in stats.items()}
engines_stats = {
f'{name}.{eval_name}': stats,