From d07c63b9d8be3a6d88d2489d59881cde0545bef8 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 12 Sep 2023 15:54:41 -0500 Subject: [PATCH] unified more things with training the AR+NAR monolothic model --- vall_e/config.py | 1 + vall_e/data.py | 130 ++++++++++++++++++++++------------------ vall_e/models/ar_nar.py | 19 ++---- vall_e/models/base.py | 16 ++--- 4 files changed, 81 insertions(+), 85 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index e2d4e04..404d130 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -131,6 +131,7 @@ class Dataset: phones_range: list[int] = field(default_factory=lambda: [4, 256]) duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) + min_utterances: int = 0 random_utterance: float = 1.0 max_prompts: int = 3 diff --git a/vall_e/data.py b/vall_e/data.py index e66b352..bc95e41 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -59,24 +59,30 @@ def _get_quant_path(path): def _get_phone_path(path): return _replace_file_extension(path, ".phn.txt") +_total_durations = {} + +@cfg.diskcache() +def _calculate_durations( type="training" ): + if type in _total_durations: + return _total_durations[type] + return 0 + @cfg.diskcache() def _load_paths(dataset, type="training"): return { cfg.get_spkr( data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") } -""" -def _load_paths_from_hdf5(dataset, type="training"): - return { cfg.get_spkr( data_dir / "dummy" ): _get_hdf5_paths( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") } - -def _load_paths_from_disk(dataset, type="training"): - return { cfg.get_spkr( data_dir / "dummy" ): _get_paths_of_extensions( data_dir, ".qnt.pt", validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}") } -""" - def _load_paths_from_metadata(data_dir, type="training", validate=False): _fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions def _validate( entry ): + if "phones" not in entry or "duration" not in entry: + return False phones = entry['phones'] duration = entry['duration'] + if type not in _total_durations: + _total_durations[type] = 0 + _total_durations[type] += entry['duration'] + return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones metadata_path = data_dir / "metadata.json" @@ -107,6 +113,9 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ): def _validate(child): phones = child.attrs['phonemes'] duration = child.attrs['duration'] + if type not in _total_durations: + _total_durations[type] = 0 + _total_durations[type] += entry['duration'] return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones key = f"/{type}{_get_hdf5_path(data_dir)}" @@ -172,6 +181,14 @@ class Dataset(_Dataset): self.dataset = cfg.dataset.training self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type) + + # cull speakers if they do not have enough utterances + if cfg.dataset.min_utterances > 0: + keys = list(self.paths_by_spkr_name.keys()) + for key in keys: + if len(self.paths_by_spkr_name[key]) < cfg.dataset.min_utterances: + del self.paths_by_spkr_name[key] + self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values())) self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() } @@ -192,13 +209,8 @@ class Dataset(_Dataset): if len(self.paths) == 0 and training: raise ValueError("No valid path is found for training.") - # would be a better cost saving if we could fetch the duration during the validation pass but oh well - self.duration = 0 - """ - if cfg.dataset.use_hdf5: - for path in tqdm(self.paths, desc="Calculating duration"): - self.duration += cfg.hdf5[_get_hdf5_path(path)].attrs['duration'] - """ + #self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0 + self.duration = _calculate_durations(self.dataset_type) @cached_property def phones(self): @@ -663,57 +675,59 @@ def create_dataset_hdf5( skip_existing=True ): # grab IDs for every file ids = { ".".join(file.split(".")[:-2]) for file in files } for id in tqdm(ids, desc=f"Processing {name}"): - audio_exists = os.path.exists(f'{root}/{name}/{id}.qnt.pt') if audios else True - text_exists = os.path.exists(f'{root}/{name}/{id}.phn.txt') if texts else True + try: + audio_exists = os.path.exists(f'{root}/{name}/{id}.qnt.pt') if audios else True + text_exists = os.path.exists(f'{root}/{name}/{id}.phn.txt') if texts else True - if not audio_exists or not text_exists: - continue - - key = f'{type}/{name}/{id}' - if key in hf: - if skip_existing: + if not audio_exists or not text_exists: continue - del hf[key] - group = hf.create_group(key) - group.attrs['id'] = id - group.attrs['type'] = type - group.attrs['speaker'] = name + key = f'{type}/{name}/{id}' + if key in hf: + if skip_existing: + continue + del hf[key] - metadata[id] = {} + group = hf.create_group(key) + group.attrs['id'] = id + group.attrs['type'] = type + group.attrs['speaker'] = name - # audio - if audios: - qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t() + metadata[id] = {} - if "audio" in group: - del group["audio"] - group.create_dataset('audio', data=qnt.numpy(), compression='lzf') - group.attrs['duration'] = qnt.shape[0] / 75 - metadata[id]["duration"] = qnt.shape[0] / 75 - else: - group.attrs['duration'] = 0 - metadata[id]["duration"] = 0 - - # text - if texts: - with open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf-8") as f: - content = f.read().split(" ") - phones = [f""] + [ " " if not p else p for p in content ] + [f""] - for s in set(phones): - if s not in symmap: - symmap[s] = len(symmap.keys()) + # audio + if audios: + qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t() - phn = [ symmap[s] for s in phones ] + if "audio" in group: + del group["audio"] + group.create_dataset('audio', data=qnt.numpy(), compression='lzf') + group.attrs['duration'] = qnt.shape[0] / 75 + metadata[id]["duration"] = qnt.shape[0] / 75 + else: + group.attrs['duration'] = 0 + metadata[id]["duration"] = 0 + + # text + if texts: + content = open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf-8") .read().split(" ") + phones = [f""] + [ " " if not p else p for p in content ] + [f""] + for s in set(phones): + if s not in symmap: + symmap[s] = len(symmap.keys()) - if "text" in group: - del group["text"] - group.create_dataset('text', data=phn, compression='lzf', chunks=True) - group.attrs['phonemes'] = len(phn) - metadata[id]["phones"] = len(phn) - else: - group.attrs['phonemes'] = 0 - metadata[id]["phones"] = 0 + phn = [ symmap[s] for s in phones ] + + if "text" in group: + del group["text"] + group.create_dataset('text', data=phn, compression='lzf', chunks=True) + group.attrs['phonemes'] = len(phn) + metadata[id]["phones"] = len(phn) + else: + group.attrs['phonemes'] = 0 + metadata[id]["phones"] = 0 + except Exception as e: + pass with open(dir / "metadata.json", "w", encoding="utf-8") as f: f.write( json.dumps( metadata ) ) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 6257256..1d45694 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -94,19 +94,10 @@ class AR_NAR(Base): # is training if n_levels == self.n_resp_levels: - if random.random() < cfg.models.ar_nar.p_ar_nar: - quant_levels = None - - targ_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels - resps_list = self._unsqueeze_list(targ_list) - else: - quant_levels = torch.randint(1, self.n_resp_levels, (batch_size,)) - - targ_list = [o[..., l] for o, l in zip(resps_list, quant_levels)] - resps_list = [o[..., : l] for o, l in zip(resps_list, quant_levels)] - - if quant_levels is not None: - quant_levels.to(device=device) + quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) + targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] + resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] + quant_levels.to(device=device) return super().forward( text_list=text_list, @@ -246,8 +237,6 @@ def example_usage(): engine = Engine(model=model, optimizer=optimizer) print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") - - print([ name for name, _ in model.named_parameters()]) @torch.inference_mode() def sample( name, steps=600 ): diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 36ec6bd..1bd13d6 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -392,7 +392,6 @@ class Base(nn.Module): # compute loss if the target is given if targ_list is not None: ignore_sep = torch.tensor(self.ignore_index, device=device) - # create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against prom_list = [ torch.full_like(t[..., 0], self.ignore_index) for t in proms_list ] # remake input sequence @@ -401,23 +400,16 @@ class Base(nn.Module): # process each batch for i in range(len(text_prom_list)): # for the AR, shift the text/input prompt into the future by 1, and ignore the rolled back text token - if quant_levels is None: + if quant_levels is None or quant_levels[i] == 0: text_prom_list[i] = text_prom_list[i].roll(-1, dims=0) + targ_list[i] = targ_list[i].clone().roll(-1, dims=0) + text_prom_list[i][-1] = self.ignore_index + targ_list[i][-1] = self.stop_token # for the NAR, ignore completely computing the loss against the text prompt else: text_prom_list[i][:] = self.ignore_index - # adjust the target sequence if needed for the AR - if quant_levels is None: - # creates a copy because this is aliased against input response sequence - targ_list = [*targ_list] - # shift the target response into the future by 1, and mark the rolled back token / last token as a stop token - # this prepares the AR to actually generate autoregressive sequences - for i in range(len(targ_list)): - targ_list[i] = targ_list[i].roll(-1, dims=0) - targ_list[i][-1] = self.stop_token - # create the new target sequence to compute the loss against target = torch.cat( self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) ) inputs = torch.cat( logits )