(finally) added parallel AR for cfg.model.version >= 7 (nvidia/audio-codec-44khz is being a pain and it might require training purely AR first......)

This commit is contained in:
mrq 2025-02-23 08:31:03 -06:00
parent 15b3c20e19
commit 67a6009555
2 changed files with 189 additions and 14 deletions

View File

@ -757,21 +757,20 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
if len(metadata) == 0: if len(metadata) == 0:
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate ) return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
# this might be slow
def _exists( id, entry ):
if not cfg.dataset.strict_validate:
return True
if cfg.dataset.use_hdf5:
return key(id, entry) in cfg.hdf5
return (data_dir / id).with_suffix(_get_artifact_extension()).exists()
def _validate( id, entry ): def _validate( id, entry ):
phones = entry['phones'] if "phones" in entry else 0 phones = entry['phones'] if "phones" in entry else 0
duration = entry['duration'] if "duration" in entry else 0 duration = entry['duration'] if "duration" in entry else 0
k = key(id, entry)
# double check if in HDF5
# this might be slow
if cfg.dataset.strict_validate:
if cfg.dataset.use_hdf5:
if k not in cfg.hdf5:
return False
elif not (data_dir / id).with_suffix(_get_artifact_extension()).exists():
return False
# add to duration bucket # add to duration bucket
if type not in _durations_map: if type not in _durations_map:
_durations_map[type] = {} _durations_map[type] = {}
@ -780,7 +779,11 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
if not validate: if not validate:
return True return True
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration in_bounds = cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
if in_bounds and not _exists( id, entry ):
return False
return in_bounds
return [ key(id, entry) for id, entry in metadata.items() if _validate(id, entry) ] return [ key(id, entry) for id, entry in metadata.items() if _validate(id, entry) ]

View File

@ -1042,6 +1042,158 @@ class AR_NAR(Base):
return sequence_list return sequence_list
def forward_ar_parallel(
self,
task_list: list[Tensor],
text_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
disable_tqdm=False,
use_lora=None,
**sampling_kwargs,
):
# deduce batch_size
if text_list:
device = text_list[0].device
batch_size = len(text_list)
elif raw_text_list:
device = raw_text_list[0].device
batch_size = len(raw_text_list)
elif proms_list:
device = proms_list[0].device
batch_size = len(proms_list)
elif resps_list:
device = resps_list[0].device
batch_size = len(resps_list)
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
# convert AR specific args
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
temperature = sampling_kwargs.get("temperature", 1.0)
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7)
min_temperature = sampling_kwargs.get("min_temperature", -1.0)
max_duration = sampling_kwargs.get("max_duration", 500)
beam_width = sampling_kwargs.get("beam_width", 0)
entropix_sampling = sampling_kwargs.get("entropix_sampling", False)
refine_on_stop = sampling_kwargs.get("refine_on_stop", False)
input_prompt_prefix = sampling_kwargs.get("input_prompt_prefix", False)
layer_skip = sampling_kwargs.get("layer_skip", False)
prefix_silence = sampling_kwargs.get("prefix_silence", 0.0)
mirostat_tau = sampling_kwargs.get("mirostat_tau", 0.0)
mirostat_eta = sampling_kwargs.get("mirostat_eta", 0.0)
start_slice = [ 0 for _ in range(batch_size) ]
sequence_list = [ torch.zeros((0, 8), device=device).to(torch.int16) for _ in range(batch_size) ]
stopped = torch.zeros(batch_size, device=device).bool()
audio_stop_token = self.stop_token
text_stop_token = 2
state = None
mirostat = [
{"n": 1024, "tau": mirostat_tau, "eta": mirostat_eta, "max_surprise": mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
] * batch_size if mirostat_tau > 0.0 else None
scores = [ 1.0 ] * beam_width
metrics = []
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
null_prom = [ None for _ in range(batch_size) ]
# get next in sequence
iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm)
for n in iterator:
if raw_text_list is not None:
raw_text_list = [ sequence_list[i] if task in text_task else raw_text_list[i] for i, task in enumerate(task_list) ]
else:
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
inputs = self.inputs(
task_list=task_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
quant_levels=quant_levels,
)
# to-do: find an elegant way to write this
output = super().forward(
inputs=inputs,
state=state,
#layer_skip_variables=sampling_layer_skip_variables,
output_attentions=entropix_sampling,
)
if cfg_strength > 0:
null_inputs = super().inputs(
text_list=null_text,
proms_list=null_prom,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
quant_levels=quant_levels,
)
null_output = super().forward(
inputs=null_inputs,
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ resp.shape[0] + 1 for resp in resps_list ] )
logits, state = output.logits, output.state
l_resps_list = [ [] for _ in range(batch_size) ]
for l in range(self.n_resp_levels):
sampled = super().sample(
logits=[ logit[l] for logit in logits ],
prev_list=[ resp[..., l] for resp in resps_list ],
**(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}),
)
ids = sampled.ids
# append tokens
for i, token in enumerate(ids):
if audio_stop_token in token:
stopped[i] = True
l_resps_list[i].append(token.to(device))
for i, l in enumerate(l_resps_list):
sequence_list[i] = torch.cat([sequence_list[i], torch.stack(l, dim=-1)])
# stop token found
# stopped |= r == stop_token
if stopped.all().item():
iterator.close()
break
for i, l in enumerate( sequence_list ):
index = (l == audio_stop_token).nonzero()[:, 0].min()
sequence_list[i] = sequence_list[i][:index]
return sequence_list
def forward( def forward(
self, self,
task_list: list[Tensor] | None = None, task_list: list[Tensor] | None = None,
@ -1169,6 +1321,25 @@ class AR_NAR(Base):
**sampling_kwargs, **sampling_kwargs,
) )
if self.version >= 7:
if task_list is None or task_list[0] != "len":
return self.forward_ar_parallel(
task_list=task_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
**sampling_kwargs,
)
# is AR # is AR
return self.forward_ar( return self.forward_ar(
task_list=task_list, task_list=task_list,
@ -1407,7 +1578,8 @@ def example_usage():
resps_list = engine( text_list=text_list, proms_list=proms_list, len_list=len_list ) resps_list = engine( text_list=text_list, proms_list=proms_list, len_list=len_list )
else: else:
resps_list = engine( text_list=text_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 ) resps_list = engine( text_list=text_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )
resps_list = engine( text_list=text_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 ) if resps_list[0].dim() == 1 or resps_list[0].shape[-1] == 1:
resps_list = engine( text_list=text_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 )
for i, o in enumerate(resps_list): for i, o in enumerate(resps_list):
print( o.shape, o ) print( o.shape, o )
@ -1444,7 +1616,7 @@ def example_usage():
""" """
for task in available_tasks: for task in available_tasks:
sample("final", task="tts-nar") sample("final", task=task)
engines.quit() engines.quit()