added sampling by speaker group name (might be better to de-emphasize the LibriVox/Audiobooks that are in large numbers, and emphasize the smaller pools), log cleanup

This commit is contained in:
mrq 2023-10-16 19:30:38 -05:00
parent a539f6889f
commit 09cda7d3f9
6 changed files with 166 additions and 34 deletions

96
scripts/parse_ppp.py Normal file
View File

@ -0,0 +1,96 @@
import os
import json
import torch
from tqdm.auto import tqdm
from pathlib import Path
from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension
device = "cuda"
target = "in"
audio_map = {}
text_map = {}
data = {}
for season in os.listdir(f"./{target}/"):
if not os.path.isdir(f"./{target}/{season}/"):
continue
for episode in os.listdir(f"./{target}/{season}/"):
if not os.path.isdir(f"./{target}/{season}/{episode}/"):
continue
for filename in os.listdir(f"./{target}/{season}/{episode}/"):
path = f'./{target}/{season}/{episode}/{filename}'
attrs = filename.split("_")
timestamp = f'{attrs[0]}h{attrs[1]}m{attrs[2]}s'
key = f'{episode}_{timestamp}'
if filename[-5:] == ".flac":
name = attrs[3]
emotion = attrs[4]
quality = attrs[5]
audio_map[key] = {
"path": path,
'episode': episode,
"name": name,
"emotion": emotion,
"quality": quality,
"timestamp": timestamp,
}
elif filename[-4:] == ".txt":
text_map[key] = open(path, encoding="utf-8").read()
txts = {}
wavs = []
for key, entry in audio_map.items():
path = entry['path']
name = entry['name']
emotion = entry['emotion']
quality = entry['quality']
episode = entry['episode']
path = entry['path']
timestamp = entry['timestamp']
transcription = text_map[key]
if name not in data:
data[name] = {}
os.makedirs(f'./training/{name}/', exist_ok=True)
os.makedirs(f'./voices/{name}/', exist_ok=True)
key = f'{episode}_{timestamp}.flac'
os.rename(path, f'./voices/{name}/{key}')
data[name][key] = {
"segments": [],
"language": "en",
"text": transcription,
"misc": {
"emotion": emotion,
"quality": quality,
"timestamp": timestamp,
"episode": episode,
}
}
path = f'./voices/{name}/{key}'
txts[path] = transcription
wavs.append(Path(path))
for name in data.keys():
open(f"./training/{name}/whisper.json", "w", encoding="utf-8").write( json.dumps( data[name], indent='\t' ) )
for key, text in tqdm(txts.items(), desc="Phonemizing..."):
path = Path(key)
phones = valle_phonemize(text)
open(_replace_file_extension(path, ".phn.txt"), "w", encoding="utf-8").write(" ".join(phones))
for path in tqdm(wavs, desc="Quantizing..."):
qnt = valle_quantize(path, device=device)
torch.save(qnt.cpu(), _replace_file_extension(path, ".qnt.pt"))

View File

@ -92,12 +92,13 @@ def _load_paths_from_metadata(data_dir, type="training", validate=False):
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
metadata_path = data_dir / "metadata.json"
if not cfg.dataset.use_metadata or not metadata_path.exists():
return _fn( data_dir, type if cfg.dataset.use_hdf5 else ".qnt.pt", validate )
speaker = cfg.get_spkr( data_dir / "dummy" )
metadata = {}
if cfg.dataset.use_metadata and metadata_path.exists():
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
if len(metadata) == 0:
return _fn( data_dir, type if cfg.dataset.use_hdf5 else ".qnt.pt", validate )
def key( dir, id ):
if not cfg.dataset.use_hdf5:
return data_dir / id
@ -193,6 +194,7 @@ class Dataset(_Dataset):
if len(self.dataset) == 0:
self.dataset = cfg.dataset.training
# dict of paths keyed by speaker names
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
# cull speakers if they do not have enough utterances
@ -206,6 +208,23 @@ class Dataset(_Dataset):
self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
# dict of speakers keyed by speaker group
self.spkrs_by_spkr_group = {}
for data_dir in self.dataset:
spkr = cfg.get_spkr( data_dir / "dummy" )
spkr_group = cfg.get_spkr_group( data_dir / "dummy" )
if len(self.paths_by_spkr_name[spkr]) < cfg.dataset.min_utterances:
continue
if spkr_group not in self.spkrs_by_spkr_group:
self.spkrs_by_spkr_group[spkr_group] = []
self.spkrs_by_spkr_group[spkr_group].append( spkr )
self.spkr_groups = list(self.spkrs_by_spkr_group.keys())
self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
if cfg.dataset.sample_type == "path":
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
@ -214,14 +233,15 @@ class Dataset(_Dataset):
self.phone_symmap = phone_symmap or self._get_phone_symmap()
self.spkr_symmap = self._get_spkr_symmap()
self.spkr_group_symmap = self._get_spkr_group_symmap()
self.lang_symmap = self._get_lang_symmap()
self.task_symmap = self._get_task_symmap()
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
if len(self.paths) == 0 and training:
raise ValueError("No valid path is found for training.")
if len(self.paths) == 0:
raise ValueError(f"No valid path is found for {self.dataset_type}")
#self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0
self.duration = _calculate_durations(self.dataset_type)
@ -281,6 +301,9 @@ class Dataset(_Dataset):
def _get_spkr_symmap(self):
return {s: i for i, s in enumerate(self.spkrs)}
def _get_spkr_group_symmap(self):
return {s: i for i, s in enumerate(self.spkr_groups)}
def _get_lang_symmap(self):
return get_lang_symmap()
@ -358,14 +381,31 @@ class Dataset(_Dataset):
return prom
def __getitem__(self, index):
if cfg.dataset.sample_type == "speaker":
if cfg.dataset.sample_type == "group":
spkr_group = self.spkr_groups[index]
spkr_group_id = self.spkr_group_symmap[spkr_group]
spkr_name = self.spkr_samplers[spkr_group].sample()
if spkr_name in self.spkr_symmap:
spkr_id = self.spkr_symmap[spkr_name]
else:
spkr_id = -1
try:
path = self.samplers[spkr_name].sample()
except Exception as e:
print( "ERROR", spkr_group, spkr_name )
raise e
elif cfg.dataset.sample_type == "speaker":
spkr_name = self.spkrs[index]
spkr_id = self.spkr_symmap[spkr_name]
path = self.samplers[spkr_name].sample()
spkr_group = self.get_speaker_group(path)
spkr_group_id = self.spkr_group_symmap[spkr_group]
else:
path = self.paths[index]
spkr_name = self.get_speaker(path)
spkr_id = self.spkr_symmap[spkr_name]
spkr_group = self.get_speaker_group(path)
spkr_group_id = self.spkr_group_symmap[spkr_group]
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
@ -381,7 +421,6 @@ class Dataset(_Dataset):
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
resps = _load_quants(path)
spkr_group = self.get_speaker_group(path)
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
# append additional prompts in an attempt to artifically increase lengths / offer new data
@ -611,6 +650,8 @@ class Dataset(_Dataset):
self.training = value
def __len__(self):
if cfg.dataset.sample_type == "group":
return min(len(self.spkr_groups), self._head or len(self.spkr_groups))
if cfg.dataset.sample_type == "speaker":
return min(len(self.spkrs), self._head or len(self.spkrs))
return min(len(self.paths), self._head or len(self.paths))
@ -679,6 +720,7 @@ def create_train_val_dataloader():
_logger.info(str(train_dataset.phone_symmap))
_logger.info(str(train_dataset.spkr_symmap))
_logger.info(str(train_dataset.spkr_group_symmap))
_logger.info(f"#samples (train): {len(train_dataset)}.")
_logger.info(f"#samples (val): {len(val_dataset)}.")
@ -707,6 +749,10 @@ def create_dataset_metadata():
metadata = {}
for path in tqdm(paths, desc="Parsing paths"):
if isinstance(path, str):
print("str:", path)
path = Path(path)
speaker = cfg.get_spkr(path)
if speaker not in metadata:
metadata[speaker] = {}

View File

@ -171,7 +171,7 @@ def encode_from_file(path, device="cuda"):
return encode_from_files( path, device )
else:
path = str(path)
wav, sr = torchaudio.load(path, format=path[-3:])
wav, sr = torchaudio.load(path)
if wav.shape[0] == 2:
wav = wav[:1]

View File

@ -469,8 +469,9 @@ class Engines(dict[str, Engine]):
self._update()
if len(self.keys()) > 1:
stats["elapsed_time"] = total_elapsed_time
stats["global_step"] = self.global_step
#stats["micro_step"] = self.micro_step
stats["it"] = self.global_step
return stats

View File

@ -146,13 +146,13 @@ class AR_NAR(Base):
quant_levels=quant_levels,
)
# is NAR
prev_list = resps_list
if max_levels == 0:
max_levels = self.n_resp_levels
while True:
level = prev_list[0].shape[-1]
prev_list = resps_list
for n in trange( max_levels, desc="NAR" ):
level = prev_list[0].shape[-1]
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
break
@ -195,14 +195,13 @@ class AR_NAR(Base):
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
] * batch_size if sampling_mirostat_tau > 0.0 else None
sampling_beam_width_use_logs = True
scores = [ 1.0 ] * sampling_beam_width
if self.interleave:
max_steps *= self.n_prom_levels
# get next in sequence
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
for n in trange(max_steps // max(1, self.recurrent_chunk_size), desc="AR"):
# experimental rolling response to avoid too-long perplexity hits despite RetNet allegedly fixing this.
# UNTESTED. In theory it would be better to also adjust the text, but there's no way of correlating text to segment of audio without something like wav2vec2
if max_resp_context > 0:
@ -245,17 +244,13 @@ class AR_NAR(Base):
r, s = r
# first step, expand batch
if batch_size == 1:
batch_size *= sampling_beam_width
batch_size = sampling_beam_width
text_list = text_list * sampling_beam_width
proms_list = proms_list * sampling_beam_width
sequence_list = sequence_list * sampling_beam_width
stopped = torch.zeros(batch_size, device=device).bool()
# update scores
if sampling_beam_width_use_logs:
scores = [ (math.log(scores[i]) if scores[i] > 0 else 0) + math.log(score) for i, score in enumerate(s) ]
else:
scores = [ scores[i] * score for i, score in enumerate(s) ]
scores = [ scores[i] + score for i, score in enumerate(s) ]
# append tokens
for i, ri in enumerate(r):
@ -270,13 +265,8 @@ class AR_NAR(Base):
# pick the best scoring candidate
# desu this is always going to be candidate 0
if sampling_beam_width and len(scores) > 0:
best_idx, best_score = (0, 0)
for idx, score in enumerate(scores):
if best_score > score:
best_idx, best_score = idx, score
sequence_list = [sequence_list[best_idx]]
if sampling_beam_width:
sequence_list = [ sequence_list[0] ]
return [self._prune(r) for r in sequence_list]

View File

@ -101,7 +101,7 @@ def _non_blocking_input():
def _make_infinite_epochs(dl):
while True:
_logger.info("New epoch starts.")
#_logger.info("New epoch starts.")
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True)
@ -158,10 +158,9 @@ def train(
#batch = to_device(batch, torch.cuda.current_device())
stats = engines.step(batch=batch, feeder=train_feeder)
stats['it'] = stats['global_step']
stats['epoch'] = engines.global_samples / len(train_dl.dataset.paths)
"""
stats['batch'] = {
'size': len(batch['text']),
'id': batch['spkr_id'],
@ -170,8 +169,8 @@ def train(
'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
}
"""
del stats['global_step']
elapsed_time = stats.get("elapsed_time", 0)
_logger.info(f"Training Metrics: {json.dumps(stats)}.")