added option to either naively concat codes to concat audio waveforms (prior behavior) or to decode => concat => encode instead (although this only currently happens for prom sampling if an utternace is too small)
This commit is contained in:
parent
97e768601c
commit
bccbb77a1a
|
@ -137,8 +137,9 @@ class Dataset:
|
|||
|
||||
hdf5_name: str = "data.h5"
|
||||
use_hdf5: bool = False
|
||||
use_metadata: bool = False
|
||||
hdf5_flag: str = "a"
|
||||
use_metadata: bool = False
|
||||
|
||||
validate: bool = True
|
||||
workers: int = 8
|
||||
cache: bool = True
|
||||
|
@ -163,6 +164,8 @@ class Dataset:
|
|||
sample_shuffle: bool = True #
|
||||
|
||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
||||
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
||||
reencode_device: str = "cuda" # "cpu" is slower but saves memory
|
||||
|
||||
_frames_per_second: int = 0 # allows setting your own hint
|
||||
|
||||
|
@ -666,7 +669,7 @@ class Optimizations:
|
|||
class Config(BaseConfig):
|
||||
device: str = "cuda"
|
||||
mode: str = "training" # "inferencing"
|
||||
experimental: bool = False # So I can stop commenting out things when committing
|
||||
experimental: bool = False # Debug flag, unused now
|
||||
|
||||
dataset: Dataset = field(default_factory=lambda: Dataset)
|
||||
models: dict | list | None = field(default_factory=lambda: [])
|
||||
|
|
|
@ -11,7 +11,7 @@ import torch
|
|||
import itertools
|
||||
|
||||
from .config import cfg
|
||||
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
||||
from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file
|
||||
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
||||
from .utils.distributed import global_rank, local_rank, world_size
|
||||
|
||||
|
@ -541,12 +541,7 @@ class Dataset(_Dataset):
|
|||
self.tone_symmap = self._get_tone_symmap()
|
||||
self.task_symmap = self._get_task_symmap()
|
||||
|
||||
"""
|
||||
self.empty_text = tokenize(" ")
|
||||
if len(self.empty_text) == 4:
|
||||
self.empty_text = self.empty_text[:1] + self.empty_text[1:2] + self.empty_text[-1:]
|
||||
"""
|
||||
|
||||
# grab IDs for bos, space, and eos for easy input creation later
|
||||
self.empty_text = [ cfg.tokenizer._bos_token, cfg.tokenizer.get_vocab()[" "], cfg.tokenizer._eos_token ]
|
||||
|
||||
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
|
||||
|
@ -743,7 +738,7 @@ class Dataset(_Dataset):
|
|||
qnt = _load_quants(path, return_metadata=False)
|
||||
|
||||
if 0 < trim_length and trim_length < qnt.shape[0]:
|
||||
qnt = trim( qnt, trim_length )
|
||||
qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat )
|
||||
|
||||
prom_list.append(qnt)
|
||||
prom_length += qnt.shape[0]
|
||||
|
@ -756,7 +751,7 @@ class Dataset(_Dataset):
|
|||
prom = torch.cat(prom_list)
|
||||
|
||||
if 0 < trim_length and trim_length < prom.shape[0]:
|
||||
prom = trim( prom, trim_length )
|
||||
prom = trim( prom, trim_length, reencode=cfg.dataset.reencode_on_concat )
|
||||
|
||||
return prom
|
||||
|
||||
|
@ -814,15 +809,13 @@ class Dataset(_Dataset):
|
|||
lang = torch.tensor([self.lang_symmap[lang]]).to(torch.uint8)
|
||||
tone = torch.tensor([self.tone_symmap[tone]]).to(torch.uint8)
|
||||
|
||||
naive = True
|
||||
# a bool to easily experiment with two mindsets later
|
||||
naive = cfg.experimental
|
||||
|
||||
# disabled because I haven't actually needed to use it myself, and I can't be assed to validate if it still works
|
||||
# it probably is better to pad with silence instead of just stitching utterances and ruining things
|
||||
"""
|
||||
# append additional prompts in an attempt to artifically increase lengths / offer new data
|
||||
if cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
|
||||
ignore_paths = []
|
||||
for _ in range( cfg.dataset.max_resps - 1 ):
|
||||
for _ in range( 1, cfg.dataset.max_resps ):
|
||||
path, txt, qnt = self.sample_utterance(spkr_name, ignore=ignore_paths)
|
||||
ignore_paths.append(path)
|
||||
|
||||
|
@ -836,15 +829,8 @@ class Dataset(_Dataset):
|
|||
|
||||
# might be better to decode => concat waveforms with silence in between => reencode
|
||||
# as you technically can't just append encodec sequences together like this without issues
|
||||
resps = torch.concat([ resps, qnt ])
|
||||
"""
|
||||
|
||||
"""
|
||||
task = "tts"
|
||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
"""
|
||||
|
||||
resps = concat_audio( resps, qnt, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
|
||||
|
||||
"""
|
||||
resps = resps[:, :cfg.model.resp_levels]
|
||||
proms = proms[:, :cfg.model.resp_levels]
|
||||
|
@ -888,7 +874,7 @@ class Dataset(_Dataset):
|
|||
# extend the noise to fill the target audio
|
||||
noise = repeat_extend_audio(noise, resps.shape[0])
|
||||
# create the input prompt by merging the target audio with the noise
|
||||
proms = merge_audio( resps, noise, scale=[1, noise_scale], device="cpu" )
|
||||
proms = merge_audio( resps, noise, scale=[1, noise_scale], device=cfg.dataset.reencode_device )
|
||||
# set the target to just be the noise if <sr>
|
||||
if task == "sr":
|
||||
resps = noise
|
||||
|
@ -907,10 +893,10 @@ class Dataset(_Dataset):
|
|||
# overlay the random speaker over the target audio
|
||||
smallest_size = min(resps.shape[0], other_proms.shape[0])
|
||||
if other_proms.shape[0] == smallest_size:
|
||||
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)], device="cpu" )
|
||||
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device )
|
||||
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
|
||||
else:
|
||||
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)], device="cpu" )
|
||||
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device )
|
||||
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
|
||||
|
||||
# stitch together the proms
|
||||
|
@ -970,7 +956,7 @@ class Dataset(_Dataset):
|
|||
# extend the noise to fill the target audio
|
||||
n = repeat_extend_audio(noise, p.shape[0])
|
||||
# merge the noise over the utterance
|
||||
return merge_audio(p, n, scale=[1, noise_scale], device="cpu")
|
||||
return merge_audio(p, n, scale=[1, noise_scale], device=cfg.dataset.reencode_device)
|
||||
|
||||
# apply noise to all pieces
|
||||
pre_prom = noise_proms( pre_prom )
|
||||
|
@ -988,10 +974,12 @@ class Dataset(_Dataset):
|
|||
]
|
||||
|
||||
# create new resp
|
||||
resps = torch.cat(
|
||||
resps = concat_audio(
|
||||
([ pre_prom ] if pre_prom is not None else []) +
|
||||
[ edit_prom ] +
|
||||
([ post_prom ] if post_prom is not None else [])
|
||||
([ post_prom ] if post_prom is not None else []),
|
||||
reencode=cfg.dataset.reencode_on_concat,
|
||||
device=cfg.dataset.reencode_device,
|
||||
)
|
||||
else:
|
||||
raise Exception(f'Undefined task: {task}')
|
||||
|
|
|
@ -431,7 +431,7 @@ def encode_from_file(path, device="cuda"):
|
|||
Helper Functions
|
||||
"""
|
||||
# trims from the start, up to `target`
|
||||
def trim( qnt, target ):
|
||||
def trim( qnt, target, reencode=False ):
|
||||
length = max( qnt.shape[0], qnt.shape[1] )
|
||||
if target > 0:
|
||||
start = 0
|
||||
|
@ -446,7 +446,16 @@ def trim( qnt, target ):
|
|||
if start < 0:
|
||||
start = 0
|
||||
|
||||
return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end]
|
||||
if not reencode:
|
||||
return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end]
|
||||
|
||||
# trims on the waveform itself
|
||||
# need to test
|
||||
start = start / cfg.dataset.frames_per_second * cfg.sample_rate
|
||||
end = end / cfg.dataset.frames_per_second * cfg.sample_rate
|
||||
|
||||
wav = decode(qnt)[0]
|
||||
return encode(wav[start:end], cfg.sample_rate)[0].t()
|
||||
|
||||
# trims a random piece of audio, up to `target`
|
||||
# to-do: try and align to EnCodec window
|
||||
|
@ -470,18 +479,47 @@ def repeat_extend_audio( qnt, target ):
|
|||
|
||||
return trim(torch.cat(pieces), target)
|
||||
|
||||
# interleaves between a list of audios
|
||||
# useful for interleaving silence
|
||||
def interleave_audio( *args, audio=None ):
|
||||
qnts = [*args]
|
||||
if audio is None:
|
||||
return qnts
|
||||
|
||||
# interleave silence
|
||||
# yes there's a better way
|
||||
res = []
|
||||
for i, qnt in enumerate(qnts):
|
||||
res.append( qnt )
|
||||
if i + 1 != len(qnts):
|
||||
res.append( audio )
|
||||
|
||||
return res
|
||||
|
||||
# concats two audios together
|
||||
def concat_audio( *args, reencode=False, device="cuda", levels=cfg.model.max_levels ):
|
||||
qnts = [*args]
|
||||
# just naively combine the codes
|
||||
if not reencode:
|
||||
return torch.concat( qnts )
|
||||
|
||||
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ]
|
||||
combined = torch.concat( decoded )
|
||||
return encode(combined, cfg.sample_rate, device=device, levels=levels)[0].t()
|
||||
|
||||
# merges two quantized audios together
|
||||
# I don't know if this works
|
||||
def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ):
|
||||
# requires re-encoding because there's no good way to combine the waveforms of two audios without relying on some embedding magic
|
||||
def merge_audio( *args, device="cuda", scale=[], levels=cfg.model.max_levels ):
|
||||
qnts = [*args]
|
||||
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ]
|
||||
|
||||
# useful to adjust the volumes of each waveform
|
||||
if len(scale) == len(decoded):
|
||||
for i in range(len(scale)):
|
||||
decoded[i] = decoded[i] * scale[i]
|
||||
|
||||
combined = sum(decoded) / len(decoded)
|
||||
return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t()
|
||||
return encode(combined, cfg.sample_rate, device=device, levels=levels)[0].t()
|
||||
|
||||
"""
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue
Block a user