added option to either naively concat codes to concat audio waveforms (prior behavior) or to decode => concat => encode instead (although this only currently happens for prom sampling if an utternace is too small)
This commit is contained in:
parent
97e768601c
commit
bccbb77a1a
|
@ -137,8 +137,9 @@ class Dataset:
|
||||||
|
|
||||||
hdf5_name: str = "data.h5"
|
hdf5_name: str = "data.h5"
|
||||||
use_hdf5: bool = False
|
use_hdf5: bool = False
|
||||||
use_metadata: bool = False
|
|
||||||
hdf5_flag: str = "a"
|
hdf5_flag: str = "a"
|
||||||
|
use_metadata: bool = False
|
||||||
|
|
||||||
validate: bool = True
|
validate: bool = True
|
||||||
workers: int = 8
|
workers: int = 8
|
||||||
cache: bool = True
|
cache: bool = True
|
||||||
|
@ -163,6 +164,8 @@ class Dataset:
|
||||||
sample_shuffle: bool = True #
|
sample_shuffle: bool = True #
|
||||||
|
|
||||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
||||||
|
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
||||||
|
reencode_device: str = "cuda" # "cpu" is slower but saves memory
|
||||||
|
|
||||||
_frames_per_second: int = 0 # allows setting your own hint
|
_frames_per_second: int = 0 # allows setting your own hint
|
||||||
|
|
||||||
|
@ -666,7 +669,7 @@ class Optimizations:
|
||||||
class Config(BaseConfig):
|
class Config(BaseConfig):
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
mode: str = "training" # "inferencing"
|
mode: str = "training" # "inferencing"
|
||||||
experimental: bool = False # So I can stop commenting out things when committing
|
experimental: bool = False # Debug flag, unused now
|
||||||
|
|
||||||
dataset: Dataset = field(default_factory=lambda: Dataset)
|
dataset: Dataset = field(default_factory=lambda: Dataset)
|
||||||
models: dict | list | None = field(default_factory=lambda: [])
|
models: dict | list | None = field(default_factory=lambda: [])
|
||||||
|
|
|
@ -11,7 +11,7 @@ import torch
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file
|
||||||
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
||||||
from .utils.distributed import global_rank, local_rank, world_size
|
from .utils.distributed import global_rank, local_rank, world_size
|
||||||
|
|
||||||
|
@ -541,12 +541,7 @@ class Dataset(_Dataset):
|
||||||
self.tone_symmap = self._get_tone_symmap()
|
self.tone_symmap = self._get_tone_symmap()
|
||||||
self.task_symmap = self._get_task_symmap()
|
self.task_symmap = self._get_task_symmap()
|
||||||
|
|
||||||
"""
|
# grab IDs for bos, space, and eos for easy input creation later
|
||||||
self.empty_text = tokenize(" ")
|
|
||||||
if len(self.empty_text) == 4:
|
|
||||||
self.empty_text = self.empty_text[:1] + self.empty_text[1:2] + self.empty_text[-1:]
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.empty_text = [ cfg.tokenizer._bos_token, cfg.tokenizer.get_vocab()[" "], cfg.tokenizer._eos_token ]
|
self.empty_text = [ cfg.tokenizer._bos_token, cfg.tokenizer.get_vocab()[" "], cfg.tokenizer._eos_token ]
|
||||||
|
|
||||||
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
|
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
|
||||||
|
@ -743,7 +738,7 @@ class Dataset(_Dataset):
|
||||||
qnt = _load_quants(path, return_metadata=False)
|
qnt = _load_quants(path, return_metadata=False)
|
||||||
|
|
||||||
if 0 < trim_length and trim_length < qnt.shape[0]:
|
if 0 < trim_length and trim_length < qnt.shape[0]:
|
||||||
qnt = trim( qnt, trim_length )
|
qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat )
|
||||||
|
|
||||||
prom_list.append(qnt)
|
prom_list.append(qnt)
|
||||||
prom_length += qnt.shape[0]
|
prom_length += qnt.shape[0]
|
||||||
|
@ -756,7 +751,7 @@ class Dataset(_Dataset):
|
||||||
prom = torch.cat(prom_list)
|
prom = torch.cat(prom_list)
|
||||||
|
|
||||||
if 0 < trim_length and trim_length < prom.shape[0]:
|
if 0 < trim_length and trim_length < prom.shape[0]:
|
||||||
prom = trim( prom, trim_length )
|
prom = trim( prom, trim_length, reencode=cfg.dataset.reencode_on_concat )
|
||||||
|
|
||||||
return prom
|
return prom
|
||||||
|
|
||||||
|
@ -814,15 +809,13 @@ class Dataset(_Dataset):
|
||||||
lang = torch.tensor([self.lang_symmap[lang]]).to(torch.uint8)
|
lang = torch.tensor([self.lang_symmap[lang]]).to(torch.uint8)
|
||||||
tone = torch.tensor([self.tone_symmap[tone]]).to(torch.uint8)
|
tone = torch.tensor([self.tone_symmap[tone]]).to(torch.uint8)
|
||||||
|
|
||||||
naive = True
|
# a bool to easily experiment with two mindsets later
|
||||||
|
naive = cfg.experimental
|
||||||
|
|
||||||
# disabled because I haven't actually needed to use it myself, and I can't be assed to validate if it still works
|
|
||||||
# it probably is better to pad with silence instead of just stitching utterances and ruining things
|
|
||||||
"""
|
|
||||||
# append additional prompts in an attempt to artifically increase lengths / offer new data
|
# append additional prompts in an attempt to artifically increase lengths / offer new data
|
||||||
if cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
|
if cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
|
||||||
ignore_paths = []
|
ignore_paths = []
|
||||||
for _ in range( cfg.dataset.max_resps - 1 ):
|
for _ in range( 1, cfg.dataset.max_resps ):
|
||||||
path, txt, qnt = self.sample_utterance(spkr_name, ignore=ignore_paths)
|
path, txt, qnt = self.sample_utterance(spkr_name, ignore=ignore_paths)
|
||||||
ignore_paths.append(path)
|
ignore_paths.append(path)
|
||||||
|
|
||||||
|
@ -836,14 +829,7 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
# might be better to decode => concat waveforms with silence in between => reencode
|
# might be better to decode => concat waveforms with silence in between => reencode
|
||||||
# as you technically can't just append encodec sequences together like this without issues
|
# as you technically can't just append encodec sequences together like this without issues
|
||||||
resps = torch.concat([ resps, qnt ])
|
resps = concat_audio( resps, qnt, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
|
||||||
task = "tts"
|
|
||||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
|
||||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
resps = resps[:, :cfg.model.resp_levels]
|
resps = resps[:, :cfg.model.resp_levels]
|
||||||
|
@ -888,7 +874,7 @@ class Dataset(_Dataset):
|
||||||
# extend the noise to fill the target audio
|
# extend the noise to fill the target audio
|
||||||
noise = repeat_extend_audio(noise, resps.shape[0])
|
noise = repeat_extend_audio(noise, resps.shape[0])
|
||||||
# create the input prompt by merging the target audio with the noise
|
# create the input prompt by merging the target audio with the noise
|
||||||
proms = merge_audio( resps, noise, scale=[1, noise_scale], device="cpu" )
|
proms = merge_audio( resps, noise, scale=[1, noise_scale], device=cfg.dataset.reencode_device )
|
||||||
# set the target to just be the noise if <sr>
|
# set the target to just be the noise if <sr>
|
||||||
if task == "sr":
|
if task == "sr":
|
||||||
resps = noise
|
resps = noise
|
||||||
|
@ -907,10 +893,10 @@ class Dataset(_Dataset):
|
||||||
# overlay the random speaker over the target audio
|
# overlay the random speaker over the target audio
|
||||||
smallest_size = min(resps.shape[0], other_proms.shape[0])
|
smallest_size = min(resps.shape[0], other_proms.shape[0])
|
||||||
if other_proms.shape[0] == smallest_size:
|
if other_proms.shape[0] == smallest_size:
|
||||||
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)], device="cpu" )
|
noisy_proms = merge_audio( resps[:smallest_size, :], other_proms, scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device )
|
||||||
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
|
noisy_proms = torch.cat( [ noisy_proms, resps[smallest_size:, :] ] )
|
||||||
else:
|
else:
|
||||||
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)], device="cpu" )
|
noisy_proms = merge_audio( resps, other_proms[:smallest_size, :], scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device )
|
||||||
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
|
noisy_proms = torch.cat( [ noisy_proms, other_proms[smallest_size:, :] ] )
|
||||||
|
|
||||||
# stitch together the proms
|
# stitch together the proms
|
||||||
|
@ -970,7 +956,7 @@ class Dataset(_Dataset):
|
||||||
# extend the noise to fill the target audio
|
# extend the noise to fill the target audio
|
||||||
n = repeat_extend_audio(noise, p.shape[0])
|
n = repeat_extend_audio(noise, p.shape[0])
|
||||||
# merge the noise over the utterance
|
# merge the noise over the utterance
|
||||||
return merge_audio(p, n, scale=[1, noise_scale], device="cpu")
|
return merge_audio(p, n, scale=[1, noise_scale], device=cfg.dataset.reencode_device)
|
||||||
|
|
||||||
# apply noise to all pieces
|
# apply noise to all pieces
|
||||||
pre_prom = noise_proms( pre_prom )
|
pre_prom = noise_proms( pre_prom )
|
||||||
|
@ -988,10 +974,12 @@ class Dataset(_Dataset):
|
||||||
]
|
]
|
||||||
|
|
||||||
# create new resp
|
# create new resp
|
||||||
resps = torch.cat(
|
resps = concat_audio(
|
||||||
([ pre_prom ] if pre_prom is not None else []) +
|
([ pre_prom ] if pre_prom is not None else []) +
|
||||||
[ edit_prom ] +
|
[ edit_prom ] +
|
||||||
([ post_prom ] if post_prom is not None else [])
|
([ post_prom ] if post_prom is not None else []),
|
||||||
|
reencode=cfg.dataset.reencode_on_concat,
|
||||||
|
device=cfg.dataset.reencode_device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception(f'Undefined task: {task}')
|
raise Exception(f'Undefined task: {task}')
|
||||||
|
|
|
@ -431,7 +431,7 @@ def encode_from_file(path, device="cuda"):
|
||||||
Helper Functions
|
Helper Functions
|
||||||
"""
|
"""
|
||||||
# trims from the start, up to `target`
|
# trims from the start, up to `target`
|
||||||
def trim( qnt, target ):
|
def trim( qnt, target, reencode=False ):
|
||||||
length = max( qnt.shape[0], qnt.shape[1] )
|
length = max( qnt.shape[0], qnt.shape[1] )
|
||||||
if target > 0:
|
if target > 0:
|
||||||
start = 0
|
start = 0
|
||||||
|
@ -446,8 +446,17 @@ def trim( qnt, target ):
|
||||||
if start < 0:
|
if start < 0:
|
||||||
start = 0
|
start = 0
|
||||||
|
|
||||||
|
if not reencode:
|
||||||
return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end]
|
return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end]
|
||||||
|
|
||||||
|
# trims on the waveform itself
|
||||||
|
# need to test
|
||||||
|
start = start / cfg.dataset.frames_per_second * cfg.sample_rate
|
||||||
|
end = end / cfg.dataset.frames_per_second * cfg.sample_rate
|
||||||
|
|
||||||
|
wav = decode(qnt)[0]
|
||||||
|
return encode(wav[start:end], cfg.sample_rate)[0].t()
|
||||||
|
|
||||||
# trims a random piece of audio, up to `target`
|
# trims a random piece of audio, up to `target`
|
||||||
# to-do: try and align to EnCodec window
|
# to-do: try and align to EnCodec window
|
||||||
def trim_random( qnt, target ):
|
def trim_random( qnt, target ):
|
||||||
|
@ -470,18 +479,47 @@ def repeat_extend_audio( qnt, target ):
|
||||||
|
|
||||||
return trim(torch.cat(pieces), target)
|
return trim(torch.cat(pieces), target)
|
||||||
|
|
||||||
|
# interleaves between a list of audios
|
||||||
|
# useful for interleaving silence
|
||||||
|
def interleave_audio( *args, audio=None ):
|
||||||
|
qnts = [*args]
|
||||||
|
if audio is None:
|
||||||
|
return qnts
|
||||||
|
|
||||||
|
# interleave silence
|
||||||
|
# yes there's a better way
|
||||||
|
res = []
|
||||||
|
for i, qnt in enumerate(qnts):
|
||||||
|
res.append( qnt )
|
||||||
|
if i + 1 != len(qnts):
|
||||||
|
res.append( audio )
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
# concats two audios together
|
||||||
|
def concat_audio( *args, reencode=False, device="cuda", levels=cfg.model.max_levels ):
|
||||||
|
qnts = [*args]
|
||||||
|
# just naively combine the codes
|
||||||
|
if not reencode:
|
||||||
|
return torch.concat( qnts )
|
||||||
|
|
||||||
|
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ]
|
||||||
|
combined = torch.concat( decoded )
|
||||||
|
return encode(combined, cfg.sample_rate, device=device, levels=levels)[0].t()
|
||||||
|
|
||||||
# merges two quantized audios together
|
# merges two quantized audios together
|
||||||
# I don't know if this works
|
# requires re-encoding because there's no good way to combine the waveforms of two audios without relying on some embedding magic
|
||||||
def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ):
|
def merge_audio( *args, device="cuda", scale=[], levels=cfg.model.max_levels ):
|
||||||
qnts = [*args]
|
qnts = [*args]
|
||||||
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ]
|
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ]
|
||||||
|
|
||||||
|
# useful to adjust the volumes of each waveform
|
||||||
if len(scale) == len(decoded):
|
if len(scale) == len(decoded):
|
||||||
for i in range(len(scale)):
|
for i in range(len(scale)):
|
||||||
decoded[i] = decoded[i] * scale[i]
|
decoded[i] = decoded[i] * scale[i]
|
||||||
|
|
||||||
combined = sum(decoded) / len(decoded)
|
combined = sum(decoded) / len(decoded)
|
||||||
return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t()
|
return encode(combined, cfg.sample_rate, device=device, levels=levels)[0].t()
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue
Block a user