removed the sampler as it's very misleading
This commit is contained in:
parent
8e7f900210
commit
ced31fd9b7
|
@ -118,6 +118,7 @@ class Dataset:
|
||||||
|
|
||||||
hdf5_name: str = "data.h5"
|
hdf5_name: str = "data.h5"
|
||||||
use_hdf5: bool = False
|
use_hdf5: bool = False
|
||||||
|
hdf5_flag: str = "a"
|
||||||
validate: bool = True
|
validate: bool = True
|
||||||
workers: int = 8
|
workers: int = 8
|
||||||
cache: bool = True
|
cache: bool = True
|
||||||
|
@ -467,7 +468,7 @@ try:
|
||||||
# cached_property stopped working...
|
# cached_property stopped working...
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
try:
|
try:
|
||||||
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'r' if cfg.distributed else 'a') # to-do, have an easy to set flag that determines if training or creating the dataset
|
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', cfg.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
|
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
|
||||||
cfg.dataset.use_hdf5 = False
|
cfg.dataset.use_hdf5 = False
|
||||||
|
|
|
@ -10,7 +10,6 @@ import random
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .utils.sampler import Sampler
|
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import cache, cached_property
|
from functools import cache, cached_property
|
||||||
|
@ -168,12 +167,6 @@ class Dataset(_Dataset):
|
||||||
self.durations[spkr_id] = duration
|
self.durations[spkr_id] = duration
|
||||||
else:
|
else:
|
||||||
self.durations[spkr_id] += duration
|
self.durations[spkr_id] += duration
|
||||||
|
|
||||||
if training and not cfg.distributed and self.sample_type == "path":
|
|
||||||
self.sampler = Sampler(self.paths, [cfg.get_spkr])
|
|
||||||
else:
|
|
||||||
self.sampler = None
|
|
||||||
|
|
||||||
def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]):
|
def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]):
|
||||||
ret = defaultdict(list)
|
ret = defaultdict(list)
|
||||||
for path in self.paths:
|
for path in self.paths:
|
||||||
|
@ -267,10 +260,7 @@ class Dataset(_Dataset):
|
||||||
spkr_id = self.spkr_symmap[spkr_name]
|
spkr_id = self.spkr_symmap[spkr_name]
|
||||||
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
|
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
|
||||||
else:
|
else:
|
||||||
if self.training and self.sampler is not None:
|
path = self.paths[index]
|
||||||
path = self.sampler.sample()
|
|
||||||
else:
|
|
||||||
path = self.paths[index]
|
|
||||||
spkr_name = cfg.get_spkr(path)
|
spkr_name = cfg.get_spkr(path)
|
||||||
spkr_id = self.spkr_symmap[spkr_name]
|
spkr_id = self.spkr_symmap[spkr_name]
|
||||||
|
|
||||||
|
@ -299,12 +289,18 @@ class Dataset(_Dataset):
|
||||||
resps = self.sample_noise()
|
resps = self.sample_noise()
|
||||||
resps = extend_audio(resps, proms.shape[0])
|
resps = extend_audio(resps, proms.shape[0])
|
||||||
# something to prepend a sr token to the beginning of proms
|
# something to prepend a sr token to the beginning of proms
|
||||||
elif task == "tse:
|
elif task == "tse":
|
||||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||||
other_speaker = self.sample_speaker(ignore=[spkr_name])
|
other_speaker = self.sample_speaker(ignore=[spkr_name])
|
||||||
other_proms = self.sample_prompts(other_speaker, ignore="")
|
other_proms = self.sample_prompts(other_speaker, ignore="")
|
||||||
proms = merge_audio(proms, other_proms)
|
proms = merge_audio(proms, other_proms)
|
||||||
# something to prepend a ns token to the beginning of proms
|
# something to prepend a tse token to the beginning of proms
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
# speech editing would require higher quality transcription data (phoneme level/word level) unfortunately
|
||||||
|
# as I need to get a good clean point to trim into
|
||||||
|
elif task == "cse":
|
||||||
|
elif task == "nse":
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -595,9 +591,16 @@ if __name__ == "__main__":
|
||||||
create_dataset_hdf5()
|
create_dataset_hdf5()
|
||||||
|
|
||||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||||
print("Training DL:", next(iter(train_dl)))
|
|
||||||
print("Training DL:", next(iter(train_dl)))
|
samples = {
|
||||||
print("Evaluation DL:", next(iter(subtrain_dl)))
|
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
|
||||||
print("Evaluation DL:", next(iter(subtrain_dl)))
|
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
||||||
print("Validation DL:", next(iter(val_dl)))
|
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
||||||
print("Validation DL:", next(iter(val_dl)))
|
}
|
||||||
|
|
||||||
|
for k, v in samples.items():
|
||||||
|
for i in range(len(v)):
|
||||||
|
del v[i]['proms']
|
||||||
|
del v[i]['resps']
|
||||||
|
print(f'{k}:', v)
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,16 @@ class TTS():
|
||||||
cfg.load_yaml( config )
|
cfg.load_yaml( config )
|
||||||
|
|
||||||
cfg.format()
|
cfg.format()
|
||||||
|
|
||||||
|
"""
|
||||||
|
if cfg.trainer.load_state_dict:
|
||||||
|
for model in cfg.models.get():
|
||||||
|
path = cfg.ckpt_dir / model.full_name / "fp32.pth"
|
||||||
|
if model.name.startswith("ar"):
|
||||||
|
ar_ckpt = path
|
||||||
|
if model.name.startswith("nar"):
|
||||||
|
nar_ckpt = path
|
||||||
|
"""
|
||||||
|
|
||||||
if ar_ckpt and nar_ckpt:
|
if ar_ckpt and nar_ckpt:
|
||||||
self.ar_ckpt = ar_ckpt
|
self.ar_ckpt = ar_ckpt
|
||||||
|
@ -44,10 +54,16 @@ class TTS():
|
||||||
for name, model in models.items():
|
for name, model in models.items():
|
||||||
if name.startswith("ar"):
|
if name.startswith("ar"):
|
||||||
self.ar = model.to(self.device, dtype=torch.float32)
|
self.ar = model.to(self.device, dtype=torch.float32)
|
||||||
self.ar.load_state_dict(torch.load(self.ar_ckpt)['module'])
|
state = torch.load(self.ar_ckpt)
|
||||||
|
if "module" in state:
|
||||||
|
state = state['module']
|
||||||
|
self.ar.load_state_dict(state)
|
||||||
elif name.startswith("nar"):
|
elif name.startswith("nar"):
|
||||||
self.nar = model.to(self.device, dtype=torch.float32)
|
self.nar = model.to(self.device, dtype=torch.float32)
|
||||||
self.nar.load_state_dict(torch.load(self.nar_ckpt)['module'])
|
state = torch.load(self.nar_ckpt)
|
||||||
|
if "module" in state:
|
||||||
|
state = state['module']
|
||||||
|
self.nar.load_state_dict(state)
|
||||||
else:
|
else:
|
||||||
self.load_models()
|
self.load_models()
|
||||||
|
|
||||||
|
|
|
@ -1,48 +0,0 @@
|
||||||
"""
|
|
||||||
A sampler that balances data by key_fns.
|
|
||||||
|
|
||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2023 Zhe Niu
|
|
||||||
|
|
||||||
niuzhe.nz@outlook.com
|
|
||||||
"""
|
|
||||||
|
|
||||||
import random
|
|
||||||
|
|
||||||
|
|
||||||
class Sampler:
|
|
||||||
def __init__(self, l, key_fns):
|
|
||||||
self.tree = self._build(l, key_fns)
|
|
||||||
|
|
||||||
def _build(self, l, key_fns) -> dict[dict, list]:
|
|
||||||
if not key_fns:
|
|
||||||
return l
|
|
||||||
|
|
||||||
tree = {}
|
|
||||||
|
|
||||||
key_fn, *key_fns = key_fns
|
|
||||||
|
|
||||||
for x in l:
|
|
||||||
k = key_fn(x)
|
|
||||||
|
|
||||||
if k in tree:
|
|
||||||
tree[k].append(x)
|
|
||||||
else:
|
|
||||||
tree[k] = [x]
|
|
||||||
|
|
||||||
for k in tree:
|
|
||||||
tree[k] = self._build(tree[k], key_fns)
|
|
||||||
|
|
||||||
return tree
|
|
||||||
|
|
||||||
def _sample(self, tree: dict | list):
|
|
||||||
if isinstance(tree, list):
|
|
||||||
ret = random.choice(tree)
|
|
||||||
else:
|
|
||||||
key = random.choice([*tree.keys()])
|
|
||||||
ret = self._sample(tree[key])
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def sample(self):
|
|
||||||
return self._sample(self.tree)
|
|
|
@ -80,7 +80,10 @@ def load_engines():
|
||||||
|
|
||||||
if cfg.trainer.load_state_dict:
|
if cfg.trainer.load_state_dict:
|
||||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||||
model.load_state_dict(torch.load(load_path)['module'])
|
state = torch.load(load_path)
|
||||||
|
if "module" in state:
|
||||||
|
state = state["module"]
|
||||||
|
model.load_state_dict(state)
|
||||||
|
|
||||||
engines[name] = Engine(
|
engines[name] = Engine(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user