(finally) added parallel AR for cfg.model.version >= 7 (nvidia/audio-codec-44khz is being a pain and it might require training purely AR first......)
This commit is contained in:
parent
15b3c20e19
commit
67a6009555
|
@ -757,21 +757,20 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
||||||
if len(metadata) == 0:
|
if len(metadata) == 0:
|
||||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
|
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
|
||||||
|
|
||||||
|
# this might be slow
|
||||||
|
def _exists( id, entry ):
|
||||||
|
if not cfg.dataset.strict_validate:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if cfg.dataset.use_hdf5:
|
||||||
|
return key(id, entry) in cfg.hdf5
|
||||||
|
|
||||||
|
return (data_dir / id).with_suffix(_get_artifact_extension()).exists()
|
||||||
|
|
||||||
def _validate( id, entry ):
|
def _validate( id, entry ):
|
||||||
phones = entry['phones'] if "phones" in entry else 0
|
phones = entry['phones'] if "phones" in entry else 0
|
||||||
duration = entry['duration'] if "duration" in entry else 0
|
duration = entry['duration'] if "duration" in entry else 0
|
||||||
|
|
||||||
k = key(id, entry)
|
|
||||||
|
|
||||||
# double check if in HDF5
|
|
||||||
# this might be slow
|
|
||||||
if cfg.dataset.strict_validate:
|
|
||||||
if cfg.dataset.use_hdf5:
|
|
||||||
if k not in cfg.hdf5:
|
|
||||||
return False
|
|
||||||
elif not (data_dir / id).with_suffix(_get_artifact_extension()).exists():
|
|
||||||
return False
|
|
||||||
|
|
||||||
# add to duration bucket
|
# add to duration bucket
|
||||||
if type not in _durations_map:
|
if type not in _durations_map:
|
||||||
_durations_map[type] = {}
|
_durations_map[type] = {}
|
||||||
|
@ -780,7 +779,11 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
||||||
if not validate:
|
if not validate:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
|
in_bounds = cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
|
||||||
|
if in_bounds and not _exists( id, entry ):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return in_bounds
|
||||||
|
|
||||||
return [ key(id, entry) for id, entry in metadata.items() if _validate(id, entry) ]
|
return [ key(id, entry) for id, entry in metadata.items() if _validate(id, entry) ]
|
||||||
|
|
||||||
|
|
|
@ -1042,6 +1042,158 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
return sequence_list
|
return sequence_list
|
||||||
|
|
||||||
|
def forward_ar_parallel(
|
||||||
|
self,
|
||||||
|
|
||||||
|
task_list: list[Tensor],
|
||||||
|
|
||||||
|
text_list: list[Tensor] | None = None,
|
||||||
|
raw_text_list: list[Tensor] | None = None,
|
||||||
|
proms_list: list[Tensor] | None = None,
|
||||||
|
resps_list: list[Tensor] | None = None,
|
||||||
|
lang_list: list[Tensor] | None = None,
|
||||||
|
tone_list: list[Tensor] | None = None,
|
||||||
|
len_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
|
disable_tqdm=False,
|
||||||
|
use_lora=None,
|
||||||
|
**sampling_kwargs,
|
||||||
|
):
|
||||||
|
# deduce batch_size
|
||||||
|
if text_list:
|
||||||
|
device = text_list[0].device
|
||||||
|
batch_size = len(text_list)
|
||||||
|
elif raw_text_list:
|
||||||
|
device = raw_text_list[0].device
|
||||||
|
batch_size = len(raw_text_list)
|
||||||
|
elif proms_list:
|
||||||
|
device = proms_list[0].device
|
||||||
|
batch_size = len(proms_list)
|
||||||
|
elif resps_list:
|
||||||
|
device = resps_list[0].device
|
||||||
|
batch_size = len(resps_list)
|
||||||
|
|
||||||
|
if cfg.lora is not None:
|
||||||
|
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
|
||||||
|
|
||||||
|
# convert AR specific args
|
||||||
|
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
|
||||||
|
|
||||||
|
temperature = sampling_kwargs.get("temperature", 1.0)
|
||||||
|
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
|
||||||
|
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7)
|
||||||
|
min_temperature = sampling_kwargs.get("min_temperature", -1.0)
|
||||||
|
max_duration = sampling_kwargs.get("max_duration", 500)
|
||||||
|
beam_width = sampling_kwargs.get("beam_width", 0)
|
||||||
|
entropix_sampling = sampling_kwargs.get("entropix_sampling", False)
|
||||||
|
refine_on_stop = sampling_kwargs.get("refine_on_stop", False)
|
||||||
|
input_prompt_prefix = sampling_kwargs.get("input_prompt_prefix", False)
|
||||||
|
layer_skip = sampling_kwargs.get("layer_skip", False)
|
||||||
|
prefix_silence = sampling_kwargs.get("prefix_silence", 0.0)
|
||||||
|
mirostat_tau = sampling_kwargs.get("mirostat_tau", 0.0)
|
||||||
|
mirostat_eta = sampling_kwargs.get("mirostat_eta", 0.0)
|
||||||
|
|
||||||
|
start_slice = [ 0 for _ in range(batch_size) ]
|
||||||
|
sequence_list = [ torch.zeros((0, 8), device=device).to(torch.int16) for _ in range(batch_size) ]
|
||||||
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
||||||
|
audio_stop_token = self.stop_token
|
||||||
|
text_stop_token = 2
|
||||||
|
|
||||||
|
state = None
|
||||||
|
mirostat = [
|
||||||
|
{"n": 1024, "tau": mirostat_tau, "eta": mirostat_eta, "max_surprise": mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
|
||||||
|
] * batch_size if mirostat_tau > 0.0 else None
|
||||||
|
|
||||||
|
scores = [ 1.0 ] * beam_width
|
||||||
|
metrics = []
|
||||||
|
|
||||||
|
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
|
||||||
|
null_prom = [ None for _ in range(batch_size) ]
|
||||||
|
|
||||||
|
# get next in sequence
|
||||||
|
iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm)
|
||||||
|
for n in iterator:
|
||||||
|
if raw_text_list is not None:
|
||||||
|
raw_text_list = [ sequence_list[i] if task in text_task else raw_text_list[i] for i, task in enumerate(task_list) ]
|
||||||
|
else:
|
||||||
|
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
|
||||||
|
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
|
||||||
|
|
||||||
|
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
|
||||||
|
|
||||||
|
inputs = self.inputs(
|
||||||
|
task_list=task_list,
|
||||||
|
|
||||||
|
text_list=text_list,
|
||||||
|
proms_list=proms_list,
|
||||||
|
resps_list=resps_list,
|
||||||
|
|
||||||
|
lang_list=lang_list,
|
||||||
|
tone_list=tone_list,
|
||||||
|
len_list=len_list,
|
||||||
|
raw_text_list=raw_text_list,
|
||||||
|
|
||||||
|
quant_levels=quant_levels,
|
||||||
|
)
|
||||||
|
|
||||||
|
# to-do: find an elegant way to write this
|
||||||
|
output = super().forward(
|
||||||
|
inputs=inputs,
|
||||||
|
state=state,
|
||||||
|
#layer_skip_variables=sampling_layer_skip_variables,
|
||||||
|
output_attentions=entropix_sampling,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg_strength > 0:
|
||||||
|
null_inputs = super().inputs(
|
||||||
|
text_list=null_text,
|
||||||
|
proms_list=null_prom,
|
||||||
|
resps_list=resps_list,
|
||||||
|
lang_list=lang_list,
|
||||||
|
tone_list=tone_list,
|
||||||
|
quant_levels=quant_levels,
|
||||||
|
)
|
||||||
|
null_output = super().forward(
|
||||||
|
inputs=null_inputs,
|
||||||
|
quant_levels=quant_levels,
|
||||||
|
#layer_skip_variables=sampling_layer_skip_variables,
|
||||||
|
)
|
||||||
|
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ resp.shape[0] + 1 for resp in resps_list ] )
|
||||||
|
|
||||||
|
logits, state = output.logits, output.state
|
||||||
|
|
||||||
|
l_resps_list = [ [] for _ in range(batch_size) ]
|
||||||
|
for l in range(self.n_resp_levels):
|
||||||
|
sampled = super().sample(
|
||||||
|
logits=[ logit[l] for logit in logits ],
|
||||||
|
prev_list=[ resp[..., l] for resp in resps_list ],
|
||||||
|
**(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}),
|
||||||
|
)
|
||||||
|
|
||||||
|
ids = sampled.ids
|
||||||
|
|
||||||
|
# append tokens
|
||||||
|
for i, token in enumerate(ids):
|
||||||
|
if audio_stop_token in token:
|
||||||
|
stopped[i] = True
|
||||||
|
l_resps_list[i].append(token.to(device))
|
||||||
|
|
||||||
|
for i, l in enumerate(l_resps_list):
|
||||||
|
sequence_list[i] = torch.cat([sequence_list[i], torch.stack(l, dim=-1)])
|
||||||
|
|
||||||
|
# stop token found
|
||||||
|
# stopped |= r == stop_token
|
||||||
|
if stopped.all().item():
|
||||||
|
iterator.close()
|
||||||
|
break
|
||||||
|
|
||||||
|
for i, l in enumerate( sequence_list ):
|
||||||
|
index = (l == audio_stop_token).nonzero()[:, 0].min()
|
||||||
|
sequence_list[i] = sequence_list[i][:index]
|
||||||
|
|
||||||
|
return sequence_list
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
task_list: list[Tensor] | None = None,
|
task_list: list[Tensor] | None = None,
|
||||||
|
@ -1169,6 +1321,25 @@ class AR_NAR(Base):
|
||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.version >= 7:
|
||||||
|
if task_list is None or task_list[0] != "len":
|
||||||
|
return self.forward_ar_parallel(
|
||||||
|
task_list=task_list,
|
||||||
|
|
||||||
|
text_list=text_list,
|
||||||
|
proms_list=proms_list,
|
||||||
|
resps_list=resps_list,
|
||||||
|
|
||||||
|
lang_list=lang_list,
|
||||||
|
tone_list=tone_list,
|
||||||
|
len_list=len_list,
|
||||||
|
raw_text_list=raw_text_list,
|
||||||
|
|
||||||
|
disable_tqdm=disable_tqdm,
|
||||||
|
use_lora=use_lora,
|
||||||
|
**sampling_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# is AR
|
# is AR
|
||||||
return self.forward_ar(
|
return self.forward_ar(
|
||||||
task_list=task_list,
|
task_list=task_list,
|
||||||
|
@ -1407,7 +1578,8 @@ def example_usage():
|
||||||
resps_list = engine( text_list=text_list, proms_list=proms_list, len_list=len_list )
|
resps_list = engine( text_list=text_list, proms_list=proms_list, len_list=len_list )
|
||||||
else:
|
else:
|
||||||
resps_list = engine( text_list=text_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )
|
resps_list = engine( text_list=text_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )
|
||||||
resps_list = engine( text_list=text_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 )
|
if resps_list[0].dim() == 1 or resps_list[0].shape[-1] == 1:
|
||||||
|
resps_list = engine( text_list=text_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 )
|
||||||
|
|
||||||
for i, o in enumerate(resps_list):
|
for i, o in enumerate(resps_list):
|
||||||
print( o.shape, o )
|
print( o.shape, o )
|
||||||
|
@ -1444,7 +1616,7 @@ def example_usage():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for task in available_tasks:
|
for task in available_tasks:
|
||||||
sample("final", task="tts-nar")
|
sample("final", task=task)
|
||||||
|
|
||||||
engines.quit()
|
engines.quit()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user