added per-speaker samplers
This commit is contained in:
parent
81b05dabb9
commit
8a6c203277
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -7,4 +7,6 @@ __pycache__
|
||||||
/*.egg-info
|
/*.egg-info
|
||||||
/vall_e/version.py
|
/vall_e/version.py
|
||||||
/build
|
/build
|
||||||
/.cache
|
/.cache
|
||||||
|
|
||||||
|
/vall_e/ext/interleaver.py
|
|
@ -162,6 +162,7 @@ class Model:
|
||||||
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
||||||
arch_type: str = "transformer"
|
arch_type: str = "transformer"
|
||||||
training: bool = True
|
training: bool = True
|
||||||
|
interleave_pattern: str | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def full_name(self):
|
def full_name(self):
|
||||||
|
|
|
@ -12,6 +12,7 @@ import itertools
|
||||||
|
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
||||||
|
from .utils.sampler import Sampler
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import cache, cached_property
|
from functools import cache, cached_property
|
||||||
|
@ -173,6 +174,8 @@ class Dataset(_Dataset):
|
||||||
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
|
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
|
||||||
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
||||||
|
|
||||||
|
self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
|
||||||
|
|
||||||
if cfg.dataset.sample_type == "path":
|
if cfg.dataset.sample_type == "path":
|
||||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||||
|
|
||||||
|
@ -215,6 +218,22 @@ class Dataset(_Dataset):
|
||||||
def tasks(self):
|
def tasks(self):
|
||||||
return cfg.dataset.tasks_list # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] # , "cse", "nse"
|
return cfg.dataset.tasks_list # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] # , "cse", "nse"
|
||||||
|
|
||||||
|
def save_state_dict(self, path):
|
||||||
|
state_dict = {
|
||||||
|
"samplers": { name: sampler.current_pool for name, sampler in self.samplers.items() }
|
||||||
|
}
|
||||||
|
torch.save(state_dict, path)
|
||||||
|
|
||||||
|
def load_state_dict(self, path):
|
||||||
|
state_dict = torch.load(path)
|
||||||
|
|
||||||
|
if "samplers" in state_dict:
|
||||||
|
# better than naively setting the entire object
|
||||||
|
for name, sampler in state_dict["samplers"].items():
|
||||||
|
if name not in self.samplers:
|
||||||
|
continue
|
||||||
|
self.samplers[name].current_pool = sampler
|
||||||
|
|
||||||
def _get_phone_symmap(self):
|
def _get_phone_symmap(self):
|
||||||
return get_phone_symmap()
|
return get_phone_symmap()
|
||||||
|
|
||||||
|
@ -290,7 +309,7 @@ class Dataset(_Dataset):
|
||||||
if cfg.dataset.sample_type == "speaker":
|
if cfg.dataset.sample_type == "speaker":
|
||||||
spkr_name = self.spkrs[index]
|
spkr_name = self.spkrs[index]
|
||||||
spkr_id = self.spkr_symmap[spkr_name]
|
spkr_id = self.spkr_symmap[spkr_name]
|
||||||
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
|
path = self.samplers[spkr_name].sample()
|
||||||
else:
|
else:
|
||||||
path = self.paths[index]
|
path = self.paths[index]
|
||||||
spkr_name = self.get_speaker(path)
|
spkr_name = self.get_speaker(path)
|
||||||
|
@ -543,6 +562,10 @@ def create_datasets():
|
||||||
train_dataset = Dataset( training=True )
|
train_dataset = Dataset( training=True )
|
||||||
val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
|
val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
|
||||||
|
|
||||||
|
train_state_path = cfg.relpath / "train_dataset.pt"
|
||||||
|
if train_state_path.exists():
|
||||||
|
train_dataset.load_state_dict( train_state_path )
|
||||||
|
|
||||||
return train_dataset, val_dataset
|
return train_dataset, val_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@ -752,6 +775,8 @@ if __name__ == "__main__":
|
||||||
del v[i]['resps']
|
del v[i]['resps']
|
||||||
print(f'{k}:', v)
|
print(f'{k}:', v)
|
||||||
|
|
||||||
|
train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
|
||||||
|
|
||||||
elif args.action == "tasks":
|
elif args.action == "tasks":
|
||||||
index = 0
|
index = 0
|
||||||
cfg.dataset.tasks_list = args.tasks.split(",")
|
cfg.dataset.tasks_list = args.tasks.split(",")
|
||||||
|
|
|
@ -15,6 +15,7 @@ def get_model(cfg):
|
||||||
d_model=cfg.dim,
|
d_model=cfg.dim,
|
||||||
n_heads=cfg.heads,
|
n_heads=cfg.heads,
|
||||||
n_layers=cfg.layers,
|
n_layers=cfg.layers,
|
||||||
|
config = cfg
|
||||||
)
|
)
|
||||||
model._cfg = cfg
|
model._cfg = cfg
|
||||||
|
|
||||||
|
|
|
@ -22,8 +22,8 @@ class AR(Base):
|
||||||
return "ln"
|
return "ln"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def arch_type(self) -> bool:
|
def arch_type(self) -> str:
|
||||||
if hasattr(self, "_cfg"):
|
if hasattr(self, "_cfg") and self._cfg:
|
||||||
return self._cfg.arch_type
|
return self._cfg.arch_type
|
||||||
return cfg.models.ar.arch_type
|
return cfg.models.ar.arch_type
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ class AR(Base):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_resp_levels(self) -> int:
|
def n_resp_levels(self) -> int:
|
||||||
if hasattr(self, "_cfg"):
|
if hasattr(self, "_cfg") and self._cfg:
|
||||||
return self._cfg.resp_levels
|
return self._cfg.resp_levels
|
||||||
return cfg.models.ar.resp_levels
|
return cfg.models.ar.resp_levels
|
||||||
|
|
||||||
|
@ -146,8 +146,8 @@ def example_usage():
|
||||||
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
|
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
|
||||||
]
|
]
|
||||||
proms_list = [
|
proms_list = [
|
||||||
x8(torch.tensor([1, 2, 3], device=device)),
|
#x8(torch.tensor([1, 2, 3], device=device)),
|
||||||
#qnt.to(device),
|
qnt.to(device),
|
||||||
]
|
]
|
||||||
resps_list = [
|
resps_list = [
|
||||||
qnt.to(device),
|
qnt.to(device),
|
||||||
|
@ -161,7 +161,7 @@ def example_usage():
|
||||||
'n_tokens': 1024,
|
'n_tokens': 1024,
|
||||||
'd_model': 1024,
|
'd_model': 1024,
|
||||||
'n_heads': 16,
|
'n_heads': 16,
|
||||||
'n_layers': 12,
|
'n_layers': 24,
|
||||||
}
|
}
|
||||||
model = AR(**kwargs).to(device)
|
model = AR(**kwargs).to(device)
|
||||||
engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4))
|
engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4))
|
||||||
|
|
|
@ -16,8 +16,8 @@ class NAR(Base):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def arch_type(self) -> bool:
|
def arch_type(self) -> str:
|
||||||
if hasattr(self, "_cfg"):
|
if hasattr(self, "_cfg") and self._cfg:
|
||||||
return self._cfg.arch_type
|
return self._cfg.arch_type
|
||||||
return cfg.models.nar.arch_type
|
return cfg.models.nar.arch_type
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ class NAR(Base):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_resp_levels(self) -> int:
|
def n_resp_levels(self) -> int:
|
||||||
if hasattr(self, "_cfg"):
|
if hasattr(self, "_cfg") and self._cfg:
|
||||||
return self._cfg.resp_levels
|
return self._cfg.resp_levels
|
||||||
return cfg.models.nar.resp_levels
|
return cfg.models.nar.resp_levels
|
||||||
|
|
||||||
|
|
|
@ -1,2 +1,29 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
import random
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class Sampler():
|
class Sampler():
|
||||||
...
|
def __init__( self, pool = [], keep_all = False ):
|
||||||
|
self.global_pool = pool if keep_all else None
|
||||||
|
self.global_indices = [ i for i in range(len(pool)) ]
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.current_pool = [ i for i in self.global_indices ]
|
||||||
|
|
||||||
|
def sample(self, pool = None):
|
||||||
|
if pool is None:
|
||||||
|
pool = self.global_pool
|
||||||
|
# check if we need to reset
|
||||||
|
index = random.choice( self.current_pool )
|
||||||
|
# remove from pool
|
||||||
|
self.current_pool.remove(index)
|
||||||
|
# reset if needed
|
||||||
|
if len(self.current_pool) == 0:
|
||||||
|
self.reset()
|
||||||
|
# map indices to our real values
|
||||||
|
return pool[index] if pool is not None else index
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.sample(*args, **kwargs)
|
|
@ -311,6 +311,7 @@ def train(
|
||||||
print("Failed to set LR rate to:", rate, str(e))
|
print("Failed to set LR rate to:", rate, str(e))
|
||||||
|
|
||||||
if "export" in command:
|
if "export" in command:
|
||||||
|
train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
|
||||||
engines.save_checkpoint()
|
engines.save_checkpoint()
|
||||||
last_save_step = engines.global_step
|
last_save_step = engines.global_step
|
||||||
|
|
||||||
|
@ -333,6 +334,7 @@ def train(
|
||||||
|
|
||||||
if engines.global_step != last_save_step:
|
if engines.global_step != last_save_step:
|
||||||
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
|
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
|
||||||
|
train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
|
||||||
engines.save_checkpoint()
|
engines.save_checkpoint()
|
||||||
last_save_step = engines.global_step
|
last_save_step = engines.global_step
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user