more stuff for similar-speaker prompt sampling (to-do: actually test if this works...)

This commit is contained in:
mrq 2024-09-16 23:10:29 -05:00
parent 69f140ba45
commit 56f25f7a9b
5 changed files with 67 additions and 23 deletions

View File

@ -15,6 +15,9 @@ from pathlib import Path
from vall_e.config import cfg
from vall_e.emb.g2p import encode as phonemize
from vall_e.emb.qnt import encode as quantize, _replace_file_extension, convert_audio
def pad(num, zeroes):
return str(num).zfill(zeroes+1)
@ -22,6 +25,13 @@ def process_items( items, stride=0, stride_offset=0 ):
items = sorted( items )
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
def load_audio( path, device="cuda" ):
waveform, sample_rate = torchaudio.load(path)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
waveform = convert_audio(waveform, sample_rate, cfg.sample_rate, 1)
return waveform.to(device=device), cfg.sample_rate
def process(
audio_backend="encodec",
input_audio="Emilia",
@ -57,10 +67,6 @@ def process(
cfg.inference.weight_dtype = dtype # "bfloat16"
cfg.inference.amp = amp # False
# import after because we've overriden the config above
# need to validate if this is even necessary anymore
from vall_e.emb.g2p import encode as phonemize
from vall_e.emb.qnt import encode as quantize, _replace_file_extension
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
@ -139,9 +145,7 @@ def process(
text = metadata["text"]
if waveform is None:
waveform, sample_rate = torchaudio.load(inpath)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
waveform, sample_rate = load_audio(inpath)
wavs.append((
outpath,

View File

@ -148,16 +148,14 @@ class Dataset:
phones_range: list[int] = field(default_factory=lambda: [4, 256]) # deprecated, the amount of phonemes an utterance can be to be included in the dataset
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset
prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0]) # the duration range the input prompts can be
# to-do: clean up the following block, it's a mess
min_utterances: int = 2 # minimum number of utterances a speaker can have
random_utterance: float = 1.0 # probability to use a different utterance rather than using the target utterance as an input prompt
max_prompts: int = 3 # maximum number of utterances that can be included in an input prompt for training
prompt_duration: float | None = None # legacy
max_resps: int = 1 # number of samples to target for training
p_resp_append: float = 1.0 # probability to append another sample to the training target
p_resp_pad_silence: float = 0.0 # probability to pad resp with silence to fit within the next window
sample_type: str = "path" # path | speaker
@ -167,6 +165,8 @@ class Dataset:
# for a full sized model with 24GiB of VRAM for Encodec, 380 seconds is 80% VRAM consumed (but it might be limited by batch size)
sample_shuffle: bool = True # i swear this is spiking the loss when sample_order = duration + sample_max_duration_batch > 0
prom_sample_similar: bool = True # if available, try and sample the prompt closest to the sampled response utterance (requires specific metadata generated)
tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

View File

@ -873,14 +873,29 @@ class Dataset(_Dataset):
return path, text, resps
def get_similar_utterance(self, spkr_name, reference, offset=0 ):
# lots of boilerplate checks
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
if not metadata_path.exists():
return None
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
if reference not in metadata:
return None
reference_metadata = metadata[reference]
if "similar" not in reference_metadata:
return None
if len(reference_metadata["similar"]) >= offset:
offset = -1
return reference_metadata["similar"][offset][0]
def sample_prompts(self, spkr_name, ignore, should_trim=True, reference=None):
def sample_prompts(self, spkr_name, reference, should_trim=True):
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0:
return None
prom_list = []
choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore}
choices = set(self.paths_by_spkr_name[spkr_name]) - {reference}
choices = [*choices]
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validation step
@ -895,9 +910,14 @@ class Dataset(_Dataset):
prom_length = 0
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) if trim else 0
# to-do: if reference is not None, find the closest utterances to the reference
for _ in range(cfg.dataset.max_prompts):
path = random.choice(choices)
if reference is not None and cfg.dataset.prom_sample_similar:
path = self.get_similar_utterance( spkr_name=spkr_name, reference=reference, offset = len(prom_list) )
# yuck
if not path:
path = random.choice(choices)
else:
path = random.choice(choices)
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
@ -1011,7 +1031,7 @@ class Dataset(_Dataset):
# Base TTS (<text><prompt> => <resp>)
if task == "tts":
proms = self.sample_prompts(spkr_name, ignore=path, reference=resps)
proms = self.sample_prompts(spkr_name, reference=path)
if cfg.dataset.inject_noise_in_prom:
# sample random noise
@ -1087,7 +1107,7 @@ class Dataset(_Dataset):
# target speech extraction ( <text><prom><resp + other resp> => <resp> )
elif task == "tse":
# sample a prompt
proms = self.sample_prompts(spkr_name, ignore=path)
proms = self.sample_prompts(spkr_name, reference=path)
# sample another speaker
_, __, other_resps = self.sample_utterance(self.sample_speakers(ignore=[spkr_name]))

View File

@ -491,8 +491,12 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat
# AudioDec uses a different pathway
if cfg.audio_backend == "audiodec":
model = _load_audiodec_model(device)
wav = wav.unsqueeze(0)
wav = convert_audio(wav, sr, model.sample_rate, 1)
# reshape (channel, samples) => (batch, channel, samples)
if wav.dim() < 3:
wav = wav.unsqueeze(0)
# skip unnecessary resample
if sr != model.sample_rate and wav.shape[1] != 1:
wav = convert_audio(wav, sr, model.sample_rate, 1)
wav = wav.to(device)
# wav = rearrange(wav, "t c -> t 1 c").to(device)
@ -502,8 +506,12 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat
# vocos does not encode wavs to encodecs, so just use normal encodec
model = _load_encodec_model(device)
wav = wav.unsqueeze(0)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
# reshape (channel, samples) => (batch, channel, samples)
if wav.dim() < 3:
wav = wav.unsqueeze(0)
# skip unnecessary resample
if sr != model.sample_rate and wav.shape[1] != model.channels:
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
wav = wav.to(device)
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):

View File

@ -48,6 +48,7 @@ def process(
amp=False,
verbose=False,
metadata_path=None,
):
cfg.set_audio_backend(audio_backend)
artifact_extension = cfg.audio_backend_extension
@ -69,7 +70,6 @@ def process(
for filename in tqdm(os.listdir(f'./{input_speaker}/'), desc="Encoding...", disable=not verbose):
extension = filename.split(".")[-1]
if text:
if extension not in artifact_extension:
raise Exception("!")
@ -140,17 +140,29 @@ def process(
if filename_a not in sorted_similarities[filename_b]:
sorted_similarities[filename_b][filename_a] = similarity
metadata = None
if metadata_path is not None:
metadata_path = metadata_path / f'{input_speaker}.json'
if metadata_path.exists():
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
# sort similarities scores
for key, sorted_similarity in sorted_similarities.items():
filename_a, filename_b = key.split(":")
sorted_similarities[key] = sorted([ ( filename, similarity ) for filename, similarity in sorted_similarity.items() ], key=lambda x: x[1], reverse=True)
most_filename, most_score = sorted_similarities[key][0]
least_filename, least_score = sorted_similarities[key][-1]
if metadata is not None and filename_a in metadata:
metadata[filename_a] = sorted_similarities
if verbose:
print( f'{key}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' )
# to-do: store this somewhere
if metadata is not None:
with open(str(metadata_path), "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) )
return sorted_similarities