From 56f25f7a9b25e39caf403aaaec1b31e60a6949b7 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 16 Sep 2024 23:10:29 -0500 Subject: [PATCH] more stuff for similar-speaker prompt sampling (to-do: actually test if this works...) --- scripts/process_emilia.py | 18 +++++++++++------- vall_e/config.py | 8 ++++---- vall_e/data.py | 32 ++++++++++++++++++++++++++------ vall_e/emb/qnt.py | 16 ++++++++++++---- vall_e/emb/similar.py | 16 ++++++++++++++-- 5 files changed, 67 insertions(+), 23 deletions(-) diff --git a/scripts/process_emilia.py b/scripts/process_emilia.py index 1b54d28..a291959 100644 --- a/scripts/process_emilia.py +++ b/scripts/process_emilia.py @@ -15,6 +15,9 @@ from pathlib import Path from vall_e.config import cfg +from vall_e.emb.g2p import encode as phonemize +from vall_e.emb.qnt import encode as quantize, _replace_file_extension, convert_audio + def pad(num, zeroes): return str(num).zfill(zeroes+1) @@ -22,6 +25,13 @@ def process_items( items, stride=0, stride_offset=0 ): items = sorted( items ) return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ] +def load_audio( path, device="cuda" ): + waveform, sample_rate = torchaudio.load(path) + if waveform.shape[0] > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + waveform = convert_audio(waveform, sample_rate, cfg.sample_rate, 1) + return waveform.to(device=device), cfg.sample_rate + def process( audio_backend="encodec", input_audio="Emilia", @@ -57,10 +67,6 @@ def process( cfg.inference.weight_dtype = dtype # "bfloat16" cfg.inference.amp = amp # False - # import after because we've overriden the config above - # need to validate if this is even necessary anymore - from vall_e.emb.g2p import encode as phonemize - from vall_e.emb.qnt import encode as quantize, _replace_file_extension output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training" @@ -139,9 +145,7 @@ def process( text = metadata["text"] if waveform is None: - waveform, sample_rate = torchaudio.load(inpath) - if waveform.shape[0] > 1: - waveform = torch.mean(waveform, dim=0, keepdim=True) + waveform, sample_rate = load_audio(inpath) wavs.append(( outpath, diff --git a/vall_e/config.py b/vall_e/config.py index 5708fe7..3723413 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -148,16 +148,14 @@ class Dataset: phones_range: list[int] = field(default_factory=lambda: [4, 256]) # deprecated, the amount of phonemes an utterance can be to be included in the dataset duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0]) # the duration range the input prompts can be + + # to-do: clean up the following block, it's a mess min_utterances: int = 2 # minimum number of utterances a speaker can have - random_utterance: float = 1.0 # probability to use a different utterance rather than using the target utterance as an input prompt max_prompts: int = 3 # maximum number of utterances that can be included in an input prompt for training - prompt_duration: float | None = None # legacy - max_resps: int = 1 # number of samples to target for training p_resp_append: float = 1.0 # probability to append another sample to the training target - p_resp_pad_silence: float = 0.0 # probability to pad resp with silence to fit within the next window sample_type: str = "path" # path | speaker @@ -167,6 +165,8 @@ class Dataset: # for a full sized model with 24GiB of VRAM for Encodec, 380 seconds is 80% VRAM consumed (but it might be limited by batch size) sample_shuffle: bool = True # i swear this is spiking the loss when sample_order = duration + sample_max_duration_batch > 0 + prom_sample_similar: bool = True # if available, try and sample the prompt closest to the sampled response utterance (requires specific metadata generated) + tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method diff --git a/vall_e/data.py b/vall_e/data.py index 41e7e73..eb9fdcb 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -873,14 +873,29 @@ class Dataset(_Dataset): return path, text, resps + def get_similar_utterance(self, spkr_name, reference, offset=0 ): + # lots of boilerplate checks + metadata_path = Path(f"{metadata_root}/{speaker_name}.json") + if not metadata_path.exists(): + return None + metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) + if reference not in metadata: + return None + reference_metadata = metadata[reference] + if "similar" not in reference_metadata: + return None + if len(reference_metadata["similar"]) >= offset: + offset = -1 + + return reference_metadata["similar"][offset][0] - def sample_prompts(self, spkr_name, ignore, should_trim=True, reference=None): + def sample_prompts(self, spkr_name, reference, should_trim=True): if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0: return None prom_list = [] - choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore} + choices = set(self.paths_by_spkr_name[spkr_name]) - {reference} choices = [*choices] # no other utterances, it'd make more sense to prune speakers with only one utterance in the validation step @@ -895,9 +910,14 @@ class Dataset(_Dataset): prom_length = 0 trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) if trim else 0 - # to-do: if reference is not None, find the closest utterances to the reference for _ in range(cfg.dataset.max_prompts): - path = random.choice(choices) + if reference is not None and cfg.dataset.prom_sample_similar: + path = self.get_similar_utterance( spkr_name=spkr_name, reference=reference, offset = len(prom_list) ) + # yuck + if not path: + path = random.choice(choices) + else: + path = random.choice(choices) if cfg.dataset.use_hdf5: key = _get_hdf5_path(path) qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) @@ -1011,7 +1031,7 @@ class Dataset(_Dataset): # Base TTS ( => ) if task == "tts": - proms = self.sample_prompts(spkr_name, ignore=path, reference=resps) + proms = self.sample_prompts(spkr_name, reference=path) if cfg.dataset.inject_noise_in_prom: # sample random noise @@ -1087,7 +1107,7 @@ class Dataset(_Dataset): # target speech extraction ( => ) elif task == "tse": # sample a prompt - proms = self.sample_prompts(spkr_name, ignore=path) + proms = self.sample_prompts(spkr_name, reference=path) # sample another speaker _, __, other_resps = self.sample_utterance(self.sample_speakers(ignore=[spkr_name])) diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index f84c8e5..ffa9f16 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -491,8 +491,12 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat # AudioDec uses a different pathway if cfg.audio_backend == "audiodec": model = _load_audiodec_model(device) - wav = wav.unsqueeze(0) - wav = convert_audio(wav, sr, model.sample_rate, 1) + # reshape (channel, samples) => (batch, channel, samples) + if wav.dim() < 3: + wav = wav.unsqueeze(0) + # skip unnecessary resample + if sr != model.sample_rate and wav.shape[1] != 1: + wav = convert_audio(wav, sr, model.sample_rate, 1) wav = wav.to(device) # wav = rearrange(wav, "t c -> t 1 c").to(device) @@ -502,8 +506,12 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat # vocos does not encode wavs to encodecs, so just use normal encodec model = _load_encodec_model(device) - wav = wav.unsqueeze(0) - wav = convert_audio(wav, sr, model.sample_rate, model.channels) + # reshape (channel, samples) => (batch, channel, samples) + if wav.dim() < 3: + wav = wav.unsqueeze(0) + # skip unnecessary resample + if sr != model.sample_rate and wav.shape[1] != model.channels: + wav = convert_audio(wav, sr, model.sample_rate, model.channels) wav = wav.to(device) with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py index 150140c..e1c84f4 100644 --- a/vall_e/emb/similar.py +++ b/vall_e/emb/similar.py @@ -48,6 +48,7 @@ def process( amp=False, verbose=False, + metadata_path=None, ): cfg.set_audio_backend(audio_backend) artifact_extension = cfg.audio_backend_extension @@ -69,7 +70,6 @@ def process( for filename in tqdm(os.listdir(f'./{input_speaker}/'), desc="Encoding...", disable=not verbose): extension = filename.split(".")[-1] - if text: if extension not in artifact_extension: raise Exception("!") @@ -140,17 +140,29 @@ def process( if filename_a not in sorted_similarities[filename_b]: sorted_similarities[filename_b][filename_a] = similarity + metadata = None + if metadata_path is not None: + metadata_path = metadata_path / f'{input_speaker}.json' + if metadata_path.exists(): + metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) + # sort similarities scores for key, sorted_similarity in sorted_similarities.items(): + filename_a, filename_b = key.split(":") sorted_similarities[key] = sorted([ ( filename, similarity ) for filename, similarity in sorted_similarity.items() ], key=lambda x: x[1], reverse=True) most_filename, most_score = sorted_similarities[key][0] least_filename, least_score = sorted_similarities[key][-1] + if metadata is not None and filename_a in metadata: + metadata[filename_a] = sorted_similarities + if verbose: print( f'{key}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' ) - # to-do: store this somewhere + if metadata is not None: + with open(str(metadata_path), "w", encoding="utf-8") as f: + f.write( json.dumps( metadata ) ) return sorted_similarities