removed the sampler as it's very misleading
This commit is contained in:
parent
8e7f900210
commit
ced31fd9b7
|
@ -118,6 +118,7 @@ class Dataset:
|
|||
|
||||
hdf5_name: str = "data.h5"
|
||||
use_hdf5: bool = False
|
||||
hdf5_flag: str = "a"
|
||||
validate: bool = True
|
||||
workers: int = 8
|
||||
cache: bool = True
|
||||
|
@ -467,7 +468,7 @@ try:
|
|||
# cached_property stopped working...
|
||||
if cfg.dataset.use_hdf5:
|
||||
try:
|
||||
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'r' if cfg.distributed else 'a') # to-do, have an easy to set flag that determines if training or creating the dataset
|
||||
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', cfg.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
|
||||
except Exception as e:
|
||||
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
|
||||
cfg.dataset.use_hdf5 = False
|
||||
|
|
|
@ -10,7 +10,6 @@ import random
|
|||
import torch
|
||||
|
||||
from .config import cfg
|
||||
from .utils.sampler import Sampler
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import cache, cached_property
|
||||
|
@ -168,12 +167,6 @@ class Dataset(_Dataset):
|
|||
self.durations[spkr_id] = duration
|
||||
else:
|
||||
self.durations[spkr_id] += duration
|
||||
|
||||
if training and not cfg.distributed and self.sample_type == "path":
|
||||
self.sampler = Sampler(self.paths, [cfg.get_spkr])
|
||||
else:
|
||||
self.sampler = None
|
||||
|
||||
def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]):
|
||||
ret = defaultdict(list)
|
||||
for path in self.paths:
|
||||
|
@ -267,10 +260,7 @@ class Dataset(_Dataset):
|
|||
spkr_id = self.spkr_symmap[spkr_name]
|
||||
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
|
||||
else:
|
||||
if self.training and self.sampler is not None:
|
||||
path = self.sampler.sample()
|
||||
else:
|
||||
path = self.paths[index]
|
||||
path = self.paths[index]
|
||||
spkr_name = cfg.get_spkr(path)
|
||||
spkr_id = self.spkr_symmap[spkr_name]
|
||||
|
||||
|
@ -299,12 +289,18 @@ class Dataset(_Dataset):
|
|||
resps = self.sample_noise()
|
||||
resps = extend_audio(resps, proms.shape[0])
|
||||
# something to prepend a sr token to the beginning of proms
|
||||
elif task == "tse:
|
||||
elif task == "tse":
|
||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
other_speaker = self.sample_speaker(ignore=[spkr_name])
|
||||
other_proms = self.sample_prompts(other_speaker, ignore="")
|
||||
proms = merge_audio(proms, other_proms)
|
||||
# something to prepend a ns token to the beginning of proms
|
||||
# something to prepend a tse token to the beginning of proms
|
||||
"""
|
||||
"""
|
||||
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
|
||||
# as I need to get a good clean point to trim into
|
||||
elif task == "cse":
|
||||
elif task == "nse":
|
||||
"""
|
||||
|
||||
|
||||
|
@ -595,9 +591,16 @@ if __name__ == "__main__":
|
|||
create_dataset_hdf5()
|
||||
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
print("Training DL:", next(iter(train_dl)))
|
||||
print("Training DL:", next(iter(train_dl)))
|
||||
print("Evaluation DL:", next(iter(subtrain_dl)))
|
||||
print("Evaluation DL:", next(iter(subtrain_dl)))
|
||||
print("Validation DL:", next(iter(val_dl)))
|
||||
print("Validation DL:", next(iter(val_dl)))
|
||||
|
||||
samples = {
|
||||
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
|
||||
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
||||
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
||||
}
|
||||
|
||||
for k, v in samples.items():
|
||||
for i in range(len(v)):
|
||||
del v[i]['proms']
|
||||
del v[i]['resps']
|
||||
print(f'{k}:', v)
|
||||
|
||||
|
|
|
@ -35,6 +35,16 @@ class TTS():
|
|||
cfg.load_yaml( config )
|
||||
|
||||
cfg.format()
|
||||
|
||||
"""
|
||||
if cfg.trainer.load_state_dict:
|
||||
for model in cfg.models.get():
|
||||
path = cfg.ckpt_dir / model.full_name / "fp32.pth"
|
||||
if model.name.startswith("ar"):
|
||||
ar_ckpt = path
|
||||
if model.name.startswith("nar"):
|
||||
nar_ckpt = path
|
||||
"""
|
||||
|
||||
if ar_ckpt and nar_ckpt:
|
||||
self.ar_ckpt = ar_ckpt
|
||||
|
@ -44,10 +54,16 @@ class TTS():
|
|||
for name, model in models.items():
|
||||
if name.startswith("ar"):
|
||||
self.ar = model.to(self.device, dtype=torch.float32)
|
||||
self.ar.load_state_dict(torch.load(self.ar_ckpt)['module'])
|
||||
state = torch.load(self.ar_ckpt)
|
||||
if "module" in state:
|
||||
state = state['module']
|
||||
self.ar.load_state_dict(state)
|
||||
elif name.startswith("nar"):
|
||||
self.nar = model.to(self.device, dtype=torch.float32)
|
||||
self.nar.load_state_dict(torch.load(self.nar_ckpt)['module'])
|
||||
state = torch.load(self.nar_ckpt)
|
||||
if "module" in state:
|
||||
state = state['module']
|
||||
self.nar.load_state_dict(state)
|
||||
else:
|
||||
self.load_models()
|
||||
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
"""
|
||||
A sampler that balances data by key_fns.
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Zhe Niu
|
||||
|
||||
niuzhe.nz@outlook.com
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
|
||||
class Sampler:
|
||||
def __init__(self, l, key_fns):
|
||||
self.tree = self._build(l, key_fns)
|
||||
|
||||
def _build(self, l, key_fns) -> dict[dict, list]:
|
||||
if not key_fns:
|
||||
return l
|
||||
|
||||
tree = {}
|
||||
|
||||
key_fn, *key_fns = key_fns
|
||||
|
||||
for x in l:
|
||||
k = key_fn(x)
|
||||
|
||||
if k in tree:
|
||||
tree[k].append(x)
|
||||
else:
|
||||
tree[k] = [x]
|
||||
|
||||
for k in tree:
|
||||
tree[k] = self._build(tree[k], key_fns)
|
||||
|
||||
return tree
|
||||
|
||||
def _sample(self, tree: dict | list):
|
||||
if isinstance(tree, list):
|
||||
ret = random.choice(tree)
|
||||
else:
|
||||
key = random.choice([*tree.keys()])
|
||||
ret = self._sample(tree[key])
|
||||
return ret
|
||||
|
||||
def sample(self):
|
||||
return self._sample(self.tree)
|
|
@ -80,7 +80,10 @@ def load_engines():
|
|||
|
||||
if cfg.trainer.load_state_dict:
|
||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
model.load_state_dict(torch.load(load_path)['module'])
|
||||
state = torch.load(load_path)
|
||||
if "module" in state:
|
||||
state = state["module"]
|
||||
model.load_state_dict(state)
|
||||
|
||||
engines[name] = Engine(
|
||||
model=model,
|
||||
|
|
Loading…
Reference in New Issue
Block a user