added sampling by speaker group name (might be better to de-emphasize the LibriVox/Audiobooks that are in large numbers, and emphasize the smaller pools), log cleanup
This commit is contained in:
parent
a539f6889f
commit
09cda7d3f9
96
scripts/parse_ppp.py
Normal file
96
scripts/parse_ppp.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from pathlib import Path
|
||||
from vall_e.emb.g2p import encode as valle_phonemize
|
||||
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension
|
||||
|
||||
device = "cuda"
|
||||
|
||||
target = "in"
|
||||
|
||||
audio_map = {}
|
||||
text_map = {}
|
||||
|
||||
data = {}
|
||||
|
||||
for season in os.listdir(f"./{target}/"):
|
||||
if not os.path.isdir(f"./{target}/{season}/"):
|
||||
continue
|
||||
|
||||
for episode in os.listdir(f"./{target}/{season}/"):
|
||||
if not os.path.isdir(f"./{target}/{season}/{episode}/"):
|
||||
continue
|
||||
|
||||
for filename in os.listdir(f"./{target}/{season}/{episode}/"):
|
||||
path = f'./{target}/{season}/{episode}/{filename}'
|
||||
attrs = filename.split("_")
|
||||
timestamp = f'{attrs[0]}h{attrs[1]}m{attrs[2]}s'
|
||||
|
||||
key = f'{episode}_{timestamp}'
|
||||
|
||||
if filename[-5:] == ".flac":
|
||||
name = attrs[3]
|
||||
emotion = attrs[4]
|
||||
quality = attrs[5]
|
||||
|
||||
audio_map[key] = {
|
||||
"path": path,
|
||||
'episode': episode,
|
||||
"name": name,
|
||||
"emotion": emotion,
|
||||
"quality": quality,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
|
||||
elif filename[-4:] == ".txt":
|
||||
text_map[key] = open(path, encoding="utf-8").read()
|
||||
txts = {}
|
||||
wavs = []
|
||||
|
||||
for key, entry in audio_map.items():
|
||||
path = entry['path']
|
||||
name = entry['name']
|
||||
emotion = entry['emotion']
|
||||
quality = entry['quality']
|
||||
episode = entry['episode']
|
||||
path = entry['path']
|
||||
timestamp = entry['timestamp']
|
||||
transcription = text_map[key]
|
||||
if name not in data:
|
||||
data[name] = {}
|
||||
os.makedirs(f'./training/{name}/', exist_ok=True)
|
||||
os.makedirs(f'./voices/{name}/', exist_ok=True)
|
||||
|
||||
key = f'{episode}_{timestamp}.flac'
|
||||
os.rename(path, f'./voices/{name}/{key}')
|
||||
|
||||
data[name][key] = {
|
||||
"segments": [],
|
||||
"language": "en",
|
||||
"text": transcription,
|
||||
"misc": {
|
||||
"emotion": emotion,
|
||||
"quality": quality,
|
||||
"timestamp": timestamp,
|
||||
"episode": episode,
|
||||
}
|
||||
}
|
||||
|
||||
path = f'./voices/{name}/{key}'
|
||||
txts[path] = transcription
|
||||
wavs.append(Path(path))
|
||||
|
||||
for name in data.keys():
|
||||
open(f"./training/{name}/whisper.json", "w", encoding="utf-8").write( json.dumps( data[name], indent='\t' ) )
|
||||
|
||||
for key, text in tqdm(txts.items(), desc="Phonemizing..."):
|
||||
path = Path(key)
|
||||
phones = valle_phonemize(text)
|
||||
open(_replace_file_extension(path, ".phn.txt"), "w", encoding="utf-8").write(" ".join(phones))
|
||||
|
||||
for path in tqdm(wavs, desc="Quantizing..."):
|
||||
qnt = valle_quantize(path, device=device)
|
||||
torch.save(qnt.cpu(), _replace_file_extension(path, ".qnt.pt"))
|
|
@ -92,11 +92,12 @@ def _load_paths_from_metadata(data_dir, type="training", validate=False):
|
|||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||
|
||||
metadata_path = data_dir / "metadata.json"
|
||||
if not cfg.dataset.use_metadata or not metadata_path.exists():
|
||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else ".qnt.pt", validate )
|
||||
metadata = {}
|
||||
if cfg.dataset.use_metadata and metadata_path.exists():
|
||||
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
||||
|
||||
speaker = cfg.get_spkr( data_dir / "dummy" )
|
||||
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
||||
if len(metadata) == 0:
|
||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else ".qnt.pt", validate )
|
||||
|
||||
def key( dir, id ):
|
||||
if not cfg.dataset.use_hdf5:
|
||||
|
@ -193,6 +194,7 @@ class Dataset(_Dataset):
|
|||
if len(self.dataset) == 0:
|
||||
self.dataset = cfg.dataset.training
|
||||
|
||||
# dict of paths keyed by speaker names
|
||||
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
|
||||
|
||||
# cull speakers if they do not have enough utterances
|
||||
|
@ -206,6 +208,23 @@ class Dataset(_Dataset):
|
|||
|
||||
self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
|
||||
|
||||
# dict of speakers keyed by speaker group
|
||||
self.spkrs_by_spkr_group = {}
|
||||
for data_dir in self.dataset:
|
||||
spkr = cfg.get_spkr( data_dir / "dummy" )
|
||||
spkr_group = cfg.get_spkr_group( data_dir / "dummy" )
|
||||
|
||||
if len(self.paths_by_spkr_name[spkr]) < cfg.dataset.min_utterances:
|
||||
continue
|
||||
|
||||
if spkr_group not in self.spkrs_by_spkr_group:
|
||||
self.spkrs_by_spkr_group[spkr_group] = []
|
||||
|
||||
self.spkrs_by_spkr_group[spkr_group].append( spkr )
|
||||
|
||||
self.spkr_groups = list(self.spkrs_by_spkr_group.keys())
|
||||
self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||
|
||||
if cfg.dataset.sample_type == "path":
|
||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||
|
||||
|
@ -214,14 +233,15 @@ class Dataset(_Dataset):
|
|||
|
||||
self.phone_symmap = phone_symmap or self._get_phone_symmap()
|
||||
self.spkr_symmap = self._get_spkr_symmap()
|
||||
self.spkr_group_symmap = self._get_spkr_group_symmap()
|
||||
self.lang_symmap = self._get_lang_symmap()
|
||||
self.task_symmap = self._get_task_symmap()
|
||||
|
||||
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
|
||||
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
|
||||
|
||||
if len(self.paths) == 0 and training:
|
||||
raise ValueError("No valid path is found for training.")
|
||||
if len(self.paths) == 0:
|
||||
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
||||
|
||||
#self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0
|
||||
self.duration = _calculate_durations(self.dataset_type)
|
||||
|
@ -281,6 +301,9 @@ class Dataset(_Dataset):
|
|||
def _get_spkr_symmap(self):
|
||||
return {s: i for i, s in enumerate(self.spkrs)}
|
||||
|
||||
def _get_spkr_group_symmap(self):
|
||||
return {s: i for i, s in enumerate(self.spkr_groups)}
|
||||
|
||||
def _get_lang_symmap(self):
|
||||
return get_lang_symmap()
|
||||
|
||||
|
@ -358,14 +381,31 @@ class Dataset(_Dataset):
|
|||
return prom
|
||||
|
||||
def __getitem__(self, index):
|
||||
if cfg.dataset.sample_type == "speaker":
|
||||
if cfg.dataset.sample_type == "group":
|
||||
spkr_group = self.spkr_groups[index]
|
||||
spkr_group_id = self.spkr_group_symmap[spkr_group]
|
||||
spkr_name = self.spkr_samplers[spkr_group].sample()
|
||||
if spkr_name in self.spkr_symmap:
|
||||
spkr_id = self.spkr_symmap[spkr_name]
|
||||
else:
|
||||
spkr_id = -1
|
||||
try:
|
||||
path = self.samplers[spkr_name].sample()
|
||||
except Exception as e:
|
||||
print( "ERROR", spkr_group, spkr_name )
|
||||
raise e
|
||||
elif cfg.dataset.sample_type == "speaker":
|
||||
spkr_name = self.spkrs[index]
|
||||
spkr_id = self.spkr_symmap[spkr_name]
|
||||
path = self.samplers[spkr_name].sample()
|
||||
spkr_group = self.get_speaker_group(path)
|
||||
spkr_group_id = self.spkr_group_symmap[spkr_group]
|
||||
else:
|
||||
path = self.paths[index]
|
||||
spkr_name = self.get_speaker(path)
|
||||
spkr_id = self.spkr_symmap[spkr_name]
|
||||
spkr_group = self.get_speaker_group(path)
|
||||
spkr_group_id = self.spkr_group_symmap[spkr_group]
|
||||
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
|
@ -381,7 +421,6 @@ class Dataset(_Dataset):
|
|||
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
|
||||
resps = _load_quants(path)
|
||||
|
||||
spkr_group = self.get_speaker_group(path)
|
||||
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
|
||||
|
||||
# append additional prompts in an attempt to artifically increase lengths / offer new data
|
||||
|
@ -611,6 +650,8 @@ class Dataset(_Dataset):
|
|||
self.training = value
|
||||
|
||||
def __len__(self):
|
||||
if cfg.dataset.sample_type == "group":
|
||||
return min(len(self.spkr_groups), self._head or len(self.spkr_groups))
|
||||
if cfg.dataset.sample_type == "speaker":
|
||||
return min(len(self.spkrs), self._head or len(self.spkrs))
|
||||
return min(len(self.paths), self._head or len(self.paths))
|
||||
|
@ -679,6 +720,7 @@ def create_train_val_dataloader():
|
|||
|
||||
_logger.info(str(train_dataset.phone_symmap))
|
||||
_logger.info(str(train_dataset.spkr_symmap))
|
||||
_logger.info(str(train_dataset.spkr_group_symmap))
|
||||
|
||||
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
||||
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
||||
|
@ -707,6 +749,10 @@ def create_dataset_metadata():
|
|||
|
||||
metadata = {}
|
||||
for path in tqdm(paths, desc="Parsing paths"):
|
||||
if isinstance(path, str):
|
||||
print("str:", path)
|
||||
path = Path(path)
|
||||
|
||||
speaker = cfg.get_spkr(path)
|
||||
if speaker not in metadata:
|
||||
metadata[speaker] = {}
|
||||
|
|
|
@ -171,7 +171,7 @@ def encode_from_file(path, device="cuda"):
|
|||
return encode_from_files( path, device )
|
||||
else:
|
||||
path = str(path)
|
||||
wav, sr = torchaudio.load(path, format=path[-3:])
|
||||
wav, sr = torchaudio.load(path)
|
||||
|
||||
if wav.shape[0] == 2:
|
||||
wav = wav[:1]
|
||||
|
|
|
@ -469,8 +469,9 @@ class Engines(dict[str, Engine]):
|
|||
|
||||
self._update()
|
||||
|
||||
stats["elapsed_time"] = total_elapsed_time
|
||||
stats["global_step"] = self.global_step
|
||||
#stats["micro_step"] = self.micro_step
|
||||
if len(self.keys()) > 1:
|
||||
stats["elapsed_time"] = total_elapsed_time
|
||||
|
||||
stats["it"] = self.global_step
|
||||
|
||||
return stats
|
||||
|
|
|
@ -146,13 +146,13 @@ class AR_NAR(Base):
|
|||
quant_levels=quant_levels,
|
||||
)
|
||||
# is NAR
|
||||
prev_list = resps_list
|
||||
if max_levels == 0:
|
||||
max_levels = self.n_resp_levels
|
||||
|
||||
while True:
|
||||
level = prev_list[0].shape[-1]
|
||||
prev_list = resps_list
|
||||
|
||||
for n in trange( max_levels, desc="NAR" ):
|
||||
level = prev_list[0].shape[-1]
|
||||
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
|
||||
break
|
||||
|
||||
|
@ -195,14 +195,13 @@ class AR_NAR(Base):
|
|||
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
|
||||
] * batch_size if sampling_mirostat_tau > 0.0 else None
|
||||
|
||||
sampling_beam_width_use_logs = True
|
||||
scores = [ 1.0 ] * sampling_beam_width
|
||||
|
||||
if self.interleave:
|
||||
max_steps *= self.n_prom_levels
|
||||
|
||||
# get next in sequence
|
||||
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
||||
for n in trange(max_steps // max(1, self.recurrent_chunk_size), desc="AR"):
|
||||
# experimental rolling response to avoid too-long perplexity hits despite RetNet allegedly fixing this.
|
||||
# UNTESTED. In theory it would be better to also adjust the text, but there's no way of correlating text to segment of audio without something like wav2vec2
|
||||
if max_resp_context > 0:
|
||||
|
@ -245,17 +244,13 @@ class AR_NAR(Base):
|
|||
r, s = r
|
||||
# first step, expand batch
|
||||
if batch_size == 1:
|
||||
batch_size *= sampling_beam_width
|
||||
batch_size = sampling_beam_width
|
||||
text_list = text_list * sampling_beam_width
|
||||
proms_list = proms_list * sampling_beam_width
|
||||
sequence_list = sequence_list * sampling_beam_width
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
# update scores
|
||||
if sampling_beam_width_use_logs:
|
||||
scores = [ (math.log(scores[i]) if scores[i] > 0 else 0) + math.log(score) for i, score in enumerate(s) ]
|
||||
else:
|
||||
scores = [ scores[i] * score for i, score in enumerate(s) ]
|
||||
scores = [ scores[i] + score for i, score in enumerate(s) ]
|
||||
|
||||
# append tokens
|
||||
for i, ri in enumerate(r):
|
||||
|
@ -270,13 +265,8 @@ class AR_NAR(Base):
|
|||
|
||||
# pick the best scoring candidate
|
||||
# desu this is always going to be candidate 0
|
||||
if sampling_beam_width and len(scores) > 0:
|
||||
best_idx, best_score = (0, 0)
|
||||
for idx, score in enumerate(scores):
|
||||
if best_score > score:
|
||||
best_idx, best_score = idx, score
|
||||
|
||||
sequence_list = [sequence_list[best_idx]]
|
||||
if sampling_beam_width:
|
||||
sequence_list = [ sequence_list[0] ]
|
||||
|
||||
return [self._prune(r) for r in sequence_list]
|
||||
|
||||
|
|
|
@ -101,7 +101,7 @@ def _non_blocking_input():
|
|||
|
||||
def _make_infinite_epochs(dl):
|
||||
while True:
|
||||
_logger.info("New epoch starts.")
|
||||
#_logger.info("New epoch starts.")
|
||||
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True)
|
||||
|
||||
|
||||
|
@ -158,10 +158,9 @@ def train(
|
|||
|
||||
#batch = to_device(batch, torch.cuda.current_device())
|
||||
stats = engines.step(batch=batch, feeder=train_feeder)
|
||||
|
||||
stats['it'] = stats['global_step']
|
||||
stats['epoch'] = engines.global_samples / len(train_dl.dataset.paths)
|
||||
|
||||
"""
|
||||
stats['batch'] = {
|
||||
'size': len(batch['text']),
|
||||
'id': batch['spkr_id'],
|
||||
|
@ -170,8 +169,8 @@ def train(
|
|||
'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
|
||||
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
|
||||
}
|
||||
"""
|
||||
|
||||
del stats['global_step']
|
||||
|
||||
elapsed_time = stats.get("elapsed_time", 0)
|
||||
_logger.info(f"Training Metrics: {json.dumps(stats)}.")
|
||||
|
|
Loading…
Reference in New Issue
Block a user