removed the sampler as it's very misleading

This commit is contained in:
mrq 2023-08-18 14:47:48 -05:00
parent 8e7f900210
commit ced31fd9b7
5 changed files with 46 additions and 71 deletions

View File

@ -118,6 +118,7 @@ class Dataset:
hdf5_name: str = "data.h5"
use_hdf5: bool = False
hdf5_flag: str = "a"
validate: bool = True
workers: int = 8
cache: bool = True
@ -467,7 +468,7 @@ try:
# cached_property stopped working...
if cfg.dataset.use_hdf5:
try:
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'r' if cfg.distributed else 'a') # to-do, have an easy to set flag that determines if training or creating the dataset
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', cfg.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
except Exception as e:
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
cfg.dataset.use_hdf5 = False

View File

@ -10,7 +10,6 @@ import random
import torch
from .config import cfg
from .utils.sampler import Sampler
from collections import defaultdict
from functools import cache, cached_property
@ -168,12 +167,6 @@ class Dataset(_Dataset):
self.durations[spkr_id] = duration
else:
self.durations[spkr_id] += duration
if training and not cfg.distributed and self.sample_type == "path":
self.sampler = Sampler(self.paths, [cfg.get_spkr])
else:
self.sampler = None
def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]):
ret = defaultdict(list)
for path in self.paths:
@ -267,10 +260,7 @@ class Dataset(_Dataset):
spkr_id = self.spkr_symmap[spkr_name]
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
else:
if self.training and self.sampler is not None:
path = self.sampler.sample()
else:
path = self.paths[index]
path = self.paths[index]
spkr_name = cfg.get_spkr(path)
spkr_id = self.spkr_symmap[spkr_name]
@ -299,12 +289,18 @@ class Dataset(_Dataset):
resps = self.sample_noise()
resps = extend_audio(resps, proms.shape[0])
# something to prepend a sr token to the beginning of proms
elif task == "tse:
elif task == "tse":
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
other_speaker = self.sample_speaker(ignore=[spkr_name])
other_proms = self.sample_prompts(other_speaker, ignore="")
proms = merge_audio(proms, other_proms)
# something to prepend a ns token to the beginning of proms
# something to prepend a tse token to the beginning of proms
"""
"""
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
# as I need to get a good clean point to trim into
elif task == "cse":
elif task == "nse":
"""
@ -595,9 +591,16 @@ if __name__ == "__main__":
create_dataset_hdf5()
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
print("Training DL:", next(iter(train_dl)))
print("Training DL:", next(iter(train_dl)))
print("Evaluation DL:", next(iter(subtrain_dl)))
print("Evaluation DL:", next(iter(subtrain_dl)))
print("Validation DL:", next(iter(val_dl)))
print("Validation DL:", next(iter(val_dl)))
samples = {
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
}
for k, v in samples.items():
for i in range(len(v)):
del v[i]['proms']
del v[i]['resps']
print(f'{k}:', v)

View File

@ -35,6 +35,16 @@ class TTS():
cfg.load_yaml( config )
cfg.format()
"""
if cfg.trainer.load_state_dict:
for model in cfg.models.get():
path = cfg.ckpt_dir / model.full_name / "fp32.pth"
if model.name.startswith("ar"):
ar_ckpt = path
if model.name.startswith("nar"):
nar_ckpt = path
"""
if ar_ckpt and nar_ckpt:
self.ar_ckpt = ar_ckpt
@ -44,10 +54,16 @@ class TTS():
for name, model in models.items():
if name.startswith("ar"):
self.ar = model.to(self.device, dtype=torch.float32)
self.ar.load_state_dict(torch.load(self.ar_ckpt)['module'])
state = torch.load(self.ar_ckpt)
if "module" in state:
state = state['module']
self.ar.load_state_dict(state)
elif name.startswith("nar"):
self.nar = model.to(self.device, dtype=torch.float32)
self.nar.load_state_dict(torch.load(self.nar_ckpt)['module'])
state = torch.load(self.nar_ckpt)
if "module" in state:
state = state['module']
self.nar.load_state_dict(state)
else:
self.load_models()

View File

@ -1,48 +0,0 @@
"""
A sampler that balances data by key_fns.
MIT License
Copyright (c) 2023 Zhe Niu
niuzhe.nz@outlook.com
"""
import random
class Sampler:
def __init__(self, l, key_fns):
self.tree = self._build(l, key_fns)
def _build(self, l, key_fns) -> dict[dict, list]:
if not key_fns:
return l
tree = {}
key_fn, *key_fns = key_fns
for x in l:
k = key_fn(x)
if k in tree:
tree[k].append(x)
else:
tree[k] = [x]
for k in tree:
tree[k] = self._build(tree[k], key_fns)
return tree
def _sample(self, tree: dict | list):
if isinstance(tree, list):
ret = random.choice(tree)
else:
key = random.choice([*tree.keys()])
ret = self._sample(tree[key])
return ret
def sample(self):
return self._sample(self.tree)

View File

@ -80,7 +80,10 @@ def load_engines():
if cfg.trainer.load_state_dict:
load_path = cfg.ckpt_dir / name / "fp32.pth"
model.load_state_dict(torch.load(load_path)['module'])
state = torch.load(load_path)
if "module" in state:
state = state["module"]
model.load_state_dict(state)
engines[name] = Engine(
model=model,