added option to either naively concat codes to concat audio waveforms (prior behavior) or to decode => concat => encode instead (although this only currently happens for prom sampling if an utternace is too small)

This commit is contained in:
mrq 2024-07-18 16:48:41 -05:00
parent 97e768601c
commit bccbb77a1a
3 changed files with 65 additions and 36 deletions

View File

@ -137,8 +137,9 @@ class Dataset:
hdf5_name: str = "data.h5"
use_hdf5: bool = False
use_metadata: bool = False
hdf5_flag: str = "a"
use_metadata: bool = False
validate: bool = True
workers: int = 8
cache: bool = True
@ -163,6 +164,8 @@ class Dataset:
sample_shuffle: bool = True #
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
reencode_device: str = "cuda" # "cpu" is slower but saves memory
_frames_per_second: int = 0 # allows setting your own hint
@ -666,7 +669,7 @@ class Optimizations:
class Config(BaseConfig):
device: str = "cuda"
mode: str = "training" # "inferencing"
experimental: bool = False # So I can stop commenting out things when committing
experimental: bool = False # Debug flag, unused now
dataset: Dataset = field(default_factory=lambda: Dataset)
models: dict | list | None = field(default_factory=lambda: [])

View File

@ -11,7 +11,7 @@ import torch
import itertools
from .config import cfg
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
from .utils.distributed import global_rank, local_rank, world_size
@ -541,12 +541,7 @@ class Dataset(_Dataset):
self.tone_symmap = self._get_tone_symmap()
self.task_symmap = self._get_task_symmap()
"""
self.empty_text = tokenize(" ")
if len(self.empty_text) == 4:
self.empty_text = self.empty_text[:1] + self.empty_text[1:2] + self.empty_text[-1:]
"""
# grab IDs for bos, space, and eos for easy input creation later
self.empty_text = [ cfg.tokenizer._bos_token, cfg.tokenizer.get_vocab()[" "], cfg.tokenizer._eos_token ]
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
@ -743,7 +738,7 @@ class Dataset(_Dataset):
qnt = _load_quants(path, return_metadata=False)
if 0 < trim_length and trim_length < qnt.shape[0]:
qnt = trim( qnt, trim_length )
qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat )
prom_list.append(qnt)
prom_length += qnt.shape[0]
@ -756,7 +751,7 @@ class Dataset(_Dataset):
prom = torch.cat(prom_list)
if 0 < trim_length and trim_length < prom.shape[0]:
prom = trim( prom, trim_length )
prom = trim( prom, trim_length, reencode=cfg.dataset.reencode_on_concat )
return prom
@ -814,15 +809,13 @@ class Dataset(_Dataset):
lang = torch.tensor([self.lang_symmap[lang]]).to(torch.uint8)
tone = torch.tensor([self.tone_symmap[tone]]).to(torch.uint8)
naive = True
# a bool to easily experiment with two mindsets later
naive = cfg.experimental
# disabled because I haven't actually needed to use it myself, and I can't be assed to validate if it still works
# it probably is better to pad with silence instead of just stitching utterances and ruining things
"""
# append additional prompts in an attempt to artifically increase lengths / offer new data
if cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
ignore_paths = []
for _ in range( cfg.dataset.max_resps - 1 ):
for _ in range( 1, cfg.dataset.max_resps ):
path, txt, qnt = self.sample_utterance(spkr_name, ignore=ignore_paths)
ignore_paths.append(path)
@ -836,15 +829,8 @@ class Dataset(_Dataset):
# might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues
resps = torch.concat([ resps, qnt ])
"""
"""
task = "tts"
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
"""
resps = concat_audio( resps, qnt, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
"""
resps = resps[:, :cfg.model.resp_levels]
proms = proms[:, :cfg.model.resp_levels]
@ -888,7 +874,7 @@ class Dataset(_Dataset):
# extend the noise to fill the target audio
noise = repeat_extend_audio(noise, resps.shape[0])
# create the input prompt by merging the target audio with the noise
proms = merge_audio( resps, noise, scale=[1, noise_scale], device="cpu" )
proms = merge_audio( resps, noise, scale=[1, noise_scale], device=cfg.dataset.reencode_device )
# set the target to just be the noise if <sr>
if task == "sr":
resps = noise
@ -907,10 +893,10 @@ class Dataset(_Dataset):
# overlay the random speaker over the target audio
smallest_size = min(resps.shape[0], other_proms.shape[0])
if other_proms.shape[0] == smallest_size:
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)], device="cpu" )
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device )
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
else:
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)], device="cpu" )
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device )
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
# stitch together the proms
@ -970,7 +956,7 @@ class Dataset(_Dataset):
# extend the noise to fill the target audio
n = repeat_extend_audio(noise, p.shape[0])
# merge the noise over the utterance
return merge_audio(p, n, scale=[1, noise_scale], device="cpu")
return merge_audio(p, n, scale=[1, noise_scale], device=cfg.dataset.reencode_device)
# apply noise to all pieces
pre_prom = noise_proms( pre_prom )
@ -988,10 +974,12 @@ class Dataset(_Dataset):
]
# create new resp
resps = torch.cat(
resps = concat_audio(
([ pre_prom ] if pre_prom is not None else []) +
[ edit_prom ] +
([ post_prom ] if post_prom is not None else [])
([ post_prom ] if post_prom is not None else []),
reencode=cfg.dataset.reencode_on_concat,
device=cfg.dataset.reencode_device,
)
else:
raise Exception(f'Undefined task: {task}')

View File

@ -431,7 +431,7 @@ def encode_from_file(path, device="cuda"):
Helper Functions
"""
# trims from the start, up to `target`
def trim( qnt, target ):
def trim( qnt, target, reencode=False ):
length = max( qnt.shape[0], qnt.shape[1] )
if target > 0:
start = 0
@ -446,7 +446,16 @@ def trim( qnt, target ):
if start < 0:
start = 0
return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end]
if not reencode:
return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end]
# trims on the waveform itself
# need to test
start = start / cfg.dataset.frames_per_second * cfg.sample_rate
end = end / cfg.dataset.frames_per_second * cfg.sample_rate
wav = decode(qnt)[0]
return encode(wav[start:end], cfg.sample_rate)[0].t()
# trims a random piece of audio, up to `target`
# to-do: try and align to EnCodec window
@ -470,18 +479,47 @@ def repeat_extend_audio( qnt, target ):
return trim(torch.cat(pieces), target)
# interleaves between a list of audios
# useful for interleaving silence
def interleave_audio( *args, audio=None ):
qnts = [*args]
if audio is None:
return qnts
# interleave silence
# yes there's a better way
res = []
for i, qnt in enumerate(qnts):
res.append( qnt )
if i + 1 != len(qnts):
res.append( audio )
return res
# concats two audios together
def concat_audio( *args, reencode=False, device="cuda", levels=cfg.model.max_levels ):
qnts = [*args]
# just naively combine the codes
if not reencode:
return torch.concat( qnts )
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ]
combined = torch.concat( decoded )
return encode(combined, cfg.sample_rate, device=device, levels=levels)[0].t()
# merges two quantized audios together
# I don't know if this works
def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ):
# requires re-encoding because there's no good way to combine the waveforms of two audios without relying on some embedding magic
def merge_audio( *args, device="cuda", scale=[], levels=cfg.model.max_levels ):
qnts = [*args]
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ]
# useful to adjust the volumes of each waveform
if len(scale) == len(decoded):
for i in range(len(scale)):
decoded[i] = decoded[i] * scale[i]
combined = sum(decoded) / len(decoded)
return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t()
return encode(combined, cfg.sample_rate, device=device, levels=levels)[0].t()
"""
if __name__ == "__main__":