removed the sampler as it's very misleading

This commit is contained in:
mrq 2023-08-18 14:47:48 -05:00
parent 8e7f900210
commit ced31fd9b7
5 changed files with 46 additions and 71 deletions

View File

@ -118,6 +118,7 @@ class Dataset:
hdf5_name: str = "data.h5" hdf5_name: str = "data.h5"
use_hdf5: bool = False use_hdf5: bool = False
hdf5_flag: str = "a"
validate: bool = True validate: bool = True
workers: int = 8 workers: int = 8
cache: bool = True cache: bool = True
@ -467,7 +468,7 @@ try:
# cached_property stopped working... # cached_property stopped working...
if cfg.dataset.use_hdf5: if cfg.dataset.use_hdf5:
try: try:
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'r' if cfg.distributed else 'a') # to-do, have an easy to set flag that determines if training or creating the dataset cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', cfg.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
except Exception as e: except Exception as e:
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e)) print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
cfg.dataset.use_hdf5 = False cfg.dataset.use_hdf5 = False

View File

@ -10,7 +10,6 @@ import random
import torch import torch
from .config import cfg from .config import cfg
from .utils.sampler import Sampler
from collections import defaultdict from collections import defaultdict
from functools import cache, cached_property from functools import cache, cached_property
@ -168,12 +167,6 @@ class Dataset(_Dataset):
self.durations[spkr_id] = duration self.durations[spkr_id] = duration
else: else:
self.durations[spkr_id] += duration self.durations[spkr_id] += duration
if training and not cfg.distributed and self.sample_type == "path":
self.sampler = Sampler(self.paths, [cfg.get_spkr])
else:
self.sampler = None
def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]): def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]):
ret = defaultdict(list) ret = defaultdict(list)
for path in self.paths: for path in self.paths:
@ -267,10 +260,7 @@ class Dataset(_Dataset):
spkr_id = self.spkr_symmap[spkr_name] spkr_id = self.spkr_symmap[spkr_name]
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])]) path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
else: else:
if self.training and self.sampler is not None: path = self.paths[index]
path = self.sampler.sample()
else:
path = self.paths[index]
spkr_name = cfg.get_spkr(path) spkr_name = cfg.get_spkr(path)
spkr_id = self.spkr_symmap[spkr_name] spkr_id = self.spkr_symmap[spkr_name]
@ -299,12 +289,18 @@ class Dataset(_Dataset):
resps = self.sample_noise() resps = self.sample_noise()
resps = extend_audio(resps, proms.shape[0]) resps = extend_audio(resps, proms.shape[0])
# something to prepend a sr token to the beginning of proms # something to prepend a sr token to the beginning of proms
elif task == "tse: elif task == "tse":
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
other_speaker = self.sample_speaker(ignore=[spkr_name]) other_speaker = self.sample_speaker(ignore=[spkr_name])
other_proms = self.sample_prompts(other_speaker, ignore="") other_proms = self.sample_prompts(other_speaker, ignore="")
proms = merge_audio(proms, other_proms) proms = merge_audio(proms, other_proms)
# something to prepend a ns token to the beginning of proms # something to prepend a tse token to the beginning of proms
"""
"""
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
# as I need to get a good clean point to trim into
elif task == "cse":
elif task == "nse":
""" """
@ -595,9 +591,16 @@ if __name__ == "__main__":
create_dataset_hdf5() create_dataset_hdf5()
train_dl, subtrain_dl, val_dl = create_train_val_dataloader() train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
print("Training DL:", next(iter(train_dl)))
print("Training DL:", next(iter(train_dl))) samples = {
print("Evaluation DL:", next(iter(subtrain_dl))) "training": [ next(iter(train_dl)), next(iter(train_dl)) ],
print("Evaluation DL:", next(iter(subtrain_dl))) "evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
print("Validation DL:", next(iter(val_dl))) "validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
print("Validation DL:", next(iter(val_dl))) }
for k, v in samples.items():
for i in range(len(v)):
del v[i]['proms']
del v[i]['resps']
print(f'{k}:', v)

View File

@ -35,6 +35,16 @@ class TTS():
cfg.load_yaml( config ) cfg.load_yaml( config )
cfg.format() cfg.format()
"""
if cfg.trainer.load_state_dict:
for model in cfg.models.get():
path = cfg.ckpt_dir / model.full_name / "fp32.pth"
if model.name.startswith("ar"):
ar_ckpt = path
if model.name.startswith("nar"):
nar_ckpt = path
"""
if ar_ckpt and nar_ckpt: if ar_ckpt and nar_ckpt:
self.ar_ckpt = ar_ckpt self.ar_ckpt = ar_ckpt
@ -44,10 +54,16 @@ class TTS():
for name, model in models.items(): for name, model in models.items():
if name.startswith("ar"): if name.startswith("ar"):
self.ar = model.to(self.device, dtype=torch.float32) self.ar = model.to(self.device, dtype=torch.float32)
self.ar.load_state_dict(torch.load(self.ar_ckpt)['module']) state = torch.load(self.ar_ckpt)
if "module" in state:
state = state['module']
self.ar.load_state_dict(state)
elif name.startswith("nar"): elif name.startswith("nar"):
self.nar = model.to(self.device, dtype=torch.float32) self.nar = model.to(self.device, dtype=torch.float32)
self.nar.load_state_dict(torch.load(self.nar_ckpt)['module']) state = torch.load(self.nar_ckpt)
if "module" in state:
state = state['module']
self.nar.load_state_dict(state)
else: else:
self.load_models() self.load_models()

View File

@ -1,48 +0,0 @@
"""
A sampler that balances data by key_fns.
MIT License
Copyright (c) 2023 Zhe Niu
niuzhe.nz@outlook.com
"""
import random
class Sampler:
def __init__(self, l, key_fns):
self.tree = self._build(l, key_fns)
def _build(self, l, key_fns) -> dict[dict, list]:
if not key_fns:
return l
tree = {}
key_fn, *key_fns = key_fns
for x in l:
k = key_fn(x)
if k in tree:
tree[k].append(x)
else:
tree[k] = [x]
for k in tree:
tree[k] = self._build(tree[k], key_fns)
return tree
def _sample(self, tree: dict | list):
if isinstance(tree, list):
ret = random.choice(tree)
else:
key = random.choice([*tree.keys()])
ret = self._sample(tree[key])
return ret
def sample(self):
return self._sample(self.tree)

View File

@ -80,7 +80,10 @@ def load_engines():
if cfg.trainer.load_state_dict: if cfg.trainer.load_state_dict:
load_path = cfg.ckpt_dir / name / "fp32.pth" load_path = cfg.ckpt_dir / name / "fp32.pth"
model.load_state_dict(torch.load(load_path)['module']) state = torch.load(load_path)
if "module" in state:
state = state["module"]
model.load_state_dict(state)
engines[name] = Engine( engines[name] = Engine(
model=model, model=model,