cleanup for AR_NAR inferencing to allow both TTS and STT tasks simultaneously (need to have training eval do this to though)
This commit is contained in:
parent
341e19162b
commit
d33a906119
|
@ -76,8 +76,15 @@ class AR_NAR(Base):
|
|||
if task_list is None:
|
||||
task_list = [ default_task for _ in range(batch_size) ]
|
||||
|
||||
has_none = resps_list is None or text_list is None
|
||||
if not has_none:
|
||||
for i, task in enumerate( task_list ):
|
||||
if resps_list[i] is None or text_list[i] is None:
|
||||
has_none = True
|
||||
break
|
||||
|
||||
# is training or NAR
|
||||
if resps_list is not None and text_list is not None:
|
||||
if not has_none:
|
||||
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||
n_levels = next(iter(n_levels_set))
|
||||
|
||||
|
@ -241,7 +248,8 @@ class AR_NAR(Base):
|
|||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
stop_token = self.stop_token if task_list[0] != "stt" else 2 # to-do: derive from tokenizer
|
||||
audio_stop_token = self.stop_token
|
||||
text_stop_token = 2
|
||||
|
||||
state = None
|
||||
mirostat = [
|
||||
|
@ -257,10 +265,15 @@ class AR_NAR(Base):
|
|||
|
||||
# get next in sequence
|
||||
for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
|
||||
if task_list[0] in text_task:
|
||||
text_list = [x for x in sequence_list]
|
||||
else:
|
||||
resps_list = [x.unsqueeze(dim=-1) for x in sequence_list]
|
||||
#
|
||||
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
|
||||
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
|
||||
|
||||
"""
|
||||
print( "task_list:", task_list )
|
||||
print( "text_list:", text_list )
|
||||
print( "resps_list:", resps_list )
|
||||
"""
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
|
@ -286,7 +299,7 @@ class AR_NAR(Base):
|
|||
|
||||
r = super().sample(
|
||||
logits=logits,
|
||||
prev_list=resps_list,
|
||||
prev_list=[ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ],
|
||||
|
||||
temperature=sampling_temperature,
|
||||
min_temperature=sampling_min_temperature,
|
||||
|
@ -325,12 +338,14 @@ class AR_NAR(Base):
|
|||
|
||||
# append tokens
|
||||
for i, ri in enumerate(r):
|
||||
task = task_list[i]
|
||||
stop_token = audio_stop_token if task not in text_task else text_stop_token
|
||||
if stop_token in ri:
|
||||
stopped[i] = True
|
||||
sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)])
|
||||
|
||||
# stop token found
|
||||
stopped |= r == stop_token
|
||||
# stopped |= r == stop_token
|
||||
if stopped.all().item():
|
||||
break
|
||||
|
||||
|
@ -339,7 +354,10 @@ class AR_NAR(Base):
|
|||
if sampling_beam_width:
|
||||
sequence_list = [ sequence_list[0] ]
|
||||
|
||||
sequence_list = [self._prune(r, stop_token) for r in sequence_list]
|
||||
# remove stop token
|
||||
sequence_list = [self._prune(r, audio_stop_token if task_list[i] not in text_task else text_stop_token) for i, r in enumerate(sequence_list)]
|
||||
# remove <bos>
|
||||
sequence_list = [ sequence_list[i] if task not in text_task else sequence_list[i][1:] for i, task in enumerate( task_list ) ]
|
||||
return sequence_list
|
||||
|
||||
|
||||
|
@ -426,7 +444,8 @@ def example_usage():
|
|||
"""
|
||||
|
||||
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
|
||||
available_tasks = cfg.dataset.tasks_list
|
||||
#available_tasks = cfg.dataset.tasks_list
|
||||
available_tasks = ["tts", "stt"]
|
||||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
steps = 150 * len(available_tasks) # * cfg.model.experimental.causal_size
|
||||
|
@ -515,6 +534,14 @@ def example_usage():
|
|||
|
||||
@torch.no_grad()
|
||||
def sample_data(t=None):
|
||||
if isinstance(t, list):
|
||||
tasks = t
|
||||
texts = [ text_list[0].to(device) if task != "stt" else None for i, task in enumerate( tasks ) ]
|
||||
proms = [ proms_list[0].to(device) if task != "stt" else [ "stt" ] for i, task in enumerate( tasks ) ]
|
||||
resps = [ None if task != "stt" else resps_list[0].to(device) for i, task in enumerate( tasks ) ]
|
||||
|
||||
return texts, proms, resps, tasks
|
||||
|
||||
texts = []
|
||||
proms = []
|
||||
resps = []
|
||||
|
@ -523,25 +550,32 @@ def example_usage():
|
|||
for i in range(batch_size):
|
||||
task = random.choice(available_tasks) if t is None else t
|
||||
|
||||
text = text_list[i]
|
||||
prom = proms_list[i]
|
||||
resp = resps_list[i]
|
||||
text = text_list[i].to(device)
|
||||
prom = proms_list[i].to(device)
|
||||
resp = resps_list[i].to(device)
|
||||
|
||||
# do nothing
|
||||
if task == "tts":
|
||||
...
|
||||
elif task == "stt":
|
||||
...
|
||||
prom = [
|
||||
task
|
||||
]
|
||||
# to-do: reimplement this from data.py
|
||||
"""
|
||||
elif task == "tts-c":
|
||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
||||
|
||||
prom = resp[:trim_length]
|
||||
resp = resp[trim_length:]
|
||||
|
||||
prom = prom.to(device)
|
||||
elif task == "ns" or task == "sr":
|
||||
# extend the noise to fill the target audio
|
||||
noise_ext = repeat_extend_audio( noise, resp.shape[0] )
|
||||
# create the input prompt by merging the target audio with the noise
|
||||
prom = merge_audio( resp.cpu(), noise_ext, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
|
||||
prom = prom.to(device)
|
||||
# set the target to just be the noise if <sr>
|
||||
if task == "sr":
|
||||
resp = noise_ext
|
||||
|
@ -550,9 +584,15 @@ def example_usage():
|
|||
if random.random() < 0.5:
|
||||
text = torch.tensor([bos_id, eos_id], device=device, dtype=torch.uint8)
|
||||
|
||||
texts.append( text.to(device) )
|
||||
proms.append( prom.to(device) )
|
||||
resps.append( resp.to(device) )
|
||||
prom = [
|
||||
task,
|
||||
prom,
|
||||
]
|
||||
"""
|
||||
|
||||
texts.append( text )
|
||||
proms.append( prom )
|
||||
resps.append( resp )
|
||||
tasks.append( task )
|
||||
|
||||
return texts, proms, resps, tasks
|
||||
|
@ -563,25 +603,25 @@ def example_usage():
|
|||
|
||||
texts, proms, resps, tasks = sample_data( task )
|
||||
|
||||
if tasks[0] == "stt":
|
||||
text = engine( None, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 )
|
||||
"""
|
||||
# to-do: STT for NAR
|
||||
text = engine( text, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 )
|
||||
"""
|
||||
text = [ cfg.tokenizer.decode( t ) for t in text ]
|
||||
if "ar" in cfg.model.capabilities:
|
||||
output = engine( texts, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 )
|
||||
|
||||
print( text )
|
||||
text = [ cfg.tokenizer.decode( output[i] ) for i, task in enumerate( tasks ) if task == "stt" ]
|
||||
|
||||
texts = [ texts[i] for i, task in enumerate( tasks ) if task != "stt" ]
|
||||
proms = [ proms[i] for i, task in enumerate( tasks ) if task != "stt" ]
|
||||
resps = [ output[i] for i, task in enumerate( tasks ) if task != "stt" ]
|
||||
tasks = [ tasks[i] for i, task in enumerate( tasks ) if task != "stt" ]
|
||||
|
||||
print( "STT:", text )
|
||||
else:
|
||||
if "ar" in cfg.model.capabilities:
|
||||
resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 )
|
||||
else:
|
||||
resps = [ resp[:, 0] for resp in resps ]
|
||||
resps = [ resp[:, 0] for resp in resps ]
|
||||
|
||||
if "nar" in cfg.model.capabilities:
|
||||
resps = engine( texts, proms, resps, sampling_temperature=0.2 )
|
||||
for i, o in enumerate(resps):
|
||||
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device)
|
||||
if "nar" in cfg.model.capabilities:
|
||||
resps = engine( texts, proms, resps, task_list=tasks, sampling_temperature=0.2 )
|
||||
|
||||
for i, o in enumerate(resps):
|
||||
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device)
|
||||
|
||||
unload_model()
|
||||
|
||||
|
@ -611,8 +651,11 @@ def example_usage():
|
|||
model = ml.compile_model(model, backend=cfg.optimizations.compile)
|
||||
"""
|
||||
|
||||
"""
|
||||
for task in available_tasks:
|
||||
sample("final", task=task)
|
||||
"""
|
||||
sample("final", task=available_tasks)
|
||||
|
||||
engines.quit()
|
||||
|
||||
|
|
|
@ -108,24 +108,35 @@ def run_eval(engines, eval_name, dl):
|
|||
for name in engines:
|
||||
engine = engines[name]
|
||||
|
||||
# to-do: eval for text tasks
|
||||
for i, task in batch["task"]:
|
||||
if task == "stt":
|
||||
batch["task"][i] = "tts"
|
||||
|
||||
kwargs = dict(
|
||||
text_list=batch["text"],
|
||||
prom_list=batch["proms"],
|
||||
lang_list=batch["lang"],
|
||||
task_list=batch["task"],
|
||||
)
|
||||
|
||||
if engine.hyper_config.experimental.hf:
|
||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"] )
|
||||
resps_list = engine( **kwargs )
|
||||
elif "len" in engine.hyper_config.capabilities:
|
||||
len_list = engine(text_list=batch["text"], proms_list=batch["proms"], max_steps=10 ) # don't need more than that
|
||||
len_list = engine( **kwargs, max_steps=10 ) # don't need more than that
|
||||
len_list = [ min( l, cfg.evaluation.steps ) for l in len_list ]
|
||||
resps_list = engine( text_list=batch["text"], proms_list=batch["proms"], len_list=len_list, max_levels=cfg.evaluation.nar_levels )
|
||||
resps_list = engine( **kwargs, len_list=len_list, max_levels=cfg.evaluation.nar_levels )
|
||||
else:
|
||||
if "ar" in engine.hyper_config.capabilities:
|
||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||
resps_list = engine( **kwargs, max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||
else:
|
||||
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
||||
|
||||
if "nar" in engine.hyper_config.capabilities:
|
||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels )
|
||||
resps_list = engine( **kwargs, resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels )
|
||||
|
||||
process( name, batch, resps_list )
|
||||
|
||||
|
||||
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||||
engines_stats = {
|
||||
f'{name}.{eval_name}': stats,
|
||||
|
|
Loading…
Reference in New Issue
Block a user