(finally) added parallel AR for cfg.model.version >= 7 (nvidia/audio-codec-44khz is being a pain and it might require training purely AR first......)
This commit is contained in:
parent
15b3c20e19
commit
67a6009555
|
@ -757,21 +757,20 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
|||
if len(metadata) == 0:
|
||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
|
||||
|
||||
# this might be slow
|
||||
def _exists( id, entry ):
|
||||
if not cfg.dataset.strict_validate:
|
||||
return True
|
||||
|
||||
if cfg.dataset.use_hdf5:
|
||||
return key(id, entry) in cfg.hdf5
|
||||
|
||||
return (data_dir / id).with_suffix(_get_artifact_extension()).exists()
|
||||
|
||||
def _validate( id, entry ):
|
||||
phones = entry['phones'] if "phones" in entry else 0
|
||||
duration = entry['duration'] if "duration" in entry else 0
|
||||
|
||||
k = key(id, entry)
|
||||
|
||||
# double check if in HDF5
|
||||
# this might be slow
|
||||
if cfg.dataset.strict_validate:
|
||||
if cfg.dataset.use_hdf5:
|
||||
if k not in cfg.hdf5:
|
||||
return False
|
||||
elif not (data_dir / id).with_suffix(_get_artifact_extension()).exists():
|
||||
return False
|
||||
|
||||
# add to duration bucket
|
||||
if type not in _durations_map:
|
||||
_durations_map[type] = {}
|
||||
|
@ -780,7 +779,11 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
|||
if not validate:
|
||||
return True
|
||||
|
||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
|
||||
in_bounds = cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
|
||||
if in_bounds and not _exists( id, entry ):
|
||||
return False
|
||||
|
||||
return in_bounds
|
||||
|
||||
return [ key(id, entry) for id, entry in metadata.items() if _validate(id, entry) ]
|
||||
|
||||
|
|
|
@ -1042,6 +1042,158 @@ class AR_NAR(Base):
|
|||
|
||||
return sequence_list
|
||||
|
||||
def forward_ar_parallel(
|
||||
self,
|
||||
|
||||
task_list: list[Tensor],
|
||||
|
||||
text_list: list[Tensor] | None = None,
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
proms_list: list[Tensor] | None = None,
|
||||
resps_list: list[Tensor] | None = None,
|
||||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
|
||||
disable_tqdm=False,
|
||||
use_lora=None,
|
||||
**sampling_kwargs,
|
||||
):
|
||||
# deduce batch_size
|
||||
if text_list:
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
elif raw_text_list:
|
||||
device = raw_text_list[0].device
|
||||
batch_size = len(raw_text_list)
|
||||
elif proms_list:
|
||||
device = proms_list[0].device
|
||||
batch_size = len(proms_list)
|
||||
elif resps_list:
|
||||
device = resps_list[0].device
|
||||
batch_size = len(resps_list)
|
||||
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
|
||||
|
||||
# convert AR specific args
|
||||
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
|
||||
|
||||
temperature = sampling_kwargs.get("temperature", 1.0)
|
||||
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
|
||||
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7)
|
||||
min_temperature = sampling_kwargs.get("min_temperature", -1.0)
|
||||
max_duration = sampling_kwargs.get("max_duration", 500)
|
||||
beam_width = sampling_kwargs.get("beam_width", 0)
|
||||
entropix_sampling = sampling_kwargs.get("entropix_sampling", False)
|
||||
refine_on_stop = sampling_kwargs.get("refine_on_stop", False)
|
||||
input_prompt_prefix = sampling_kwargs.get("input_prompt_prefix", False)
|
||||
layer_skip = sampling_kwargs.get("layer_skip", False)
|
||||
prefix_silence = sampling_kwargs.get("prefix_silence", 0.0)
|
||||
mirostat_tau = sampling_kwargs.get("mirostat_tau", 0.0)
|
||||
mirostat_eta = sampling_kwargs.get("mirostat_eta", 0.0)
|
||||
|
||||
start_slice = [ 0 for _ in range(batch_size) ]
|
||||
sequence_list = [ torch.zeros((0, 8), device=device).to(torch.int16) for _ in range(batch_size) ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
audio_stop_token = self.stop_token
|
||||
text_stop_token = 2
|
||||
|
||||
state = None
|
||||
mirostat = [
|
||||
{"n": 1024, "tau": mirostat_tau, "eta": mirostat_eta, "max_surprise": mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
|
||||
] * batch_size if mirostat_tau > 0.0 else None
|
||||
|
||||
scores = [ 1.0 ] * beam_width
|
||||
metrics = []
|
||||
|
||||
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
|
||||
null_prom = [ None for _ in range(batch_size) ]
|
||||
|
||||
# get next in sequence
|
||||
iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm)
|
||||
for n in iterator:
|
||||
if raw_text_list is not None:
|
||||
raw_text_list = [ sequence_list[i] if task in text_task else raw_text_list[i] for i, task in enumerate(task_list) ]
|
||||
else:
|
||||
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
|
||||
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
|
||||
|
||||
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
|
||||
|
||||
inputs = self.inputs(
|
||||
task_list=task_list,
|
||||
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
# to-do: find an elegant way to write this
|
||||
output = super().forward(
|
||||
inputs=inputs,
|
||||
state=state,
|
||||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
output_attentions=entropix_sampling,
|
||||
)
|
||||
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
text_list=null_text,
|
||||
proms_list=null_prom,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
null_output = super().forward(
|
||||
inputs=null_inputs,
|
||||
quant_levels=quant_levels,
|
||||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ resp.shape[0] + 1 for resp in resps_list ] )
|
||||
|
||||
logits, state = output.logits, output.state
|
||||
|
||||
l_resps_list = [ [] for _ in range(batch_size) ]
|
||||
for l in range(self.n_resp_levels):
|
||||
sampled = super().sample(
|
||||
logits=[ logit[l] for logit in logits ],
|
||||
prev_list=[ resp[..., l] for resp in resps_list ],
|
||||
**(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}),
|
||||
)
|
||||
|
||||
ids = sampled.ids
|
||||
|
||||
# append tokens
|
||||
for i, token in enumerate(ids):
|
||||
if audio_stop_token in token:
|
||||
stopped[i] = True
|
||||
l_resps_list[i].append(token.to(device))
|
||||
|
||||
for i, l in enumerate(l_resps_list):
|
||||
sequence_list[i] = torch.cat([sequence_list[i], torch.stack(l, dim=-1)])
|
||||
|
||||
# stop token found
|
||||
# stopped |= r == stop_token
|
||||
if stopped.all().item():
|
||||
iterator.close()
|
||||
break
|
||||
|
||||
for i, l in enumerate( sequence_list ):
|
||||
index = (l == audio_stop_token).nonzero()[:, 0].min()
|
||||
sequence_list[i] = sequence_list[i][:index]
|
||||
|
||||
return sequence_list
|
||||
|
||||
def forward(
|
||||
self,
|
||||
task_list: list[Tensor] | None = None,
|
||||
|
@ -1169,6 +1321,25 @@ class AR_NAR(Base):
|
|||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
if self.version >= 7:
|
||||
if task_list is None or task_list[0] != "len":
|
||||
return self.forward_ar_parallel(
|
||||
task_list=task_list,
|
||||
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
|
||||
disable_tqdm=disable_tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
# is AR
|
||||
return self.forward_ar(
|
||||
task_list=task_list,
|
||||
|
@ -1407,7 +1578,8 @@ def example_usage():
|
|||
resps_list = engine( text_list=text_list, proms_list=proms_list, len_list=len_list )
|
||||
else:
|
||||
resps_list = engine( text_list=text_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )
|
||||
resps_list = engine( text_list=text_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 )
|
||||
if resps_list[0].dim() == 1 or resps_list[0].shape[-1] == 1:
|
||||
resps_list = engine( text_list=text_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 )
|
||||
|
||||
for i, o in enumerate(resps_list):
|
||||
print( o.shape, o )
|
||||
|
@ -1444,7 +1616,7 @@ def example_usage():
|
|||
"""
|
||||
|
||||
for task in available_tasks:
|
||||
sample("final", task="tts-nar")
|
||||
sample("final", task=task)
|
||||
|
||||
engines.quit()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user