From 09cda7d3f9f14c1b2f6b3a81f2e55c29d06d5804 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 16 Oct 2023 19:30:38 -0500 Subject: [PATCH] added sampling by speaker group name (might be better to de-emphasize the LibriVox/Audiobooks that are in large numbers, and emphasize the smaller pools), log cleanup --- scripts/parse_ppp.py | 96 +++++++++++++++++++++++++++++++++++++++++ vall_e/data.py | 62 ++++++++++++++++++++++---- vall_e/emb/qnt.py | 2 +- vall_e/engines/base.py | 7 +-- vall_e/models/ar_nar.py | 26 ++++------- vall_e/utils/trainer.py | 7 ++- 6 files changed, 166 insertions(+), 34 deletions(-) create mode 100644 scripts/parse_ppp.py diff --git a/scripts/parse_ppp.py b/scripts/parse_ppp.py new file mode 100644 index 0000000..51f2300 --- /dev/null +++ b/scripts/parse_ppp.py @@ -0,0 +1,96 @@ +import os +import json +import torch + +from tqdm.auto import tqdm +from pathlib import Path +from vall_e.emb.g2p import encode as valle_phonemize +from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension + +device = "cuda" + +target = "in" + +audio_map = {} +text_map = {} + +data = {} + +for season in os.listdir(f"./{target}/"): + if not os.path.isdir(f"./{target}/{season}/"): + continue + + for episode in os.listdir(f"./{target}/{season}/"): + if not os.path.isdir(f"./{target}/{season}/{episode}/"): + continue + + for filename in os.listdir(f"./{target}/{season}/{episode}/"): + path = f'./{target}/{season}/{episode}/{filename}' + attrs = filename.split("_") + timestamp = f'{attrs[0]}h{attrs[1]}m{attrs[2]}s' + + key = f'{episode}_{timestamp}' + + if filename[-5:] == ".flac": + name = attrs[3] + emotion = attrs[4] + quality = attrs[5] + + audio_map[key] = { + "path": path, + 'episode': episode, + "name": name, + "emotion": emotion, + "quality": quality, + "timestamp": timestamp, + } + + elif filename[-4:] == ".txt": + text_map[key] = open(path, encoding="utf-8").read() +txts = {} +wavs = [] + +for key, entry in audio_map.items(): + path = entry['path'] + name = entry['name'] + emotion = entry['emotion'] + quality = entry['quality'] + episode = entry['episode'] + path = entry['path'] + timestamp = entry['timestamp'] + transcription = text_map[key] + if name not in data: + data[name] = {} + os.makedirs(f'./training/{name}/', exist_ok=True) + os.makedirs(f'./voices/{name}/', exist_ok=True) + + key = f'{episode}_{timestamp}.flac' + os.rename(path, f'./voices/{name}/{key}') + + data[name][key] = { + "segments": [], + "language": "en", + "text": transcription, + "misc": { + "emotion": emotion, + "quality": quality, + "timestamp": timestamp, + "episode": episode, + } + } + + path = f'./voices/{name}/{key}' + txts[path] = transcription + wavs.append(Path(path)) + +for name in data.keys(): + open(f"./training/{name}/whisper.json", "w", encoding="utf-8").write( json.dumps( data[name], indent='\t' ) ) + +for key, text in tqdm(txts.items(), desc="Phonemizing..."): + path = Path(key) + phones = valle_phonemize(text) + open(_replace_file_extension(path, ".phn.txt"), "w", encoding="utf-8").write(" ".join(phones)) + +for path in tqdm(wavs, desc="Quantizing..."): + qnt = valle_quantize(path, device=device) + torch.save(qnt.cpu(), _replace_file_extension(path, ".qnt.pt")) \ No newline at end of file diff --git a/vall_e/data.py b/vall_e/data.py index d780b23..88dfd85 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -92,11 +92,12 @@ def _load_paths_from_metadata(data_dir, type="training", validate=False): return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones metadata_path = data_dir / "metadata.json" - if not cfg.dataset.use_metadata or not metadata_path.exists(): - return _fn( data_dir, type if cfg.dataset.use_hdf5 else ".qnt.pt", validate ) + metadata = {} + if cfg.dataset.use_metadata and metadata_path.exists(): + metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) - speaker = cfg.get_spkr( data_dir / "dummy" ) - metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) + if len(metadata) == 0: + return _fn( data_dir, type if cfg.dataset.use_hdf5 else ".qnt.pt", validate ) def key( dir, id ): if not cfg.dataset.use_hdf5: @@ -193,6 +194,7 @@ class Dataset(_Dataset): if len(self.dataset) == 0: self.dataset = cfg.dataset.training + # dict of paths keyed by speaker names self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type) # cull speakers if they do not have enough utterances @@ -205,6 +207,23 @@ class Dataset(_Dataset): self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values())) self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() } + + # dict of speakers keyed by speaker group + self.spkrs_by_spkr_group = {} + for data_dir in self.dataset: + spkr = cfg.get_spkr( data_dir / "dummy" ) + spkr_group = cfg.get_spkr_group( data_dir / "dummy" ) + + if len(self.paths_by_spkr_name[spkr]) < cfg.dataset.min_utterances: + continue + + if spkr_group not in self.spkrs_by_spkr_group: + self.spkrs_by_spkr_group[spkr_group] = [] + + self.spkrs_by_spkr_group[spkr_group].append( spkr ) + + self.spkr_groups = list(self.spkrs_by_spkr_group.keys()) + self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() } if cfg.dataset.sample_type == "path": self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)] @@ -214,14 +233,15 @@ class Dataset(_Dataset): self.phone_symmap = phone_symmap or self._get_phone_symmap() self.spkr_symmap = self._get_spkr_symmap() + self.spkr_group_symmap = self._get_spkr_group_symmap() self.lang_symmap = self._get_lang_symmap() self.task_symmap = self._get_task_symmap() # assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8" self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16 - if len(self.paths) == 0 and training: - raise ValueError("No valid path is found for training.") + if len(self.paths) == 0: + raise ValueError(f"No valid path is found for {self.dataset_type}") #self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0 self.duration = _calculate_durations(self.dataset_type) @@ -281,6 +301,9 @@ class Dataset(_Dataset): def _get_spkr_symmap(self): return {s: i for i, s in enumerate(self.spkrs)} + def _get_spkr_group_symmap(self): + return {s: i for i, s in enumerate(self.spkr_groups)} + def _get_lang_symmap(self): return get_lang_symmap() @@ -358,14 +381,31 @@ class Dataset(_Dataset): return prom def __getitem__(self, index): - if cfg.dataset.sample_type == "speaker": + if cfg.dataset.sample_type == "group": + spkr_group = self.spkr_groups[index] + spkr_group_id = self.spkr_group_symmap[spkr_group] + spkr_name = self.spkr_samplers[spkr_group].sample() + if spkr_name in self.spkr_symmap: + spkr_id = self.spkr_symmap[spkr_name] + else: + spkr_id = -1 + try: + path = self.samplers[spkr_name].sample() + except Exception as e: + print( "ERROR", spkr_group, spkr_name ) + raise e + elif cfg.dataset.sample_type == "speaker": spkr_name = self.spkrs[index] spkr_id = self.spkr_symmap[spkr_name] path = self.samplers[spkr_name].sample() + spkr_group = self.get_speaker_group(path) + spkr_group_id = self.spkr_group_symmap[spkr_group] else: path = self.paths[index] spkr_name = self.get_speaker(path) spkr_id = self.spkr_symmap[spkr_name] + spkr_group = self.get_speaker_group(path) + spkr_group_id = self.spkr_group_symmap[spkr_group] if cfg.dataset.use_hdf5: key = _get_hdf5_path(path) @@ -381,7 +421,6 @@ class Dataset(_Dataset): text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype) resps = _load_quants(path) - spkr_group = self.get_speaker_group(path) lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8) # append additional prompts in an attempt to artifically increase lengths / offer new data @@ -611,6 +650,8 @@ class Dataset(_Dataset): self.training = value def __len__(self): + if cfg.dataset.sample_type == "group": + return min(len(self.spkr_groups), self._head or len(self.spkr_groups)) if cfg.dataset.sample_type == "speaker": return min(len(self.spkrs), self._head or len(self.spkrs)) return min(len(self.paths), self._head or len(self.paths)) @@ -679,6 +720,7 @@ def create_train_val_dataloader(): _logger.info(str(train_dataset.phone_symmap)) _logger.info(str(train_dataset.spkr_symmap)) + _logger.info(str(train_dataset.spkr_group_symmap)) _logger.info(f"#samples (train): {len(train_dataset)}.") _logger.info(f"#samples (val): {len(val_dataset)}.") @@ -707,6 +749,10 @@ def create_dataset_metadata(): metadata = {} for path in tqdm(paths, desc="Parsing paths"): + if isinstance(path, str): + print("str:", path) + path = Path(path) + speaker = cfg.get_spkr(path) if speaker not in metadata: metadata[speaker] = {} diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 7a2f9dc..a86af5e 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -171,7 +171,7 @@ def encode_from_file(path, device="cuda"): return encode_from_files( path, device ) else: path = str(path) - wav, sr = torchaudio.load(path, format=path[-3:]) + wav, sr = torchaudio.load(path) if wav.shape[0] == 2: wav = wav[:1] diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index de29d80..f2949db 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -469,8 +469,9 @@ class Engines(dict[str, Engine]): self._update() - stats["elapsed_time"] = total_elapsed_time - stats["global_step"] = self.global_step - #stats["micro_step"] = self.micro_step + if len(self.keys()) > 1: + stats["elapsed_time"] = total_elapsed_time + + stats["it"] = self.global_step return stats diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index bcd5704..4c1b642 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -146,13 +146,13 @@ class AR_NAR(Base): quant_levels=quant_levels, ) # is NAR - prev_list = resps_list if max_levels == 0: max_levels = self.n_resp_levels + + prev_list = resps_list - while True: + for n in trange( max_levels, desc="NAR" ): level = prev_list[0].shape[-1] - if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels break @@ -195,14 +195,13 @@ class AR_NAR(Base): {"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0} ] * batch_size if sampling_mirostat_tau > 0.0 else None - sampling_beam_width_use_logs = True scores = [ 1.0 ] * sampling_beam_width if self.interleave: max_steps *= self.n_prom_levels # get next in sequence - for n in trange(max_steps // max(1, self.recurrent_chunk_size)): + for n in trange(max_steps // max(1, self.recurrent_chunk_size), desc="AR"): # experimental rolling response to avoid too-long perplexity hits despite RetNet allegedly fixing this. # UNTESTED. In theory it would be better to also adjust the text, but there's no way of correlating text to segment of audio without something like wav2vec2 if max_resp_context > 0: @@ -245,17 +244,13 @@ class AR_NAR(Base): r, s = r # first step, expand batch if batch_size == 1: - batch_size *= sampling_beam_width + batch_size = sampling_beam_width text_list = text_list * sampling_beam_width proms_list = proms_list * sampling_beam_width sequence_list = sequence_list * sampling_beam_width stopped = torch.zeros(batch_size, device=device).bool() - # update scores - if sampling_beam_width_use_logs: - scores = [ (math.log(scores[i]) if scores[i] > 0 else 0) + math.log(score) for i, score in enumerate(s) ] - else: - scores = [ scores[i] * score for i, score in enumerate(s) ] + scores = [ scores[i] + score for i, score in enumerate(s) ] # append tokens for i, ri in enumerate(r): @@ -270,13 +265,8 @@ class AR_NAR(Base): # pick the best scoring candidate # desu this is always going to be candidate 0 - if sampling_beam_width and len(scores) > 0: - best_idx, best_score = (0, 0) - for idx, score in enumerate(scores): - if best_score > score: - best_idx, best_score = idx, score - - sequence_list = [sequence_list[best_idx]] + if sampling_beam_width: + sequence_list = [ sequence_list[0] ] return [self._prune(r) for r in sequence_list] diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 2214f9b..d9bf9e0 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -101,7 +101,7 @@ def _non_blocking_input(): def _make_infinite_epochs(dl): while True: - _logger.info("New epoch starts.") + #_logger.info("New epoch starts.") yield from tqdm(dl, "Epoch progress", dynamic_ncols=True) @@ -158,10 +158,9 @@ def train( #batch = to_device(batch, torch.cuda.current_device()) stats = engines.step(batch=batch, feeder=train_feeder) - - stats['it'] = stats['global_step'] stats['epoch'] = engines.global_samples / len(train_dl.dataset.paths) + """ stats['batch'] = { 'size': len(batch['text']), 'id': batch['spkr_id'], @@ -170,8 +169,8 @@ def train( 'prom_len': [ prom.shape[0] for prom in batch['proms'] ], 'resp_len': [ resp.shape[0] for resp in batch['resps'] ], } + """ - del stats['global_step'] elapsed_time = stats.get("elapsed_time", 0) _logger.info(f"Training Metrics: {json.dumps(stats)}.")