sampler update (some brainworm just never actually had a sampler for sample_type=path)
This commit is contained in:
parent
b3b67f34ac
commit
31f71fa134
|
@ -12,7 +12,7 @@ import itertools
|
||||||
|
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
||||||
from .utils.sampler import Sampler
|
from .utils.sampler import PoolSampler, OrderedSampler, RandomSampler
|
||||||
from .utils.distributed import global_rank, local_rank, world_size
|
from .utils.distributed import global_rank, local_rank, world_size
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
@ -424,6 +424,7 @@ class Dataset(_Dataset):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._head = None
|
self._head = None
|
||||||
|
self.shuffle = False
|
||||||
self.sampler = None
|
self.sampler = None
|
||||||
|
|
||||||
self.paths = []
|
self.paths = []
|
||||||
|
@ -503,7 +504,6 @@ class Dataset(_Dataset):
|
||||||
# just interleave
|
# just interleave
|
||||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||||
|
|
||||||
self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
|
|
||||||
|
|
||||||
# dict of speakers keyed by speaker group
|
# dict of speakers keyed by speaker group
|
||||||
self.spkrs_by_spkr_group = {}
|
self.spkrs_by_spkr_group = {}
|
||||||
|
@ -521,8 +521,6 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
self.spkr_groups = list(self.spkrs_by_spkr_group.keys())
|
self.spkr_groups = list(self.spkrs_by_spkr_group.keys())
|
||||||
|
|
||||||
self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
|
||||||
|
|
||||||
self.noise_paths = _load_paths(cfg.dataset.noise, "noise")
|
self.noise_paths = _load_paths(cfg.dataset.noise, "noise")
|
||||||
self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values()))
|
self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values()))
|
||||||
|
|
||||||
|
@ -539,6 +537,20 @@ class Dataset(_Dataset):
|
||||||
if len(self.paths) == 0:
|
if len(self.paths) == 0:
|
||||||
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
||||||
|
|
||||||
|
|
||||||
|
sampler_path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
||||||
|
|
||||||
|
if self.sampler_type == "path":
|
||||||
|
self.sampler = OrderedSampler( len(self) )
|
||||||
|
self.samplers = {}
|
||||||
|
self.spkr_samplers = {}
|
||||||
|
else:
|
||||||
|
self.sampler = RandomSampler( len(self) )
|
||||||
|
self.samplers = { name: PoolSampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
|
||||||
|
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||||
|
|
||||||
|
self.load_state_dict()
|
||||||
|
|
||||||
def get_speaker(self, path):
|
def get_speaker(self, path):
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
|
@ -568,21 +580,39 @@ class Dataset(_Dataset):
|
||||||
def tasks(self):
|
def tasks(self):
|
||||||
return cfg.dataset.tasks_list # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] # , "cse", "nse"
|
return cfg.dataset.tasks_list # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] # , "cse", "nse"
|
||||||
|
|
||||||
def save_state_dict(self, path):
|
def save_state_dict(self, path = None):
|
||||||
state_dict = {
|
if path is None:
|
||||||
"samplers": { name: sampler.current_pool for name, sampler in self.samplers.items() }
|
path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
||||||
}
|
|
||||||
|
if self.sampler_type == "path":
|
||||||
|
state_dict = self.sampler.get_state()
|
||||||
|
else:
|
||||||
|
state_dict = {
|
||||||
|
"samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() },
|
||||||
|
"spkr_samplers": { name: sampler.get_state() for name, sampler in self.spkr_samplers.items() },
|
||||||
|
}
|
||||||
torch.save(state_dict, path)
|
torch.save(state_dict, path)
|
||||||
|
|
||||||
def load_state_dict(self, path):
|
def load_state_dict(self, path = None):
|
||||||
state_dict = torch.load(path)
|
if path is None:
|
||||||
|
path = cfg.rel_path / f"sampler.{self.sampler_type}.rank{global_rank()}.pt"
|
||||||
|
|
||||||
if "samplers" in state_dict:
|
if not path.exists():
|
||||||
# better than naively setting the entire object
|
return
|
||||||
|
|
||||||
|
state_dict = torch.load(path)
|
||||||
|
if self.sampler_type == "path":
|
||||||
|
state_dict = self.sampler.load_state(state_dict)
|
||||||
|
else:
|
||||||
for name, sampler in state_dict["samplers"].items():
|
for name, sampler in state_dict["samplers"].items():
|
||||||
if name not in self.samplers:
|
if name not in self.samplers:
|
||||||
continue
|
continue
|
||||||
self.samplers[name].current_pool = sampler
|
self.samplers[name].load_state( sampler )
|
||||||
|
|
||||||
|
for name, sampler in state_dict["spkr_samplers"].items():
|
||||||
|
if name not in self.spkr_samplers:
|
||||||
|
continue
|
||||||
|
self.spkr_samplers[name].load_state( sampler )
|
||||||
|
|
||||||
def _get_phone_symmap(self):
|
def _get_phone_symmap(self):
|
||||||
return get_phone_symmap()
|
return get_phone_symmap()
|
||||||
|
@ -965,36 +995,29 @@ def _seed_worker(worker_id):
|
||||||
|
|
||||||
|
|
||||||
def _create_dataloader(dataset, training):
|
def _create_dataloader(dataset, training):
|
||||||
sampler = None
|
|
||||||
shuffle = True
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if cfg.distributed and training:
|
if cfg.distributed and training:
|
||||||
sampler = DistributedSampler(dataset)
|
sampler = DistributedSampler(dataset)
|
||||||
shuffle = False
|
shuffle = False
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
|
||||||
shuffle=shuffle,
|
shuffle=dataset.shuffle,
|
||||||
drop_last=training,
|
drop_last=training,
|
||||||
num_workers=cfg.dataset.workers,
|
num_workers=cfg.dataset.workers,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
persistent_workers=cfg.dataset.workers > 1,
|
persistent_workers=cfg.dataset.workers > 1,
|
||||||
pin_memory=False, # True,
|
pin_memory=False, # True,
|
||||||
worker_init_fn=_seed_worker,
|
worker_init_fn=_seed_worker,
|
||||||
sampler=sampler,
|
sampler=dataset.sampler,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_datasets():
|
def create_datasets():
|
||||||
train_dataset = Dataset( training=True )
|
train_dataset = Dataset( training=True )
|
||||||
val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
|
val_dataset = Dataset( phone_symmap=train_dataset.phone_symmap, training=False )
|
||||||
|
|
||||||
train_state_path = cfg.rel_path / f"sampler.rank{global_rank()}.pt"
|
|
||||||
if train_state_path.exists():
|
|
||||||
train_dataset.load_state_dict( train_state_path )
|
|
||||||
|
|
||||||
return train_dataset, val_dataset
|
return train_dataset, val_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@ -1312,8 +1335,6 @@ if __name__ == "__main__":
|
||||||
for i in range(len(v)):
|
for i in range(len(v)):
|
||||||
print(f'{k}[{i}]:', v[i])
|
print(f'{k}[{i}]:', v[i])
|
||||||
|
|
||||||
#train_dl.dataset.save_state_dict(cfg.rel_path / "train_dataset.pt")
|
|
||||||
|
|
||||||
elif args.action == "tasks":
|
elif args.action == "tasks":
|
||||||
index = 0
|
index = 0
|
||||||
cfg.dataset.tasks_list = args.tasks.split(",")
|
cfg.dataset.tasks_list = args.tasks.split(",")
|
||||||
|
|
|
@ -2,11 +2,15 @@ from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
import random
|
import random
|
||||||
|
|
||||||
@dataclass
|
import torch
|
||||||
class Sampler():
|
from torch.utils.data import Sampler
|
||||||
|
|
||||||
|
# Randomly picks an index from an array of indices
|
||||||
|
class PoolSampler():
|
||||||
def __init__( self, pool = [], keep_all = False ):
|
def __init__( self, pool = [], keep_all = False ):
|
||||||
|
self.length = len(pool)
|
||||||
self.global_pool = pool if keep_all else None
|
self.global_pool = pool if keep_all else None
|
||||||
self.global_indices = [ i for i in range(len(pool)) ]
|
self.global_indices = [ i for i in range(self.length) ]
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
@ -25,5 +29,78 @@ class Sampler():
|
||||||
# map indices to our real values
|
# map indices to our real values
|
||||||
return pool[index] if pool is not None else index
|
return pool[index] if pool is not None else index
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length # len(self.current_pool)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
while len(self.current_pool) > 0:
|
||||||
|
yield self.sample()
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.sample(*args, **kwargs)
|
return self.sample(*args, **kwargs)
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool }
|
||||||
|
|
||||||
|
def set_state(self, state):
|
||||||
|
self.length = state["length"]
|
||||||
|
self.global_pool = state["global_pool"]
|
||||||
|
self.global_indices = state["global_indices"]
|
||||||
|
self.current_pool = state["current_pool"]
|
||||||
|
|
||||||
|
# "Samples" through a fixed sequence from 0 to length
|
||||||
|
# Necessary for our "shuffle+sort by duration+interleave" sampling method
|
||||||
|
# Allows saving and loading state
|
||||||
|
class OrderedSampler(Sampler):
|
||||||
|
def __init__( self, length ):
|
||||||
|
self.position = 0
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.position >= self.length:
|
||||||
|
self.position = 0
|
||||||
|
|
||||||
|
while self.position < self.length:
|
||||||
|
yield self.position
|
||||||
|
self.position += 1
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return { "position": self.position, "length": self.length }
|
||||||
|
|
||||||
|
def set_state(self, state):
|
||||||
|
self.position = state["position"]
|
||||||
|
self.length = state["length"]
|
||||||
|
|
||||||
|
# Randomly samples indices from a given sequence from 0 to length
|
||||||
|
# Allows saving and loading state
|
||||||
|
class RandomSampler(Sampler):
|
||||||
|
def __init__( self, length ):
|
||||||
|
self.position = 0
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
self.generator = torch.Generator()
|
||||||
|
self.perm = torch.randperm(self.length, generator=self.generator)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.position >= self.length:
|
||||||
|
self.position = 0
|
||||||
|
self.perm = torch.randperm(self.length, generator=self.generator)
|
||||||
|
|
||||||
|
while self.position < self.length:
|
||||||
|
yield self.perm[self.position]
|
||||||
|
self.position += 1
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() }
|
||||||
|
|
||||||
|
def set_state(self, state):
|
||||||
|
self.position = state["position"]
|
||||||
|
self.length = state["length"]
|
||||||
|
self.perm = state["perm"]
|
||||||
|
self.generator.set_state(state["generator"])
|
|
@ -218,7 +218,7 @@ def train(
|
||||||
print("Failed to set LR rate to:", rate, str(e))
|
print("Failed to set LR rate to:", rate, str(e))
|
||||||
|
|
||||||
if "export" in command:
|
if "export" in command:
|
||||||
train_dl.dataset.save_state_dict(cfg.rel_path / f"sampler.rank{global_rank()}.pt")
|
train_dl.dataset.save_state_dict()
|
||||||
engines.save_checkpoint()
|
engines.save_checkpoint()
|
||||||
last_save_step = engines.global_step
|
last_save_step = engines.global_step
|
||||||
|
|
||||||
|
@ -241,7 +241,7 @@ def train(
|
||||||
|
|
||||||
if engines.global_step != last_save_step:
|
if engines.global_step != last_save_step:
|
||||||
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
|
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
|
||||||
train_dl.dataset.save_state_dict(cfg.rel_path / f"sampler.rank{global_rank()}.pt")
|
train_dl.dataset.save_state_dict()
|
||||||
engines.save_checkpoint()
|
engines.save_checkpoint()
|
||||||
last_save_step = engines.global_step
|
last_save_step = engines.global_step
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user