unified more things with training the AR+NAR monolothic model

This commit is contained in:
mrq 2023-09-12 15:54:41 -05:00
parent 40ef34e1ca
commit d07c63b9d8
4 changed files with 81 additions and 85 deletions

View File

@ -131,6 +131,7 @@ class Dataset:
phones_range: list[int] = field(default_factory=lambda: [4, 256])
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0])
min_utterances: int = 0
random_utterance: float = 1.0
max_prompts: int = 3

View File

@ -59,24 +59,30 @@ def _get_quant_path(path):
def _get_phone_path(path):
return _replace_file_extension(path, ".phn.txt")
_total_durations = {}
@cfg.diskcache()
def _calculate_durations( type="training" ):
if type in _total_durations:
return _total_durations[type]
return 0
@cfg.diskcache()
def _load_paths(dataset, type="training"):
return { cfg.get_spkr( data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
"""
def _load_paths_from_hdf5(dataset, type="training"):
return { cfg.get_spkr( data_dir / "dummy" ): _get_hdf5_paths( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
def _load_paths_from_disk(dataset, type="training"):
return { cfg.get_spkr( data_dir / "dummy" ): _get_paths_of_extensions( data_dir, ".qnt.pt", validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") }
"""
def _load_paths_from_metadata(data_dir, type="training", validate=False):
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
def _validate( entry ):
if "phones" not in entry or "duration" not in entry:
return False
phones = entry['phones']
duration = entry['duration']
if type not in _total_durations:
_total_durations[type] = 0
_total_durations[type] += entry['duration']
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
metadata_path = data_dir / "metadata.json"
@ -107,6 +113,9 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ):
def _validate(child):
phones = child.attrs['phonemes']
duration = child.attrs['duration']
if type not in _total_durations:
_total_durations[type] = 0
_total_durations[type] += entry['duration']
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
key = f"/{type}{_get_hdf5_path(data_dir)}"
@ -172,6 +181,14 @@ class Dataset(_Dataset):
self.dataset = cfg.dataset.training
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
# cull speakers if they do not have enough utterances
if cfg.dataset.min_utterances > 0:
keys = list(self.paths_by_spkr_name.keys())
for key in keys:
if len(self.paths_by_spkr_name[key]) < cfg.dataset.min_utterances:
del self.paths_by_spkr_name[key]
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
@ -192,13 +209,8 @@ class Dataset(_Dataset):
if len(self.paths) == 0 and training:
raise ValueError("No valid path is found for training.")
# would be a better cost saving if we could fetch the duration during the validation pass but oh well
self.duration = 0
"""
if cfg.dataset.use_hdf5:
for path in tqdm(self.paths, desc="Calculating duration"):
self.duration += cfg.hdf5[_get_hdf5_path(path)].attrs['duration']
"""
#self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0
self.duration = _calculate_durations(self.dataset_type)
@cached_property
def phones(self):
@ -663,6 +675,7 @@ def create_dataset_hdf5( skip_existing=True ):
# grab IDs for every file
ids = { ".".join(file.split(".")[:-2]) for file in files }
for id in tqdm(ids, desc=f"Processing {name}"):
try:
audio_exists = os.path.exists(f'{root}/{name}/{id}.qnt.pt') if audios else True
text_exists = os.path.exists(f'{root}/{name}/{id}.phn.txt') if texts else True
@ -697,8 +710,7 @@ def create_dataset_hdf5( skip_existing=True ):
# text
if texts:
with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf-8") as f:
content = f.read().split(" ")
content = open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf-8") .read().split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in content ] + [f"</s>"]
for s in set(phones):
if s not in symmap:
@ -714,6 +726,8 @@ def create_dataset_hdf5( skip_existing=True ):
else:
group.attrs['phonemes'] = 0
metadata[id]["phones"] = 0
except Exception as e:
pass
with open(dir / "metadata.json", "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) )

View File

@ -94,18 +94,9 @@ class AR_NAR(Base):
# is training
if n_levels == self.n_resp_levels:
if random.random() < cfg.models.ar_nar.p_ar_nar:
quant_levels = None
targ_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
resps_list = self._unsqueeze_list(targ_list)
else:
quant_levels = torch.randint(1, self.n_resp_levels, (batch_size,))
targ_list = [o[..., l] for o, l in zip(resps_list, quant_levels)]
resps_list = [o[..., : l] for o, l in zip(resps_list, quant_levels)]
if quant_levels is not None:
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,))
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)]
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)]
quant_levels.to(device=device)
return super().forward(
@ -247,8 +238,6 @@ def example_usage():
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
print([ name for name, _ in model.named_parameters()])
@torch.inference_mode()
def sample( name, steps=600 ):
engine.eval()

View File

@ -392,7 +392,6 @@ class Base(nn.Module):
# compute loss if the target is given
if targ_list is not None:
ignore_sep = torch.tensor(self.ignore_index, device=device)
# create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against
prom_list = [ torch.full_like(t[..., 0], self.ignore_index) for t in proms_list ]
# remake input sequence
@ -401,23 +400,16 @@ class Base(nn.Module):
# process each batch
for i in range(len(text_prom_list)):
# for the AR, shift the text/input prompt into the future by 1, and ignore the rolled back text token
if quant_levels is None:
if quant_levels is None or quant_levels[i] == 0:
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
targ_list[i] = targ_list[i].clone().roll(-1, dims=0)
text_prom_list[i][-1] = self.ignore_index
targ_list[i][-1] = self.stop_token
# for the NAR, ignore completely computing the loss against the text prompt
else:
text_prom_list[i][:] = self.ignore_index
# adjust the target sequence if needed for the AR
if quant_levels is None:
# creates a copy because this is aliased against input response sequence
targ_list = [*targ_list]
# shift the target response into the future by 1, and mark the rolled back token / last token as a stop token
# this prepares the AR to actually generate autoregressive sequences
for i in range(len(targ_list)):
targ_list[i] = targ_list[i].roll(-1, dims=0)
targ_list[i][-1] = self.stop_token
# create the new target sequence to compute the loss against
target = torch.cat( self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) )
inputs = torch.cat( logits )