diff --git a/vall_e/config.py b/vall_e/config.py index 87c4963..6f8c9c0 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -137,8 +137,9 @@ class Dataset: hdf5_name: str = "data.h5" use_hdf5: bool = False - use_metadata: bool = False hdf5_flag: str = "a" + use_metadata: bool = False + validate: bool = True workers: int = 8 cache: bool = True @@ -163,6 +164,8 @@ class Dataset: sample_shuffle: bool = True # tasks_list: list[str] = field(default_factory=lambda: ["tts"]) + reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes + reencode_device: str = "cuda" # "cpu" is slower but saves memory _frames_per_second: int = 0 # allows setting your own hint @@ -666,7 +669,7 @@ class Optimizations: class Config(BaseConfig): device: str = "cuda" mode: str = "training" # "inferencing" - experimental: bool = False # So I can stop commenting out things when committing + experimental: bool = False # Debug flag, unused now dataset: Dataset = field(default_factory=lambda: Dataset) models: dict | list | None = field(default_factory=lambda: []) diff --git a/vall_e/data.py b/vall_e/data.py index 9310290..15a4aa3 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -11,7 +11,7 @@ import torch import itertools from .config import cfg -from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file +from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler from .utils.distributed import global_rank, local_rank, world_size @@ -541,12 +541,7 @@ class Dataset(_Dataset): self.tone_symmap = self._get_tone_symmap() self.task_symmap = self._get_task_symmap() - """ - self.empty_text = tokenize(" ") - if len(self.empty_text) == 4: - self.empty_text = self.empty_text[:1] + self.empty_text[1:2] + self.empty_text[-1:] - """ - + # grab IDs for bos, space, and eos for easy input creation later self.empty_text = [ cfg.tokenizer._bos_token, cfg.tokenizer.get_vocab()[" "], cfg.tokenizer._eos_token ] # assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8" @@ -743,7 +738,7 @@ class Dataset(_Dataset): qnt = _load_quants(path, return_metadata=False) if 0 < trim_length and trim_length < qnt.shape[0]: - qnt = trim( qnt, trim_length ) + qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat ) prom_list.append(qnt) prom_length += qnt.shape[0] @@ -756,7 +751,7 @@ class Dataset(_Dataset): prom = torch.cat(prom_list) if 0 < trim_length and trim_length < prom.shape[0]: - prom = trim( prom, trim_length ) + prom = trim( prom, trim_length, reencode=cfg.dataset.reencode_on_concat ) return prom @@ -814,15 +809,13 @@ class Dataset(_Dataset): lang = torch.tensor([self.lang_symmap[lang]]).to(torch.uint8) tone = torch.tensor([self.tone_symmap[tone]]).to(torch.uint8) - naive = True + # a bool to easily experiment with two mindsets later + naive = cfg.experimental - # disabled because I haven't actually needed to use it myself, and I can't be assed to validate if it still works - # it probably is better to pad with silence instead of just stitching utterances and ruining things - """ # append additional prompts in an attempt to artifically increase lengths / offer new data if cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append: ignore_paths = [] - for _ in range( cfg.dataset.max_resps - 1 ): + for _ in range( 1, cfg.dataset.max_resps ): path, txt, qnt = self.sample_utterance(spkr_name, ignore=ignore_paths) ignore_paths.append(path) @@ -836,15 +829,8 @@ class Dataset(_Dataset): # might be better to decode => concat waveforms with silence in between => reencode # as you technically can't just append encodec sequences together like this without issues - resps = torch.concat([ resps, qnt ]) - """ - - """ - task = "tts" - trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) - proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps - """ - + resps = concat_audio( resps, qnt, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device ) + """ resps = resps[:, :cfg.model.resp_levels] proms = proms[:, :cfg.model.resp_levels] @@ -888,7 +874,7 @@ class Dataset(_Dataset): # extend the noise to fill the target audio noise = repeat_extend_audio(noise, resps.shape[0]) # create the input prompt by merging the target audio with the noise - proms = merge_audio( resps, noise, scale=[1, noise_scale], device="cpu" ) + proms = merge_audio( resps, noise, scale=[1, noise_scale], device=cfg.dataset.reencode_device ) # set the target to just be the noise if if task == "sr": resps = noise @@ -907,10 +893,10 @@ class Dataset(_Dataset): # overlay the random speaker over the target audio smallest_size = min(resps.shape[0], other_proms.shape[0]) if other_proms.shape[0] == smallest_size: - noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)], device="cpu" ) + noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device ) noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] ) else: - noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)], device="cpu" ) + noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device ) noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] ) # stitch together the proms @@ -970,7 +956,7 @@ class Dataset(_Dataset): # extend the noise to fill the target audio n = repeat_extend_audio(noise, p.shape[0]) # merge the noise over the utterance - return merge_audio(p, n, scale=[1, noise_scale], device="cpu") + return merge_audio(p, n, scale=[1, noise_scale], device=cfg.dataset.reencode_device) # apply noise to all pieces pre_prom = noise_proms( pre_prom ) @@ -988,10 +974,12 @@ class Dataset(_Dataset): ] # create new resp - resps = torch.cat( + resps = concat_audio( ([ pre_prom ] if pre_prom is not None else []) + [ edit_prom ] + - ([ post_prom ] if post_prom is not None else []) + ([ post_prom ] if post_prom is not None else []), + reencode=cfg.dataset.reencode_on_concat, + device=cfg.dataset.reencode_device, ) else: raise Exception(f'Undefined task: {task}') diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 1dd2ebe..47a3ebd 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -431,7 +431,7 @@ def encode_from_file(path, device="cuda"): Helper Functions """ # trims from the start, up to `target` -def trim( qnt, target ): +def trim( qnt, target, reencode=False ): length = max( qnt.shape[0], qnt.shape[1] ) if target > 0: start = 0 @@ -446,7 +446,16 @@ def trim( qnt, target ): if start < 0: start = 0 - return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end] + if not reencode: + return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end] + + # trims on the waveform itself + # need to test + start = start / cfg.dataset.frames_per_second * cfg.sample_rate + end = end / cfg.dataset.frames_per_second * cfg.sample_rate + + wav = decode(qnt)[0] + return encode(wav[start:end], cfg.sample_rate)[0].t() # trims a random piece of audio, up to `target` # to-do: try and align to EnCodec window @@ -470,18 +479,47 @@ def repeat_extend_audio( qnt, target ): return trim(torch.cat(pieces), target) +# interleaves between a list of audios +# useful for interleaving silence +def interleave_audio( *args, audio=None ): + qnts = [*args] + if audio is None: + return qnts + + # interleave silence + # yes there's a better way + res = [] + for i, qnt in enumerate(qnts): + res.append( qnt ) + if i + 1 != len(qnts): + res.append( audio ) + + return res + +# concats two audios together +def concat_audio( *args, reencode=False, device="cuda", levels=cfg.model.max_levels ): + qnts = [*args] + # just naively combine the codes + if not reencode: + return torch.concat( qnts ) + + decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ] + combined = torch.concat( decoded ) + return encode(combined, cfg.sample_rate, device=device, levels=levels)[0].t() + # merges two quantized audios together -# I don't know if this works -def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ): +# requires re-encoding because there's no good way to combine the waveforms of two audios without relying on some embedding magic +def merge_audio( *args, device="cuda", scale=[], levels=cfg.model.max_levels ): qnts = [*args] decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ] + # useful to adjust the volumes of each waveform if len(scale) == len(decoded): for i in range(len(scale)): decoded[i] = decoded[i] * scale[i] combined = sum(decoded) / len(decoded) - return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t() + return encode(combined, cfg.sample_rate, device=device, levels=levels)[0].t() """ if __name__ == "__main__":