more stuff for similar-speaker prompt sampling (to-do: actually test if this works...)
This commit is contained in:
parent
69f140ba45
commit
56f25f7a9b
|
@ -15,6 +15,9 @@ from pathlib import Path
|
|||
|
||||
from vall_e.config import cfg
|
||||
|
||||
from vall_e.emb.g2p import encode as phonemize
|
||||
from vall_e.emb.qnt import encode as quantize, _replace_file_extension, convert_audio
|
||||
|
||||
def pad(num, zeroes):
|
||||
return str(num).zfill(zeroes+1)
|
||||
|
||||
|
@ -22,6 +25,13 @@ def process_items( items, stride=0, stride_offset=0 ):
|
|||
items = sorted( items )
|
||||
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
||||
|
||||
def load_audio( path, device="cuda" ):
|
||||
waveform, sample_rate = torchaudio.load(path)
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
waveform = convert_audio(waveform, sample_rate, cfg.sample_rate, 1)
|
||||
return waveform.to(device=device), cfg.sample_rate
|
||||
|
||||
def process(
|
||||
audio_backend="encodec",
|
||||
input_audio="Emilia",
|
||||
|
@ -57,10 +67,6 @@ def process(
|
|||
cfg.inference.weight_dtype = dtype # "bfloat16"
|
||||
cfg.inference.amp = amp # False
|
||||
|
||||
# import after because we've overriden the config above
|
||||
# need to validate if this is even necessary anymore
|
||||
from vall_e.emb.g2p import encode as phonemize
|
||||
from vall_e.emb.qnt import encode as quantize, _replace_file_extension
|
||||
|
||||
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
|
||||
|
||||
|
@ -139,9 +145,7 @@ def process(
|
|||
text = metadata["text"]
|
||||
|
||||
if waveform is None:
|
||||
waveform, sample_rate = torchaudio.load(inpath)
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
waveform, sample_rate = load_audio(inpath)
|
||||
|
||||
wavs.append((
|
||||
outpath,
|
||||
|
|
|
@ -148,16 +148,14 @@ class Dataset:
|
|||
phones_range: list[int] = field(default_factory=lambda: [4, 256]) # deprecated, the amount of phonemes an utterance can be to be included in the dataset
|
||||
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset
|
||||
prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0]) # the duration range the input prompts can be
|
||||
|
||||
# to-do: clean up the following block, it's a mess
|
||||
min_utterances: int = 2 # minimum number of utterances a speaker can have
|
||||
|
||||
random_utterance: float = 1.0 # probability to use a different utterance rather than using the target utterance as an input prompt
|
||||
max_prompts: int = 3 # maximum number of utterances that can be included in an input prompt for training
|
||||
|
||||
prompt_duration: float | None = None # legacy
|
||||
|
||||
max_resps: int = 1 # number of samples to target for training
|
||||
p_resp_append: float = 1.0 # probability to append another sample to the training target
|
||||
|
||||
p_resp_pad_silence: float = 0.0 # probability to pad resp with silence to fit within the next window
|
||||
|
||||
sample_type: str = "path" # path | speaker
|
||||
|
@ -167,6 +165,8 @@ class Dataset:
|
|||
# for a full sized model with 24GiB of VRAM for Encodec, 380 seconds is 80% VRAM consumed (but it might be limited by batch size)
|
||||
sample_shuffle: bool = True # i swear this is spiking the loss when sample_order = duration + sample_max_duration_batch > 0
|
||||
|
||||
prom_sample_similar: bool = True # if available, try and sample the prompt closest to the sampled response utterance (requires specific metadata generated)
|
||||
|
||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against
|
||||
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
||||
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
|
||||
|
|
|
@ -873,14 +873,29 @@ class Dataset(_Dataset):
|
|||
|
||||
return path, text, resps
|
||||
|
||||
def get_similar_utterance(self, spkr_name, reference, offset=0 ):
|
||||
# lots of boilerplate checks
|
||||
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
||||
if not metadata_path.exists():
|
||||
return None
|
||||
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
||||
if reference not in metadata:
|
||||
return None
|
||||
reference_metadata = metadata[reference]
|
||||
if "similar" not in reference_metadata:
|
||||
return None
|
||||
if len(reference_metadata["similar"]) >= offset:
|
||||
offset = -1
|
||||
|
||||
return reference_metadata["similar"][offset][0]
|
||||
|
||||
def sample_prompts(self, spkr_name, ignore, should_trim=True, reference=None):
|
||||
def sample_prompts(self, spkr_name, reference, should_trim=True):
|
||||
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0:
|
||||
return None
|
||||
|
||||
prom_list = []
|
||||
|
||||
choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore}
|
||||
choices = set(self.paths_by_spkr_name[spkr_name]) - {reference}
|
||||
choices = [*choices]
|
||||
|
||||
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validation step
|
||||
|
@ -895,9 +910,14 @@ class Dataset(_Dataset):
|
|||
prom_length = 0
|
||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) if trim else 0
|
||||
|
||||
# to-do: if reference is not None, find the closest utterances to the reference
|
||||
for _ in range(cfg.dataset.max_prompts):
|
||||
path = random.choice(choices)
|
||||
if reference is not None and cfg.dataset.prom_sample_similar:
|
||||
path = self.get_similar_utterance( spkr_name=spkr_name, reference=reference, offset = len(prom_list) )
|
||||
# yuck
|
||||
if not path:
|
||||
path = random.choice(choices)
|
||||
else:
|
||||
path = random.choice(choices)
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||
|
@ -1011,7 +1031,7 @@ class Dataset(_Dataset):
|
|||
|
||||
# Base TTS (<text><prompt> => <resp>)
|
||||
if task == "tts":
|
||||
proms = self.sample_prompts(spkr_name, ignore=path, reference=resps)
|
||||
proms = self.sample_prompts(spkr_name, reference=path)
|
||||
|
||||
if cfg.dataset.inject_noise_in_prom:
|
||||
# sample random noise
|
||||
|
@ -1087,7 +1107,7 @@ class Dataset(_Dataset):
|
|||
# target speech extraction ( <text><prom><resp + other resp> => <resp> )
|
||||
elif task == "tse":
|
||||
# sample a prompt
|
||||
proms = self.sample_prompts(spkr_name, ignore=path)
|
||||
proms = self.sample_prompts(spkr_name, reference=path)
|
||||
|
||||
# sample another speaker
|
||||
_, __, other_resps = self.sample_utterance(self.sample_speakers(ignore=[spkr_name]))
|
||||
|
|
|
@ -491,8 +491,12 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat
|
|||
# AudioDec uses a different pathway
|
||||
if cfg.audio_backend == "audiodec":
|
||||
model = _load_audiodec_model(device)
|
||||
wav = wav.unsqueeze(0)
|
||||
wav = convert_audio(wav, sr, model.sample_rate, 1)
|
||||
# reshape (channel, samples) => (batch, channel, samples)
|
||||
if wav.dim() < 3:
|
||||
wav = wav.unsqueeze(0)
|
||||
# skip unnecessary resample
|
||||
if sr != model.sample_rate and wav.shape[1] != 1:
|
||||
wav = convert_audio(wav, sr, model.sample_rate, 1)
|
||||
wav = wav.to(device)
|
||||
|
||||
# wav = rearrange(wav, "t c -> t 1 c").to(device)
|
||||
|
@ -502,8 +506,12 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat
|
|||
|
||||
# vocos does not encode wavs to encodecs, so just use normal encodec
|
||||
model = _load_encodec_model(device)
|
||||
wav = wav.unsqueeze(0)
|
||||
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
||||
# reshape (channel, samples) => (batch, channel, samples)
|
||||
if wav.dim() < 3:
|
||||
wav = wav.unsqueeze(0)
|
||||
# skip unnecessary resample
|
||||
if sr != model.sample_rate and wav.shape[1] != model.channels:
|
||||
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
||||
wav = wav.to(device)
|
||||
|
||||
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||
|
|
|
@ -48,6 +48,7 @@ def process(
|
|||
amp=False,
|
||||
|
||||
verbose=False,
|
||||
metadata_path=None,
|
||||
):
|
||||
cfg.set_audio_backend(audio_backend)
|
||||
artifact_extension = cfg.audio_backend_extension
|
||||
|
@ -69,7 +70,6 @@ def process(
|
|||
for filename in tqdm(os.listdir(f'./{input_speaker}/'), desc="Encoding...", disable=not verbose):
|
||||
extension = filename.split(".")[-1]
|
||||
|
||||
|
||||
if text:
|
||||
if extension not in artifact_extension:
|
||||
raise Exception("!")
|
||||
|
@ -140,17 +140,29 @@ def process(
|
|||
if filename_a not in sorted_similarities[filename_b]:
|
||||
sorted_similarities[filename_b][filename_a] = similarity
|
||||
|
||||
metadata = None
|
||||
if metadata_path is not None:
|
||||
metadata_path = metadata_path / f'{input_speaker}.json'
|
||||
if metadata_path.exists():
|
||||
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
||||
|
||||
# sort similarities scores
|
||||
for key, sorted_similarity in sorted_similarities.items():
|
||||
filename_a, filename_b = key.split(":")
|
||||
sorted_similarities[key] = sorted([ ( filename, similarity ) for filename, similarity in sorted_similarity.items() ], key=lambda x: x[1], reverse=True)
|
||||
|
||||
most_filename, most_score = sorted_similarities[key][0]
|
||||
least_filename, least_score = sorted_similarities[key][-1]
|
||||
|
||||
if metadata is not None and filename_a in metadata:
|
||||
metadata[filename_a] = sorted_similarities
|
||||
|
||||
if verbose:
|
||||
print( f'{key}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' )
|
||||
|
||||
# to-do: store this somewhere
|
||||
if metadata is not None:
|
||||
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
||||
f.write( json.dumps( metadata ) )
|
||||
|
||||
return sorted_similarities
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user