more stuff for similar-speaker prompt sampling (to-do: actually test if this works...)
This commit is contained in:
parent
69f140ba45
commit
56f25f7a9b
|
@ -15,6 +15,9 @@ from pathlib import Path
|
||||||
|
|
||||||
from vall_e.config import cfg
|
from vall_e.config import cfg
|
||||||
|
|
||||||
|
from vall_e.emb.g2p import encode as phonemize
|
||||||
|
from vall_e.emb.qnt import encode as quantize, _replace_file_extension, convert_audio
|
||||||
|
|
||||||
def pad(num, zeroes):
|
def pad(num, zeroes):
|
||||||
return str(num).zfill(zeroes+1)
|
return str(num).zfill(zeroes+1)
|
||||||
|
|
||||||
|
@ -22,6 +25,13 @@ def process_items( items, stride=0, stride_offset=0 ):
|
||||||
items = sorted( items )
|
items = sorted( items )
|
||||||
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
|
||||||
|
|
||||||
|
def load_audio( path, device="cuda" ):
|
||||||
|
waveform, sample_rate = torchaudio.load(path)
|
||||||
|
if waveform.shape[0] > 1:
|
||||||
|
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||||
|
waveform = convert_audio(waveform, sample_rate, cfg.sample_rate, 1)
|
||||||
|
return waveform.to(device=device), cfg.sample_rate
|
||||||
|
|
||||||
def process(
|
def process(
|
||||||
audio_backend="encodec",
|
audio_backend="encodec",
|
||||||
input_audio="Emilia",
|
input_audio="Emilia",
|
||||||
|
@ -57,10 +67,6 @@ def process(
|
||||||
cfg.inference.weight_dtype = dtype # "bfloat16"
|
cfg.inference.weight_dtype = dtype # "bfloat16"
|
||||||
cfg.inference.amp = amp # False
|
cfg.inference.amp = amp # False
|
||||||
|
|
||||||
# import after because we've overriden the config above
|
|
||||||
# need to validate if this is even necessary anymore
|
|
||||||
from vall_e.emb.g2p import encode as phonemize
|
|
||||||
from vall_e.emb.qnt import encode as quantize, _replace_file_extension
|
|
||||||
|
|
||||||
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
|
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
|
||||||
|
|
||||||
|
@ -139,9 +145,7 @@ def process(
|
||||||
text = metadata["text"]
|
text = metadata["text"]
|
||||||
|
|
||||||
if waveform is None:
|
if waveform is None:
|
||||||
waveform, sample_rate = torchaudio.load(inpath)
|
waveform, sample_rate = load_audio(inpath)
|
||||||
if waveform.shape[0] > 1:
|
|
||||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
|
||||||
|
|
||||||
wavs.append((
|
wavs.append((
|
||||||
outpath,
|
outpath,
|
||||||
|
|
|
@ -148,16 +148,14 @@ class Dataset:
|
||||||
phones_range: list[int] = field(default_factory=lambda: [4, 256]) # deprecated, the amount of phonemes an utterance can be to be included in the dataset
|
phones_range: list[int] = field(default_factory=lambda: [4, 256]) # deprecated, the amount of phonemes an utterance can be to be included in the dataset
|
||||||
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset
|
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset
|
||||||
prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0]) # the duration range the input prompts can be
|
prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0]) # the duration range the input prompts can be
|
||||||
|
|
||||||
|
# to-do: clean up the following block, it's a mess
|
||||||
min_utterances: int = 2 # minimum number of utterances a speaker can have
|
min_utterances: int = 2 # minimum number of utterances a speaker can have
|
||||||
|
|
||||||
random_utterance: float = 1.0 # probability to use a different utterance rather than using the target utterance as an input prompt
|
random_utterance: float = 1.0 # probability to use a different utterance rather than using the target utterance as an input prompt
|
||||||
max_prompts: int = 3 # maximum number of utterances that can be included in an input prompt for training
|
max_prompts: int = 3 # maximum number of utterances that can be included in an input prompt for training
|
||||||
|
|
||||||
prompt_duration: float | None = None # legacy
|
prompt_duration: float | None = None # legacy
|
||||||
|
|
||||||
max_resps: int = 1 # number of samples to target for training
|
max_resps: int = 1 # number of samples to target for training
|
||||||
p_resp_append: float = 1.0 # probability to append another sample to the training target
|
p_resp_append: float = 1.0 # probability to append another sample to the training target
|
||||||
|
|
||||||
p_resp_pad_silence: float = 0.0 # probability to pad resp with silence to fit within the next window
|
p_resp_pad_silence: float = 0.0 # probability to pad resp with silence to fit within the next window
|
||||||
|
|
||||||
sample_type: str = "path" # path | speaker
|
sample_type: str = "path" # path | speaker
|
||||||
|
@ -167,6 +165,8 @@ class Dataset:
|
||||||
# for a full sized model with 24GiB of VRAM for Encodec, 380 seconds is 80% VRAM consumed (but it might be limited by batch size)
|
# for a full sized model with 24GiB of VRAM for Encodec, 380 seconds is 80% VRAM consumed (but it might be limited by batch size)
|
||||||
sample_shuffle: bool = True # i swear this is spiking the loss when sample_order = duration + sample_max_duration_batch > 0
|
sample_shuffle: bool = True # i swear this is spiking the loss when sample_order = duration + sample_max_duration_batch > 0
|
||||||
|
|
||||||
|
prom_sample_similar: bool = True # if available, try and sample the prompt closest to the sampled response utterance (requires specific metadata generated)
|
||||||
|
|
||||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against
|
tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against
|
||||||
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
||||||
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
|
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
|
||||||
|
|
|
@ -873,14 +873,29 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
return path, text, resps
|
return path, text, resps
|
||||||
|
|
||||||
|
def get_similar_utterance(self, spkr_name, reference, offset=0 ):
|
||||||
|
# lots of boilerplate checks
|
||||||
|
metadata_path = Path(f"{metadata_root}/{speaker_name}.json")
|
||||||
|
if not metadata_path.exists():
|
||||||
|
return None
|
||||||
|
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
||||||
|
if reference not in metadata:
|
||||||
|
return None
|
||||||
|
reference_metadata = metadata[reference]
|
||||||
|
if "similar" not in reference_metadata:
|
||||||
|
return None
|
||||||
|
if len(reference_metadata["similar"]) >= offset:
|
||||||
|
offset = -1
|
||||||
|
|
||||||
|
return reference_metadata["similar"][offset][0]
|
||||||
|
|
||||||
def sample_prompts(self, spkr_name, ignore, should_trim=True, reference=None):
|
def sample_prompts(self, spkr_name, reference, should_trim=True):
|
||||||
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0:
|
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
prom_list = []
|
prom_list = []
|
||||||
|
|
||||||
choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore}
|
choices = set(self.paths_by_spkr_name[spkr_name]) - {reference}
|
||||||
choices = [*choices]
|
choices = [*choices]
|
||||||
|
|
||||||
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validation step
|
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validation step
|
||||||
|
@ -895,9 +910,14 @@ class Dataset(_Dataset):
|
||||||
prom_length = 0
|
prom_length = 0
|
||||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) if trim else 0
|
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) if trim else 0
|
||||||
|
|
||||||
# to-do: if reference is not None, find the closest utterances to the reference
|
|
||||||
for _ in range(cfg.dataset.max_prompts):
|
for _ in range(cfg.dataset.max_prompts):
|
||||||
path = random.choice(choices)
|
if reference is not None and cfg.dataset.prom_sample_similar:
|
||||||
|
path = self.get_similar_utterance( spkr_name=spkr_name, reference=reference, offset = len(prom_list) )
|
||||||
|
# yuck
|
||||||
|
if not path:
|
||||||
|
path = random.choice(choices)
|
||||||
|
else:
|
||||||
|
path = random.choice(choices)
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
key = _get_hdf5_path(path)
|
key = _get_hdf5_path(path)
|
||||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||||
|
@ -1011,7 +1031,7 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
# Base TTS (<text><prompt> => <resp>)
|
# Base TTS (<text><prompt> => <resp>)
|
||||||
if task == "tts":
|
if task == "tts":
|
||||||
proms = self.sample_prompts(spkr_name, ignore=path, reference=resps)
|
proms = self.sample_prompts(spkr_name, reference=path)
|
||||||
|
|
||||||
if cfg.dataset.inject_noise_in_prom:
|
if cfg.dataset.inject_noise_in_prom:
|
||||||
# sample random noise
|
# sample random noise
|
||||||
|
@ -1087,7 +1107,7 @@ class Dataset(_Dataset):
|
||||||
# target speech extraction ( <text><prom><resp + other resp> => <resp> )
|
# target speech extraction ( <text><prom><resp + other resp> => <resp> )
|
||||||
elif task == "tse":
|
elif task == "tse":
|
||||||
# sample a prompt
|
# sample a prompt
|
||||||
proms = self.sample_prompts(spkr_name, ignore=path)
|
proms = self.sample_prompts(spkr_name, reference=path)
|
||||||
|
|
||||||
# sample another speaker
|
# sample another speaker
|
||||||
_, __, other_resps = self.sample_utterance(self.sample_speakers(ignore=[spkr_name]))
|
_, __, other_resps = self.sample_utterance(self.sample_speakers(ignore=[spkr_name]))
|
||||||
|
|
|
@ -491,8 +491,12 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat
|
||||||
# AudioDec uses a different pathway
|
# AudioDec uses a different pathway
|
||||||
if cfg.audio_backend == "audiodec":
|
if cfg.audio_backend == "audiodec":
|
||||||
model = _load_audiodec_model(device)
|
model = _load_audiodec_model(device)
|
||||||
wav = wav.unsqueeze(0)
|
# reshape (channel, samples) => (batch, channel, samples)
|
||||||
wav = convert_audio(wav, sr, model.sample_rate, 1)
|
if wav.dim() < 3:
|
||||||
|
wav = wav.unsqueeze(0)
|
||||||
|
# skip unnecessary resample
|
||||||
|
if sr != model.sample_rate and wav.shape[1] != 1:
|
||||||
|
wav = convert_audio(wav, sr, model.sample_rate, 1)
|
||||||
wav = wav.to(device)
|
wav = wav.to(device)
|
||||||
|
|
||||||
# wav = rearrange(wav, "t c -> t 1 c").to(device)
|
# wav = rearrange(wav, "t c -> t 1 c").to(device)
|
||||||
|
@ -502,8 +506,12 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", return_metadat
|
||||||
|
|
||||||
# vocos does not encode wavs to encodecs, so just use normal encodec
|
# vocos does not encode wavs to encodecs, so just use normal encodec
|
||||||
model = _load_encodec_model(device)
|
model = _load_encodec_model(device)
|
||||||
wav = wav.unsqueeze(0)
|
# reshape (channel, samples) => (batch, channel, samples)
|
||||||
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
if wav.dim() < 3:
|
||||||
|
wav = wav.unsqueeze(0)
|
||||||
|
# skip unnecessary resample
|
||||||
|
if sr != model.sample_rate and wav.shape[1] != model.channels:
|
||||||
|
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
||||||
wav = wav.to(device)
|
wav = wav.to(device)
|
||||||
|
|
||||||
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||||
|
|
|
@ -48,6 +48,7 @@ def process(
|
||||||
amp=False,
|
amp=False,
|
||||||
|
|
||||||
verbose=False,
|
verbose=False,
|
||||||
|
metadata_path=None,
|
||||||
):
|
):
|
||||||
cfg.set_audio_backend(audio_backend)
|
cfg.set_audio_backend(audio_backend)
|
||||||
artifact_extension = cfg.audio_backend_extension
|
artifact_extension = cfg.audio_backend_extension
|
||||||
|
@ -69,7 +70,6 @@ def process(
|
||||||
for filename in tqdm(os.listdir(f'./{input_speaker}/'), desc="Encoding...", disable=not verbose):
|
for filename in tqdm(os.listdir(f'./{input_speaker}/'), desc="Encoding...", disable=not verbose):
|
||||||
extension = filename.split(".")[-1]
|
extension = filename.split(".")[-1]
|
||||||
|
|
||||||
|
|
||||||
if text:
|
if text:
|
||||||
if extension not in artifact_extension:
|
if extension not in artifact_extension:
|
||||||
raise Exception("!")
|
raise Exception("!")
|
||||||
|
@ -140,17 +140,29 @@ def process(
|
||||||
if filename_a not in sorted_similarities[filename_b]:
|
if filename_a not in sorted_similarities[filename_b]:
|
||||||
sorted_similarities[filename_b][filename_a] = similarity
|
sorted_similarities[filename_b][filename_a] = similarity
|
||||||
|
|
||||||
|
metadata = None
|
||||||
|
if metadata_path is not None:
|
||||||
|
metadata_path = metadata_path / f'{input_speaker}.json'
|
||||||
|
if metadata_path.exists():
|
||||||
|
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
||||||
|
|
||||||
# sort similarities scores
|
# sort similarities scores
|
||||||
for key, sorted_similarity in sorted_similarities.items():
|
for key, sorted_similarity in sorted_similarities.items():
|
||||||
|
filename_a, filename_b = key.split(":")
|
||||||
sorted_similarities[key] = sorted([ ( filename, similarity ) for filename, similarity in sorted_similarity.items() ], key=lambda x: x[1], reverse=True)
|
sorted_similarities[key] = sorted([ ( filename, similarity ) for filename, similarity in sorted_similarity.items() ], key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
most_filename, most_score = sorted_similarities[key][0]
|
most_filename, most_score = sorted_similarities[key][0]
|
||||||
least_filename, least_score = sorted_similarities[key][-1]
|
least_filename, least_score = sorted_similarities[key][-1]
|
||||||
|
|
||||||
|
if metadata is not None and filename_a in metadata:
|
||||||
|
metadata[filename_a] = sorted_similarities
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print( f'{key}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' )
|
print( f'{key}:\n\tMost: {most_filename} ({most_score:.3f})\n\tLeast: {least_filename} ({least_score:.3f})' )
|
||||||
|
|
||||||
# to-do: store this somewhere
|
if metadata is not None:
|
||||||
|
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
||||||
|
f.write( json.dumps( metadata ) )
|
||||||
|
|
||||||
return sorted_similarities
|
return sorted_similarities
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user