cleanup for AR_NAR inferencing to allow both TTS and STT tasks simultaneously (need to have training eval do this to though)
This commit is contained in:
parent
341e19162b
commit
d33a906119
|
@ -76,8 +76,15 @@ class AR_NAR(Base):
|
||||||
if task_list is None:
|
if task_list is None:
|
||||||
task_list = [ default_task for _ in range(batch_size) ]
|
task_list = [ default_task for _ in range(batch_size) ]
|
||||||
|
|
||||||
|
has_none = resps_list is None or text_list is None
|
||||||
|
if not has_none:
|
||||||
|
for i, task in enumerate( task_list ):
|
||||||
|
if resps_list[i] is None or text_list[i] is None:
|
||||||
|
has_none = True
|
||||||
|
break
|
||||||
|
|
||||||
# is training or NAR
|
# is training or NAR
|
||||||
if resps_list is not None and text_list is not None:
|
if not has_none:
|
||||||
n_levels_set = {r.shape[-1] for r in resps_list}
|
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||||
n_levels = next(iter(n_levels_set))
|
n_levels = next(iter(n_levels_set))
|
||||||
|
|
||||||
|
@ -241,7 +248,8 @@ class AR_NAR(Base):
|
||||||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
|
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
|
||||||
stopped = torch.zeros(batch_size, device=device).bool()
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
||||||
stop_token = self.stop_token if task_list[0] != "stt" else 2 # to-do: derive from tokenizer
|
audio_stop_token = self.stop_token
|
||||||
|
text_stop_token = 2
|
||||||
|
|
||||||
state = None
|
state = None
|
||||||
mirostat = [
|
mirostat = [
|
||||||
|
@ -257,10 +265,15 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
# get next in sequence
|
# get next in sequence
|
||||||
for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
|
for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
|
||||||
if task_list[0] in text_task:
|
#
|
||||||
text_list = [x for x in sequence_list]
|
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
|
||||||
else:
|
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
|
||||||
resps_list = [x.unsqueeze(dim=-1) for x in sequence_list]
|
|
||||||
|
"""
|
||||||
|
print( "task_list:", task_list )
|
||||||
|
print( "text_list:", text_list )
|
||||||
|
print( "resps_list:", resps_list )
|
||||||
|
"""
|
||||||
|
|
||||||
inputs = self.inputs(
|
inputs = self.inputs(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
|
@ -286,7 +299,7 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
r = super().sample(
|
r = super().sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
prev_list=resps_list,
|
prev_list=[ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ],
|
||||||
|
|
||||||
temperature=sampling_temperature,
|
temperature=sampling_temperature,
|
||||||
min_temperature=sampling_min_temperature,
|
min_temperature=sampling_min_temperature,
|
||||||
|
@ -325,12 +338,14 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
# append tokens
|
# append tokens
|
||||||
for i, ri in enumerate(r):
|
for i, ri in enumerate(r):
|
||||||
|
task = task_list[i]
|
||||||
|
stop_token = audio_stop_token if task not in text_task else text_stop_token
|
||||||
if stop_token in ri:
|
if stop_token in ri:
|
||||||
stopped[i] = True
|
stopped[i] = True
|
||||||
sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)])
|
sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)])
|
||||||
|
|
||||||
# stop token found
|
# stop token found
|
||||||
stopped |= r == stop_token
|
# stopped |= r == stop_token
|
||||||
if stopped.all().item():
|
if stopped.all().item():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -339,7 +354,10 @@ class AR_NAR(Base):
|
||||||
if sampling_beam_width:
|
if sampling_beam_width:
|
||||||
sequence_list = [ sequence_list[0] ]
|
sequence_list = [ sequence_list[0] ]
|
||||||
|
|
||||||
sequence_list = [self._prune(r, stop_token) for r in sequence_list]
|
# remove stop token
|
||||||
|
sequence_list = [self._prune(r, audio_stop_token if task_list[i] not in text_task else text_stop_token) for i, r in enumerate(sequence_list)]
|
||||||
|
# remove <bos>
|
||||||
|
sequence_list = [ sequence_list[i] if task not in text_task else sequence_list[i][1:] for i, task in enumerate( task_list ) ]
|
||||||
return sequence_list
|
return sequence_list
|
||||||
|
|
||||||
|
|
||||||
|
@ -426,7 +444,8 @@ def example_usage():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
|
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
|
||||||
available_tasks = cfg.dataset.tasks_list
|
#available_tasks = cfg.dataset.tasks_list
|
||||||
|
available_tasks = ["tts", "stt"]
|
||||||
|
|
||||||
model = AR_NAR(**kwargs).to(device)
|
model = AR_NAR(**kwargs).to(device)
|
||||||
steps = 150 * len(available_tasks) # * cfg.model.experimental.causal_size
|
steps = 150 * len(available_tasks) # * cfg.model.experimental.causal_size
|
||||||
|
@ -515,6 +534,14 @@ def example_usage():
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_data(t=None):
|
def sample_data(t=None):
|
||||||
|
if isinstance(t, list):
|
||||||
|
tasks = t
|
||||||
|
texts = [ text_list[0].to(device) if task != "stt" else None for i, task in enumerate( tasks ) ]
|
||||||
|
proms = [ proms_list[0].to(device) if task != "stt" else [ "stt" ] for i, task in enumerate( tasks ) ]
|
||||||
|
resps = [ None if task != "stt" else resps_list[0].to(device) for i, task in enumerate( tasks ) ]
|
||||||
|
|
||||||
|
return texts, proms, resps, tasks
|
||||||
|
|
||||||
texts = []
|
texts = []
|
||||||
proms = []
|
proms = []
|
||||||
resps = []
|
resps = []
|
||||||
|
@ -523,25 +550,32 @@ def example_usage():
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
task = random.choice(available_tasks) if t is None else t
|
task = random.choice(available_tasks) if t is None else t
|
||||||
|
|
||||||
text = text_list[i]
|
text = text_list[i].to(device)
|
||||||
prom = proms_list[i]
|
prom = proms_list[i].to(device)
|
||||||
resp = resps_list[i]
|
resp = resps_list[i].to(device)
|
||||||
|
|
||||||
# do nothing
|
# do nothing
|
||||||
if task == "tts":
|
if task == "tts":
|
||||||
...
|
...
|
||||||
elif task == "stt":
|
elif task == "stt":
|
||||||
...
|
prom = [
|
||||||
|
task
|
||||||
|
]
|
||||||
|
# to-do: reimplement this from data.py
|
||||||
|
"""
|
||||||
elif task == "tts-c":
|
elif task == "tts-c":
|
||||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
||||||
|
|
||||||
prom = resp[:trim_length]
|
prom = resp[:trim_length]
|
||||||
resp = resp[trim_length:]
|
resp = resp[trim_length:]
|
||||||
|
|
||||||
|
prom = prom.to(device)
|
||||||
elif task == "ns" or task == "sr":
|
elif task == "ns" or task == "sr":
|
||||||
# extend the noise to fill the target audio
|
# extend the noise to fill the target audio
|
||||||
noise_ext = repeat_extend_audio( noise, resp.shape[0] )
|
noise_ext = repeat_extend_audio( noise, resp.shape[0] )
|
||||||
# create the input prompt by merging the target audio with the noise
|
# create the input prompt by merging the target audio with the noise
|
||||||
prom = merge_audio( resp.cpu(), noise_ext, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
|
prom = merge_audio( resp.cpu(), noise_ext, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
|
||||||
|
prom = prom.to(device)
|
||||||
# set the target to just be the noise if <sr>
|
# set the target to just be the noise if <sr>
|
||||||
if task == "sr":
|
if task == "sr":
|
||||||
resp = noise_ext
|
resp = noise_ext
|
||||||
|
@ -550,9 +584,15 @@ def example_usage():
|
||||||
if random.random() < 0.5:
|
if random.random() < 0.5:
|
||||||
text = torch.tensor([bos_id, eos_id], device=device, dtype=torch.uint8)
|
text = torch.tensor([bos_id, eos_id], device=device, dtype=torch.uint8)
|
||||||
|
|
||||||
texts.append( text.to(device) )
|
prom = [
|
||||||
proms.append( prom.to(device) )
|
task,
|
||||||
resps.append( resp.to(device) )
|
prom,
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
texts.append( text )
|
||||||
|
proms.append( prom )
|
||||||
|
resps.append( resp )
|
||||||
tasks.append( task )
|
tasks.append( task )
|
||||||
|
|
||||||
return texts, proms, resps, tasks
|
return texts, proms, resps, tasks
|
||||||
|
@ -563,23 +603,23 @@ def example_usage():
|
||||||
|
|
||||||
texts, proms, resps, tasks = sample_data( task )
|
texts, proms, resps, tasks = sample_data( task )
|
||||||
|
|
||||||
if tasks[0] == "stt":
|
|
||||||
text = engine( None, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 )
|
|
||||||
"""
|
|
||||||
# to-do: STT for NAR
|
|
||||||
text = engine( text, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 )
|
|
||||||
"""
|
|
||||||
text = [ cfg.tokenizer.decode( t ) for t in text ]
|
|
||||||
|
|
||||||
print( text )
|
|
||||||
else:
|
|
||||||
if "ar" in cfg.model.capabilities:
|
if "ar" in cfg.model.capabilities:
|
||||||
resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 )
|
output = engine( texts, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 )
|
||||||
|
|
||||||
|
text = [ cfg.tokenizer.decode( output[i] ) for i, task in enumerate( tasks ) if task == "stt" ]
|
||||||
|
|
||||||
|
texts = [ texts[i] for i, task in enumerate( tasks ) if task != "stt" ]
|
||||||
|
proms = [ proms[i] for i, task in enumerate( tasks ) if task != "stt" ]
|
||||||
|
resps = [ output[i] for i, task in enumerate( tasks ) if task != "stt" ]
|
||||||
|
tasks = [ tasks[i] for i, task in enumerate( tasks ) if task != "stt" ]
|
||||||
|
|
||||||
|
print( "STT:", text )
|
||||||
else:
|
else:
|
||||||
resps = [ resp[:, 0] for resp in resps ]
|
resps = [ resp[:, 0] for resp in resps ]
|
||||||
|
|
||||||
if "nar" in cfg.model.capabilities:
|
if "nar" in cfg.model.capabilities:
|
||||||
resps = engine( texts, proms, resps, sampling_temperature=0.2 )
|
resps = engine( texts, proms, resps, task_list=tasks, sampling_temperature=0.2 )
|
||||||
|
|
||||||
for i, o in enumerate(resps):
|
for i, o in enumerate(resps):
|
||||||
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device)
|
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device)
|
||||||
|
|
||||||
|
@ -611,8 +651,11 @@ def example_usage():
|
||||||
model = ml.compile_model(model, backend=cfg.optimizations.compile)
|
model = ml.compile_model(model, backend=cfg.optimizations.compile)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
for task in available_tasks:
|
for task in available_tasks:
|
||||||
sample("final", task=task)
|
sample("final", task=task)
|
||||||
|
"""
|
||||||
|
sample("final", task=available_tasks)
|
||||||
|
|
||||||
engines.quit()
|
engines.quit()
|
||||||
|
|
||||||
|
|
|
@ -108,24 +108,35 @@ def run_eval(engines, eval_name, dl):
|
||||||
for name in engines:
|
for name in engines:
|
||||||
engine = engines[name]
|
engine = engines[name]
|
||||||
|
|
||||||
|
# to-do: eval for text tasks
|
||||||
|
for i, task in batch["task"]:
|
||||||
|
if task == "stt":
|
||||||
|
batch["task"][i] = "tts"
|
||||||
|
|
||||||
|
kwargs = dict(
|
||||||
|
text_list=batch["text"],
|
||||||
|
prom_list=batch["proms"],
|
||||||
|
lang_list=batch["lang"],
|
||||||
|
task_list=batch["task"],
|
||||||
|
)
|
||||||
|
|
||||||
if engine.hyper_config.experimental.hf:
|
if engine.hyper_config.experimental.hf:
|
||||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"] )
|
resps_list = engine( **kwargs )
|
||||||
elif "len" in engine.hyper_config.capabilities:
|
elif "len" in engine.hyper_config.capabilities:
|
||||||
len_list = engine(text_list=batch["text"], proms_list=batch["proms"], max_steps=10 ) # don't need more than that
|
len_list = engine( **kwargs, max_steps=10 ) # don't need more than that
|
||||||
len_list = [ min( l, cfg.evaluation.steps ) for l in len_list ]
|
len_list = [ min( l, cfg.evaluation.steps ) for l in len_list ]
|
||||||
resps_list = engine( text_list=batch["text"], proms_list=batch["proms"], len_list=len_list, max_levels=cfg.evaluation.nar_levels )
|
resps_list = engine( **kwargs, len_list=len_list, max_levels=cfg.evaluation.nar_levels )
|
||||||
else:
|
else:
|
||||||
if "ar" in engine.hyper_config.capabilities:
|
if "ar" in engine.hyper_config.capabilities:
|
||||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
resps_list = engine( **kwargs, max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||||
else:
|
else:
|
||||||
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
||||||
|
|
||||||
if "nar" in engine.hyper_config.capabilities:
|
if "nar" in engine.hyper_config.capabilities:
|
||||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels )
|
resps_list = engine( **kwargs, resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels )
|
||||||
|
|
||||||
process( name, batch, resps_list )
|
process( name, batch, resps_list )
|
||||||
|
|
||||||
|
|
||||||
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||||||
engines_stats = {
|
engines_stats = {
|
||||||
f'{name}.{eval_name}': stats,
|
f'{name}.{eval_name}': stats,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user