diff --git a/data/config.yaml b/data/config.yaml index b24a52f..93e14be 100755 --- a/data/config.yaml +++ b/data/config.yaml @@ -15,12 +15,12 @@ dataset: workers: 8 cache: True - phones_range: [4, 256] - duration_range: [1.0, 12.0] + phones_range: [4, 512] + duration_range: [1.0, 24.0] random_utterance: 1.0 - max_prompts: 3 - prompt_duration: 3.0 + max_prompts: 6 + prompt_duration: 6.0 models: _models: @@ -69,7 +69,8 @@ evaluation: size: 32 steps: 300 - temperature: 1.0 + ar_temperature: 1.0 + nar_temperature: 0.2 trainer: iterations: 100_000 @@ -91,7 +92,13 @@ trainer: weight_dtype: bfloat16 - zero_optimization_level: 2 - use_compression_training: True + backend: deepspeed + deepspeed: + zero_optimization_level: 0 + use_compression_training: True -use_vocos: False +inference: + use_vocos: True + +bitsandbytes: + enabled: false \ No newline at end of file diff --git a/vall_e/config.py b/vall_e/config.py index 0eca960..a27fa8b 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -232,7 +232,121 @@ class Evaluation: size: int = 64 steps: int = 500 - temperature: float = 1.0 + ar_temperature: float = 1.0 + nar_temperature: float = 0.2 + +@dataclass() +class DeepSpeed: + zero_optimization_level: int = 0 + use_compression_training: bool = False + + def get_ds_cfg(self, model): + weights = [ name[0] for name in model.named_parameters() ] + bits = 8 + + scheduler_params = {} + for k in cfg.hyperparameters.scheduler_params: + scheduler_params[k] = cfg.hyperparameters.scheduler_params[k] + + if cfg.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params: + scheduler_params['total_num_steps'] = cfg.trainer.iterations + + ds_cfg = { + "train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size, + "gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps, + "optimizer": { + "type": cfg.hyperparameters.optimizer, + "params": { + "lr": cfg.hyperparameters.learning_rate, + } + }, + "scheduler": { + "type": cfg.hyperparameters.scheduler_type, + "params": scheduler_params, + } if cfg.hyperparameters.scheduler_type != "" else None, + "gradient_clipping": cfg.hyperparameters.gradient_clipping, + "fp16": { + "enabled": True, + "auto_cast": True, + } if cfg.trainer.weight_dtype.lower() == "float16" else None, + "bf16": { + "enabled": cfg.trainer.weight_dtype.lower() == "bfloat16" + }, + "compression_training": { + "weight_quantization": { + "shared_parameters":{ + "enabled": True, + "quantizer_kernel": True, + "schedule_offset": 0, + "quantize_groups": 64, + "quantize_verbose": True, + "quantization_type": "symmetric", + "rounding": "nearest", + "quantize_weight_in_forward": True, + "fp16_mixed_quantize":{ + "enabled": False, + "quantize_change_ratio": 1 + } + }, + "different_groups": { + "wq1": { + "params": { + "start_bits": bits, + "target_bits": bits, + "quantization_period": 0 + }, + "modules": weights + } + } + }, + "activation_quantization": { + "shared_parameters":{ + "enabled": True, + "quantization_type": "symmetric", + "range_calibration": "dynamic", + "schedule_offset": 0 + }, + "different_groups": { + "aq1": { + "params": { + "bits": bits + }, + "modules": weights + } + } + } + } if self.use_compression_training else None, + "zero_optimization": { + "stage": self.zero_optimization_level, + "contiguous_gradients": True, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 5e8, + "allgather_bucket_size": 5e8, + "sub_group_size": 5e8, + "round_robin_gradients": True, + "offload_optimizer": { + "device": "cpu", + "pin_memory": True + }, + "offload_param": { + "device": "cpu", + "pin_memory": True + } + } if self.zero_optimization_level > 0 else None, + "comms_logger": { + "enabled": False + } + } + + null_keys = [ k for k in ds_cfg if not ds_cfg[k] ] + for k in null_keys: + del ds_cfg[k] + + if os.path.exists("./config/ds_config.json"): + ds_cfg.update(json.load(open("./config/ds_config.json", "r", encoding="utf-8"))) + + return ds_cfg @dataclass() class Trainer: @@ -256,8 +370,9 @@ class Trainer: weight_dtype: str = "float16" - zero_optimization_level: int = 0 - use_compression_training: bool = False + backend: str = "deepspeed" + + deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) @dataclass() @@ -284,7 +399,6 @@ class Config(_Config): inference: Inference = field(default_factory=lambda: Inference) bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes) - @property def sample_rate(self): return 24_000 @@ -293,142 +407,6 @@ class Config(_Config): def get_spkr(self): return eval(self.dataset.speaker_name_getter) - @property - def scheduler(self): - cfg = { - "type": self.hyperparameters.scheduler_type, - "params": {}, - } - - for k in self.hyperparameters.scheduler_params: - cfg['params'][k] = self.hyperparameters.scheduler_params[k] - - if self.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in cfg['params']: - cfg['params']['total_num_steps'] = self.trainer.iterations - return cfg - - @property - def fp16_cfg(self): - if self.trainer.weight_dtype.lower() != "float16": - return None - return { - "enabled": True, - "auto_cast": True, - } - - @property - def bf16_cfg(self): - return { - "enabled": self.trainer.weight_dtype.lower() == "bfloat16" - } - - def get_compression_cfg(self, model): - if not self.trainer.use_compression_training: - return None - - weights = [ name[0] for name in model.named_parameters() ] - bits = 8 - - return { - "weight_quantization": { - "shared_parameters":{ - "enabled": True, - "quantizer_kernel": True, - "schedule_offset": 0, - "quantize_groups": 64, - "quantize_verbose": True, - "quantization_type": "symmetric", - "rounding": "nearest", - "quantize_weight_in_forward": True, - "fp16_mixed_quantize":{ - "enabled": False, - "quantize_change_ratio": 1 - } - }, - "different_groups": { - "wq1": { - "params": { - "start_bits": bits, - "target_bits": bits, - "quantization_period": 0 - }, - "modules": weights - } - } - }, - "activation_quantization": { - "shared_parameters":{ - "enabled": True, - "quantization_type": "symmetric", - "range_calibration": "dynamic", - "schedule_offset": 0 - }, - "different_groups": { - "aq1": { - "params": { - "bits": bits - }, - "modules": weights - } - } - } - } - - @property - def zero_cfg(self): - if self.trainer.zero_optimization_level == 0: - return None - - return { - "stage": self.trainer.zero_optimization_level, - "contiguous_gradients": True, - "overlap_comm": True, - "reduce_scatter": True, - "reduce_bucket_size": 5e8, - "allgather_bucket_size": 5e8, - "sub_group_size": 5e8, - "round_robin_gradients": True, - "offload_optimizer": { - "device": "cpu", - "pin_memory": True - }, - "offload_param": { - "device": "cpu", - "pin_memory": True - } - } - - def get_ds_cfg(self, model): - cfg = { - "train_micro_batch_size_per_gpu": self.hyperparameters.batch_size, - "gradient_accumulation_steps": self.hyperparameters.gradient_accumulation_steps, - "optimizer": { - "type": self.hyperparameters.optimizer, - "params": { - "lr": self.hyperparameters.learning_rate, - } - }, - "scheduler": self.hyperparameters.scheduler if self.hyperparameters.scheduler_type != "" else None, - "gradient_clipping": self.hyperparameters.gradient_clipping, - "fp16": self.fp16_cfg, - "bf16": self.bf16_cfg, - "compression_training": self.get_compression_cfg(model), - "zero_optimization": self.zero_cfg, - "comms_logger": { - "enabled": False - } - } - - null_keys = [ k for k in cfg if not cfg[k] ] - for k in null_keys: - del cfg[k] - - if os.path.exists("./config/ds_config.json"): - ds_cfg = json.load(open("./config/ds_config.json", "r", encoding="utf-8")) - cfg.update(ds_cfg) - - return cfg - @property def cache_dir(self): return ".cache" / self.relpath @@ -455,6 +433,8 @@ cfg.trainer = Trainer(**cfg.trainer) cfg.inference = Inference(**cfg.inference) cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes) +cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed) + # cached_property stopped working... if cfg.dataset.use_hdf5: try: diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py new file mode 100644 index 0000000..3dcf775 --- /dev/null +++ b/vall_e/engines/__init__.py @@ -0,0 +1,9 @@ +from ..config import cfg + +if cfg.trainer.backend == "deepspeed": + from .deepspeed import Engine +elif cfg.trainer.backend == "local": + from .base import Engine + +from .base import Engines +from .base import TrainFeeder \ No newline at end of file diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py new file mode 100644 index 0000000..bec3656 --- /dev/null +++ b/vall_e/engines/base.py @@ -0,0 +1,374 @@ +from torch import Tensor +from typing import Any, Protocol + +Stats = dict[str, float] + +class TrainFeeder(Protocol): + def __call__( + self, *, engine: "Engine", batch: Any + ) -> None | tuple[Tensor, Stats]: + ... + +def default_feed(engine, batch): + if isinstance(batch, list): + engine( *batch ) + elif isinstance(batch, dict): + engine( **batch ) + else: + engine( batch ) + + losses = engine.gather_attribute("loss") + loss = torch.stack([*losses.values()]).sum() + + stats = {} + stats |= {k: v.item() for k, v in losses.items()} + + return loss, stats + + +from ..config import cfg +from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device + +import logging +import time +import torch +import torch.distributed +import os + +from torch import Tensor +from torch.distributed import all_reduce +from typing import Any, Protocol + +from .base import TrainFeeder + +_logger = logging.getLogger(__name__) + +# A very naive engine implementation using barebones PyTorch +class Engine(): + def __init__(self, *args, **kwargs): + self.module = kwargs['model'] + self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None + self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None + + self.global_steps = 0 + self.micro_steps = 0 + self.gradient_accumulation_steps = cfg.hyperparameters.gradient_accumulation_steps + + def freeze(self): + for p in self.module.parameters(): + if p.requires_grad: + p.requires_grad_(False) + self._frozen_params.add(p) + + def unfreeze(self): + for p in self._frozen_params: + p.requires_grad_(True) + self._frozen_params.clear() + + @property + def global_step(self): + return self.global_steps + + @property + def micro_step(self): + return self.micro_steps + + def train_batch_size(self): + return cfg.hyperparameters.batch_size + + def gather_attribute(self, *args, **kwargs): + return gather_attribute(self.module, *args, **kwargs) + + def dispatch_attribute(self, *args, **kwargs): + return dispatch_attribute(self.module, *args, **kwargs) + + def save_checkpoint(self, save_dir, tag ): + torch.save({ + "global_step": self.global_step, + "micro_step": self.micro_step, + "module": self.module.state_dict(), + "optimizer": self.optimizer.state_dict() if self.optimizer is not None else None, + "lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, + }, save_dir / tag / "state.pth") + + def load_checkpoint(self, load_dir, tag, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True): + state = torch.load(load_dir / tag / "state.pth") + self.global_step = state['global_step'] + self.micro_step = state['micro_step'] + self.module.load_state_dict(state['module']) + + load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state + load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state + + if load_optimizer_states: + self.optimizer.load_state_dict(state['optimizer']) + + if load_lr_scheduler_states: + self.lr_scheduler.load_state_dict(state['lr_scheduler']) + + def eval(self): + return self.module.eval() + + def train(self): + return self.module.train() + + def to(self, *args, **kwargs): + self.module = self.module.to(*args, **kwargs) + return self.module + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def forward(self, *args, **kwargs): + return self.module.forward(*args, **kwargs) + + def backward(self, loss): + return (loss / self.gradient_accumulation_steps).backward() + + def step(self): + with torch.set_grad_enabled(self.gradient_accumulation_steps > 1): + self.micro_steps += 1 + + if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0: + self.global_steps += 1 + self.optimizer.step() + self.optimizer.zero_grad() + + def get_lr(self): + lrs = [] + for param_group in self.optimizer.param_groups: + if 'lr' in param_group: + lrs.append(param_group['lr']) + return lrs + + def set_lr(self, lr): + for param_group in self.optimizer.param_groups: + if 'lr' in param_group: + param_group['lr'] = lr + + def get_global_grad_norm(self): + return 0.0 + + def traverse(self, *args, **kwargs): + self.forward(*args, **kwargs) + losses = self.gather_attribute("loss") + loss = torch.stack([*losses.values()]).sum() + + stats = {} + stats |= {k: v.item() for k, v in losses.items()} + stats |= self.gather_attribute("scalar") + + self.backward(loss) + self.step() + + return stats + +# and now to ignore everything from the above +class Engines(dict[str, Engine]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setup() + + def setup(self): + self._global_step = 0 + self._micro_step = 0 + + @property + def global_step(self): + return self._global_step + + @property + def micro_step(self): + return self._micro_step + + def gather_attribute(self, *args, **kwargs): + ret = {} + for engine in self.values(): + ret |= engine.gather_attribute(*args, **kwargs) + return ret + + def dispatch_attribute(self, *args, **kwargs): + for engine in self.values(): + engine.dispatch_attribute(*args, **kwargs) + + def save_checkpoint(self, tag=None): + if not tag: + tag = cfg.trainer.save_tag + tag = tag.lower() + if tag[:2] == "it" or tag[:4] == "step": + tag = f'{self.global_step}' + + cfg.ckpt_dir.mkdir(parents=True, exist_ok=True) + for name, engine in self.items(): + engine.save_checkpoint(cfg.ckpt_dir / name, tag=tag) + + def load_checkpoint(self, tag=None): + if not tag: + tag = cfg.trainer.load_tag + + for name, engine in self.items(): + load_dir = cfg.ckpt_dir / name + engine.load_checkpoint( + tag=tag, + load_dir=load_dir, + load_module_strict=cfg.trainer.strict_loading, + load_optimizer_states=cfg.trainer.load_states, + load_lr_scheduler_states=cfg.trainer.load_states, + ) + if cfg.trainer.restart_step_count: + engine.global_steps = 0 + + # update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler + if cfg.hyperparameters.scheduler_type == "": + self.set_lr(cfg.hyperparameters.learning_rate) + + self._update_global_step() + self._update_micro_step() + + def set_lr(self, lr): + for engine in self.values(): + engine.set_lr(lr) + + def _update_global_step(self): + for engine in self.values(): + self._global_step = max(self._global_step, engine.global_step) + + def _update_micro_step(self): + for engine in self.values(): + self._micro_step = max(self._micro_step, engine.micro_step) + + def train_batch_size(self): + batch_size = 0 + for engine in self.values(): + batch_size = max(batch_size, engine.train_batch_size()) + + def eval(self): + for engine in self.values(): + engine.eval() + + def train(self): + for engine in self.values(): + engine.train() + + def traverse(self): + stats = {} + for name, engine in self.items(): + stat = engine.traverse() + stats.update(flatten_dict({ name.split("-")[0]: stat })) + return stats + + def step(self, batch, feeder: TrainFeeder = default_feed, device=torch.cuda.current_device()): + total_elapsed_time = 0 + + stats: Any = dict() + + if cfg.trainer.gc_mode == 'step': + do_gc() + + batch = to_device(batch, device) + + for name, engine in self.items(): + #torch.cuda.synchronize() + + if cfg.trainer.gc_mode == 'substep': + do_gc() + + start_time = time.time() + + tries = 4 + n_ooms = torch.zeros([], device=cfg.device) + + if cfg.trainer.aggressive_optimizations: + batch = to_device(batch, device) + + res = feeder( engine=engine, batch=batch ) + """ + while tries >= 0: + try: + res = feeder( engine=engine, batch=batch ) + break + except RuntimeError as e: + print("Forward", str(e)) + + if "out of memory" not in str(e): + self.save_checkpoint() + raise e + + # shrink batch size until it's happy + for k in batch: + batch[k] = batch[k][:-1] + + if tries <= 0: + # trigger OOM + n_ooms += 1 + else: + # also do GC + do_gc() + continue + + all_reduce(n_ooms) + if n_ooms.item() > 0: + self.save_checkpoint() + raise RuntimeError("Out of memory during forward pass!") + """ + + if res is None: + continue + + loss, engine_stats = res + engine_stats |= self.gather_attribute("scalar") + + n_ooms = torch.zeros([], device=cfg.device) + + if cfg.trainer.aggressive_optimizations: + batch = to_device(batch, 'cpu') + + engine.backward(loss) + """ + try: + engine.backward(loss) + except RuntimeError as e: + print("Backwards:", str(e)) + + if "out of memory" not in str(e): + self.save_checkpoint() + raise e + + n_ooms += 1 + + all_reduce(n_ooms) + if n_ooms.item() > 0: + self.save_checkpoint() + raise RuntimeError("Out of memory during backwards pass!") + """ + + engine.step() + + #torch.cuda.synchronize() + + elapsed_time = time.time() - start_time + total_elapsed_time += elapsed_time + + stats.update( + flatten_dict( + { + name.split("-")[0]: dict( + loss=loss.item(), + lr=engine.get_lr()[0], + grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation + elapsed_time=elapsed_time, + engine_step=engine.global_step, + **engine_stats, + ) + } + ), + ) + + self._update_global_step() + self._update_micro_step() + stats["batch_size"] = self.train_batch_size() # len(batch["text"]) + stats["elapsed_time"] = total_elapsed_time + stats["wall_time"] = time.time() + stats["global_step"] = self.global_step + + return stats diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py new file mode 100644 index 0000000..850a77f --- /dev/null +++ b/vall_e/engines/deepspeed.py @@ -0,0 +1,89 @@ +""" +# https://github.com/enhuiz/pytorch-training-utilities +""" + +# to-do: replace this +# to-do: swap out deepspeed + +from ..config import cfg +from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device + +import logging +import time +import torch +import torch.distributed + +from torch import Tensor +from torch.distributed import all_reduce +from typing import Any, Protocol + +from .base import TrainFeeder + +_logger = logging.getLogger(__name__) + +from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distributed +from deepspeed.accelerator import get_accelerator + +#dist.init_distributed(dist_backend=get_accelerator().communication_backend_name()) +initialized_dist = False +if not initialized_dist: + initialized_dist = True + init_distributed() + +class Engine(DeepSpeedEngine): + def __init__(self, *args, **kwargs): + kwargs['config'] = cfg.trainer.deepspeed.get_ds_cfg(model=kwargs['model']) + kwargs['config_class'] = DeepSpeedConfig(kwargs['config']) + + super().__init__(None, *args, **kwargs) + self._frozen_params = set() + + def freeze(self): + for p in self.module.parameters(): + if p.requires_grad: + p.requires_grad_(False) + self._frozen_params.add(p) + + def unfreeze(self): + for p in self._frozen_params: + p.requires_grad_(True) + self._frozen_params.clear() + + @property + def global_step(self): + return self.global_steps + + @property + def micro_step(self): + return self.micro_steps + + def gather_attribute(self, *args, **kwargs): + return gather_attribute(self.module, *args, **kwargs) + + def dispatch_attribute(self, *args, **kwargs): + return dispatch_attribute(self.module, *args, **kwargs) + + def set_lr(self, lr): + try: + if hasattr(self.optimizer, 'param_groups'): + print(self.optimizer.param_groups) + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr + else: + self.optimizer.set_lr(lr) + except Exception as e: + print(str(e)) + + def traverse(self, *args, **kwargs): + self.forward(*args, **kwargs) + losses = self.gather_attribute("loss") + loss = torch.stack([*losses.values()]).sum() + + stats = {} + stats |= {k: v.item() for k, v in losses.items()} + stats |= self.gather_attribute("scalar") + + self.backward(loss) + self.step() + + return stats \ No newline at end of file diff --git a/vall_e/inference.py b/vall_e/inference.py index 944daa0..8e7914a 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -72,8 +72,8 @@ class TTS(): prom = to_device(prom, self.device).to(torch.int16) phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16) - resp_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp) - resps_list = [r.unsqueeze(-1) for r in resp_list] + resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp) + resps_list = [r.unsqueeze(-1) for r in resps_list] resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp) wav, sr = qnt.decode_to_file(resps_list[0], out_path) diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index dc33d91..93fe7a5 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -1,20 +1,22 @@ from .ar import AR from .nar import NAR -def get_model(model): - if model.name == "ar": +def get_model(cfg): + if cfg.name == "ar": Model = AR - elif model.name == "nar": + elif cfg.name == "nar": Model = NAR else: - raise f"invalid model name: {model.name}" - name = model.name + raise f"invalid model name: {cfg.name}" + name = cfg.name + model = Model( n_tokens=model.tokens, d_model=model.dim, n_heads=model.heads, n_layers=model.layers, ) + model._cfg = cfg print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index fabcd7b..4c5e293 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -50,112 +50,80 @@ class AR(Base): self, text_list: list[Tensor], proms_list: list[Tensor], - resp_list: list[Tensor] | None = None, + resps_list: list[Tensor] | None = None, max_steps: int = 1000, sampling_temperature: float = 1.0, naive: bool = True, ): - if resp_list is not None: + if resps_list is not None: + resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels + return super().forward( - text_list, - proms_list, - self._unsqueeze_list(resp_list), - resp_list, + text_list=text_list, + proms_list=proms_list, + resps_list=self._unsqueeze_list(resps_list), + targ_list=resps_list, quant_levels=None, shift_targ_list=True, return_all_resp=False, ) device = text_list[0].device - resp_list: list[Tensor] = [ + resps_list: list[Tensor] = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ] stopped = torch.zeros(len(text_list), device=device).bool() - if self.arch_type == "transformer": - naive = True - chunk_size = 1 # don't really know what to do about this desu state = None start = 0 - # prefill - if self.arch_type == "retnet/local": - # pre-process - state = [ - [ - torch.zeros(self.retnet.hidden_dim // self.retnet.heads, self.retnet.v_dim // self.retnet.heads, device=device).unsqueeze(0).repeat(len(text_list), 1, 1) - for _ in range(self.retnet.heads) - ] for _ in range(self.retnet.layers) - ] - resps_list = self._unsqueeze_list(resp_list) - x_list = self._samplewise_merge_tensors( - self.text_emb(text_list), - self.proms_emb(proms_list), - self.resps_emb(resps_list), - sep=self.sep, - ) - - x, m = list_to_tensor(x_list) - - start = x.shape[1] - - for i in trange(start-1): - _, state = self.retnet.forward_recurrent( x[:, i:i+1, :], state, i+1 ) - for n in trange(max_steps // chunk_size): # get next in sequence r, state = super().forward( text_list, proms_list, - self._unsqueeze_list(resp_list), + self._unsqueeze_list(resps_list), sampling_temperature=sampling_temperature, - state=state, + state=state if not naive else None, ) # append outputted token for i, ri in enumerate(r): - resp_list[i] = torch.cat([resp_list[i], ri[None]]) + resps_list[i] = torch.cat([resps_list[i], ri[None]]) # stop token found stopped |= r == self.stop_token if stopped.all().item(): break - pruned = [self._prune(r) for r in resp_list] + pruned = [self._prune(r) for r in resps_list] return pruned def example_usage(): + cfg.trainer.backend = "local" from functools import partial from einops import repeat from ..emb.qnt import decode_to_file - from ..utils import gather_attribute + from ..engines import Engine + from tqdm import tqdm device = "cpu" - + x8 = partial(repeat, pattern="t -> t l", l=2) symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} def tokenize(content, lang_marker="en"): split = content.split(" ") phones = [f""] + [ " " if not p else p for p in split ] + [f""] return torch.tensor([*map(symmap.get, phones)]).to() - qnt = torch.load("data/qnt.pt")[0, 0].to(device) - kwargs = { - 'n_tokens': 1024, - 'd_model': 1024, - 'n_heads': 16, - 'n_layers': 12, - } + qnt = torch.load("data/qnt.pt")[0].t()[:, :2].to(device) - model = AR(**kwargs).to(device) - - x8 = partial(repeat, pattern="t -> t l", l=2) text_list = [ #torch.tensor([1, 2, 3], device=device), tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device), @@ -164,41 +132,41 @@ def example_usage(): x8(torch.tensor([1, 2, 3], device=device)), #qnt.to(device), ] - resp_list = [ + resps_list = [ qnt.to(device), ] text_list = text_list[:1] proms_list = proms_list[:1] - resp_list = resp_list[:1] + resps_list = resps_list[:1] - model.eval() - out = model(text_list, proms_list, max_steps=75)[0] - print("qnt:", qnt.shape, qnt) - print("out:", out.shape, out) - wav, sr = decode_to_file(out, "data/test/test.ar.init.wav", device=device) + kwargs = { + 'n_tokens': 1024, + 'd_model': 1024, + 'n_heads': 16, + 'n_layers': 12, + } + model = AR(**kwargs).to(device) + engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) + + def sample( name, steps=400 ): + engine.eval() + out = engine(text_list, proms_list, max_steps=steps) + for i, o in enumerate(out): + wav, sr = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device) - optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + def train(): + engine.train() + t = trange(60) + for i in t: + stats = {"step": i} + stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list) - model.train() - for i in trange(60): - optimizer.zero_grad() - _ = model(text_list, proms_list, resp_list) - - losses = gather_attribute(model, "loss") - loss = sum(losses.values()) - loss.backward() - optimizer.step() - - if i % 20 == 0: - print(f"iter={i}, {losses}.") - model.eval() - out = model(text_list, proms_list, max_steps=400) - print("qnt:", qnt.shape, qnt) - for i, o in enumerate(out): - print("out:", i, o.shape, o) - wav, sr = decode_to_file(o, f"data/test/test.ar.{i}.wav", device=device) + t.set_description(f"{stats}") + sample("init", 75) + train() + sample("final") if __name__ == "__main__": example_usage() diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 7eb3902..9eb0c36 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -176,17 +176,7 @@ class Base(nn.Module): self.retnet = RetNetDecoder( self.retnet_config ) - elif self.arch_type == "retnet/local": - self.retnet = RetNet( - layers=n_layers, - hidden_dim=d_model, - ffn_size=d_model * 4, - heads=n_heads, - dropout=p_dropout, - norm_type=self.norm_type, - n_levels=self.n_resp_levels, - double_v_dim=True - ) + self.classifier = nn.Linear(d_model, n_resp_tokens) self.accuracy_metric = MulticlassAccuracy( @@ -272,7 +262,7 @@ class Base(nn.Module): text_list: [t] * b proms_list: [t' l] * b, l quantization levels. resps_list: [t'' l] * b, l quantization levels. - targ_list: [t''] * b, one quantization level only, when given, loss will be computed + targ_list: [t''] * b, one quantization level only; when given, loss will be computed quant_levels: specify which quant_levels to feed forward, used in NAR mode. shift_targ_list: whether to shift target list when computing loss. True if AR. return_all_resp: True if NAR. @@ -298,24 +288,12 @@ class Base(nn.Module): elif self.arch_type == "retnet": x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True) state = self.retnet.get_incremental_state( state, 'prev_state' ) - elif self.arch_type == "retnet/local": - # recurrent inferencing - if self.causal and state is not None: - last = x.shape[1] - x, state = self.retnet.forward_recurrent( - x[:, last-1:last, :], # nasty way to grab the last embedding to forward - state, - last - ) - else: - x = self.retnet( x, quant_levels ) x = self.classifier(x) * m # Remove padding h_list = [hi[:li] for hi, li in zip(x, map(len, x_list))] - # compute loss if the target is given if targ_list is not None: if any([l == 0 for l in map(len, targ_list)]): @@ -337,20 +315,24 @@ class Base(nn.Module): # the NAR doesn't need to compute the loss for it if self.resp_loss_only: text_prom_list[i][:] = self.ignore_index + # roll the text/prompt for loss computing - # the AR benefits from this + # the AR benefits from this, for some reason I'll figure out later else: text_prom_list[i] = text_prom_list[i].roll(-1, dims=0) text_prom_list[i][-1] = self.ignore_index - # necessary to roll the target if recurrently/causally/autoregressively generating, or it won't be able to work + # for the AR, roll by one and mark the ending with a stop token + # this coerces the model into properly inferencing causally + + # why we don't just append a stop token in the dataloader, who knows if shift_targ_list: targ_list = [*targ_list] for i in range(len(targ_list)): targ_list[i] = targ_list[i].roll(-1, dims=0) targ_list[i][-1] = self.stop_token - # generate the sequence + # create the new target sequence to compute the loss against y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) self.loss = dict( @@ -367,6 +349,7 @@ class Base(nn.Module): del prom_list del text_prom_list del y_list + # return the entire generated token string if return_all: @@ -387,124 +370,90 @@ class Base(nn.Module): return ret, state def example_usage(): + from ..config import cfg + cfg.trainer.backend = "local" + from functools import partial from einops import repeat - from tqdm import trange - - from ..utils import gather_attribute + from ..emb.qnt import decode_to_file + from ..engines import Engine, Engines + from tqdm import tqdm, trange + from .ar import AR from .nar import NAR + device = "cpu" + x8 = partial(repeat, pattern="t -> t l", l=2) symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} def tokenize(content, lang_marker="en"): split = content.split(" ") phones = [f""] + [ " " if not p else p for p in split ] + [f""] return torch.tensor([*map(symmap.get, phones)]).to() - device = "cpu" - kwargs = { 'n_tokens': 1024, 'd_model': 1024, 'n_heads': 16, 'n_layers': 12, } - model_ar = AR(**kwargs).to(device) - model_nar = NAR(**kwargs).to(device) + models = { "ar": AR(**kwargs).to(device), "nar": NAR(**kwargs).to(device) } + engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() }) train = True - if train: - qnt = torch.load("data/qnt.pt").to(device) - text_list = [ - tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device), - #tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device), - ] - x8 = partial(repeat, pattern="t -> t l", l=2) - proms_list = [ - qnt[0][:2,:].t().to(device), - #x8(torch.tensor([1, 2, 3], device=device)), - # x8(torch.tensor([2, 3], device=device)), - ] + qnt = torch.load("data/qnt.pt")[0].t()[:, :2].to(device) + text_list = [ + tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device), + #tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device), + ] - resp_list_ar = [ - qnt[0,0].to(device), - # qnt[0,0].to(device), - ] - - resp_list_nar = [ - qnt[0][:2,:].t().to(device), - # qnt[0][:2,:].t().to(device), - ] - - model_ar.train() - optimizer = torch.optim.AdamW(model_ar.parameters(), lr=1e-4) - for i in trange(60): - optimizer.zero_grad() - _ = model_ar(text_list, proms_list, resp_list_ar) - - losses = gather_attribute(model_ar, "loss") - loss = sum(losses.values()) - loss.backward() - optimizer.step() - - if i % 20 == 0: - print(f"iter={i}, {losses}.") - - model_nar.train() - optimizer = torch.optim.AdamW(model_nar.parameters(), lr=1e-4) - for i in trange(60): - optimizer.zero_grad() - - _ = model_nar(text_list, proms_list, resps_list=resp_list_nar) - - losses = gather_attribute(model_nar, "loss") - loss = sum(losses.values()) - loss.backward() - optimizer.step() - - if i % 20 == 0: - stats = {k: v.item() for k, v in losses.items()} - stats["loss"] = loss.item() - print(f"iter={i}, {stats}.") - else: - qnt = torch.load("data/test/test.qnt.pt")[0][:2,:].t().to(device) - text_list = [ - #tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device), - tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device), - ] - proms_list = [ - qnt.to(device), - ] - model_ar.load_state_dict(torch.load("data/test/ar.pth")) - model_nar.load_state_dict(torch.load("data/test/nar.pth")) - - model_ar.eval() - resp_list = model_ar(text_list, proms_list, max_steps=300, sampling_temperature=1.0) - resps_list = [r.unsqueeze(-1) for r in resp_list] + proms_list = [ + qnt.to(device), + ] + resps_list = [ + qnt.to(device), + ] - print("qnt:", qnt.shape, qnt) - print("out:", resp_list[0].shape, resp_list[0]) - wav, sr = decode_to_file(resp_list[0], "data/test/test.ar.init.wav", device=device) - print(wav, sr) + def sample( name, steps=400 ): + AR = None + NAR = None - model_nar.eval() - codes = model_nar( - text_list, - proms_list, - resps_list=resps_list, - sampling_temperature=1.0, - )[0] + engines.eval() + for name, engine in engines.items(): + if name[:2] == "ar": + AR = engine + elif name[:3] == "nar": + NAR = engine + resps_list = AR(text_list, proms_list, max_steps=steps, sampling_temperature=1.0) + resps_list = [r.unsqueeze(-1) for r in resps_list] + codes = NAR( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 ) - print("qnt:", qnt.shape, qnt) - print("codes:", codes.shape, codes) + decode_to_file(resps_list[0], f"./data/ar.{name}.wav", device=device) + decode_to_file(codes[0], f"./data/ar+nar.{name}.wav", device=device) + + if train: + sample("init", 15) - wav, sr = decode_to_file(codes, "data/test/test.ar+nar.init.wav", device=device) - print(wav, sr) + engines.train() + t = trange(60) + for i in t: + """ + stats = {"step": i} + for name, engine in engines.items(): + stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list) + """ + stats = engines.step({"text_list": text_list, "proms_list": proms_list, "resps_list": resps_list}, device="cpu") + t.set_description(f"{stats}") + else: + for name, engine in engines.items(): + engine.module.load_state_dict(torch.load(f"./data/{name}.pth")) + + sample("final") + if __name__ == "__main__": example_usage() diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 2cba56e..57dc966 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -80,7 +80,6 @@ class NAR(Base): quant_levels=quant_levels, ) - # Yes, just nothing as we are training prev_list = [] else: prev_list = resps_list @@ -93,7 +92,7 @@ class NAR(Base): quant_levels = torch.full((len(text_list),), level, device=device) - resp_list, _ = super().forward( + resps_list, _ = super().forward( text_list, proms_list, prev_list, @@ -105,24 +104,22 @@ class NAR(Base): prev_list = [ torch.cat([rs, r.unsqueeze(-1)], dim=-1) - for rs, r in zip(prev_list, resp_list) + for rs, r in zip(prev_list, resps_list) ] return prev_list - def example_usage(): + cfg.trainer.backend = "local" from functools import partial - from pathlib import Path from einops import repeat from ..emb.qnt import decode_to_file - from ..utils import gather_attribute - from ..config import cfg + from ..engines import Engine + from tqdm import tqdm device = "cpu" - x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels) symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} def tokenize(content, lang_marker="en"): @@ -130,15 +127,8 @@ def example_usage(): phones = [f""] + [ " " if not p else p for p in split ] + [f""] return torch.tensor([*map(symmap.get, phones)]).to() - resps = torch.load("data/qnt.pt")[0][:2, :].to(device) - kwargs = { - 'n_tokens': 1024, - 'd_model': 1024, - 'n_heads': 16, - 'n_layers': 12, - } - - model = NAR(**kwargs).to(device) + # to-do: unmangle this and the resp shit + qnt = torch.load("data/qnt.pt")[0].t()[:, :2].to(device) text_list = [ #torch.tensor([1, 2, 3], device=device), @@ -149,65 +139,36 @@ def example_usage(): x8(torch.tensor([2, 3], device=device)), ] - resps_x1_list = [ - resps[:1].t().to(device), + resps_list = [ + qnt.to(device), ] - resps_x8_list = [ - resps.t().to(device), - ] + kwargs = { + 'n_tokens': 1024, + 'd_model': 1024, + 'n_heads': 16, + 'n_layers': 12, + } + model = NAR(**kwargs).to(device) + engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) - model.eval() - codes = model( - text_list, - proms_list, - resps_list=resps_x1_list, - sampling_temperature=0.2, - )[0] + def sample( name ): + engine.eval() + codes = engine( text_list, proms_list, resps_list=[r[..., 0].unsqueeze(-1) for r in resps_list], sampling_temperature=0.2 ) + decode_to_file( codes[0], f"data/nar.{name}.wav", device ) - decode_to_file( - codes, - Path("data/test/test.nar.init.wav"), - device - ) + def train(): + engine.train() + t = trange(60) + for i in t: + stats = {"step": i} + stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list) - optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + t.set_description(f"{stats}") - model.train() - for i in trange(50): - optimizer.zero_grad() - - _ = model(text_list, proms_list, resps_list=resps_x8_list) - - losses = gather_attribute(model, "loss") - loss = sum(losses.values()) - loss.backward() - - optimizer.step() - - if i % 20 == 0: - stats = {k: v.item() for k, v in losses.items()} - stats["loss"] = loss.item() - print(f"iter={i}, {stats}.") - - model.eval() - for i in trange(1, 2): # cfg.models.prom_levels): - resps_list = [ - resps[:i].t().to(device), - ] - - codes = model( - text_list, - proms_list, - resps_list=resps_list, - sampling_temperature=0.2, - )[0] - - decode_to_file( - codes, - Path(f"data/test/test.nar.1-{i}.wav"), - device - ) + sample("init") + train() + sample("final") if __name__ == "__main__": diff --git a/vall_e/train.py b/vall_e/train.py index 9bf8686..748f22d 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -1,17 +1,13 @@ # todo: clean this mess up -# todo: yank deepspeed dependent code out into its own thing from .config import cfg from .data import create_train_val_dataloader from .emb import qnt from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc -from .utils import wrapper as ml -from .models import get_models import auraloss -import deepspeed import json import logging import random @@ -21,10 +17,6 @@ import traceback from collections import defaultdict -from deepspeed import comm as dist -from deepspeed import DeepSpeedConfig -from deepspeed.accelerator import get_accelerator - from tqdm import tqdm mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cuda") @@ -38,210 +30,120 @@ def left_crop(x, len): return x[..., 0:len] _logger = logging.getLogger(__name__) -deepspeed._initialized_dist = False -def load_engines(): - if not deepspeed._initialized_dist: - deepspeed._initialized_dist = True - deepspeed.init_distributed() +def train_feeder(engine, batch): + engine( text_list=batch["text"], proms_list=batch["proms"], resps_list=batch["resps"] ) - models = get_models(cfg.models.get()) - engines = dict() + losses = engine.gather_attribute("loss") - for name in models: - model = models[name] + loss = torch.stack([*losses.values()]).sum() - optimizer = None - lr_scheduler = None + stats = {} + stats |= {k: v.item() for k, v in losses.items()} - Adam = ml.Adam - AdamW = ml.AdamW + return loss, stats - if cfg.hyperparameters.optimizer.lower() == "adamw-torch": - optimizer = AdamW( - model.parameters(), - lr=cfg.hyperparameters.learning_rate, - betas=(0.9, 0.96), - eps=1e-07, - weight_decay=0.01, - ) +@torch.inference_mode() +def run_eval(engines, eval_name, dl): + engines_stats = { + 'eval': eval_name + } - if cfg.trainer.load_state_dict: - load_dir = cfg.ckpt_dir / name / "fp32.pth" - model.load_state_dict(torch.load(load_dir)) + AR = None + NAR = None - ds_cfg=cfg.get_ds_cfg(model=model) - config_class=DeepSpeedConfig(ds_cfg) - engines[name] = trainer.Engine( - model=model, - config=ds_cfg, - config_class=config_class, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - ) + names = [] + for name, engine in engines.items(): + names.append(name) + if name[:2] == "ar": + AR = engine + elif name[:3] == "nar": + NAR = engine + + stats = defaultdict(list) + stats['loss'] = [] + + def process( name, batch, resps_list ): + for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]): + if len(hyp) == 0: + continue + + filename = f'{speaker}_{path.parts[-1]}' + + # to-do, refine the output dir to be sane-er + ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav") + hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav") + prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav") + + hyp_path.parent.mkdir(parents=True, exist_ok=True) + ref_path.parent.mkdir(parents=True, exist_ok=True) + prom_path.parent.mkdir(parents=True, exist_ok=True) + + ref_audio, sr = qnt.decode_to_file(ref, ref_path) + hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path) + prom_audio, sr = qnt.decode_to_file(prom, prom_path) + + # pseudo loss calculation since we don't get the logits during eval + min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] ) + ref_audio = ref_audio[..., 0:min_length] + hyp_audio = hyp_audio[..., 0:min_length] + + try: + stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item()) + except Exception as e: + print(str(e)) + + for batch in tqdm(dl): + batch: dict = to_device(batch, cfg.device) + + # if we're training both models, provide output for both + if AR is not None and NAR is not None: + name = "+".join(names) + + resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.temperature) + resps_list = [ r.unsqueeze(-1) for r in resps_list ] + resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.temperature) + + process( name, batch, resps_list ) + else: + for name in engines: + model = engines[name] + + if name.startswith("ar"): + resps_list = model( + text_list=batch["text"], + proms_list=batch["proms"], + max_steps=cfg.evaluation.steps, + sampling_temperature=cfg.evaluation.ar_temperature, + ) + resps_list = [r.unsqueeze(-1) for r in resps_list] + elif name.startswith("nar"): + resps_list = model( + text_list=batch["text"], + proms_list=batch["proms"], + resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]], + sampling_temperature=cfg.evaluation.nar_temperature, + ) + else: + raise NotImplementedError(name) + + process( name, batch, resps_list ) + + stats = {k: sum(v) / len(v) for k, v in stats.items()} + engines_stats.update(flatten_dict({ name: stats })) + + iteration = engines.global_step + engines_stats['it'] = iteration + engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl) + + _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.") - return trainer.load_engines(engines, cfg) def main(): setup_logging(cfg.log_dir) - #dist.init_distributed(dist_backend=get_accelerator().communication_backend_name()) - if not deepspeed._initialized_dist: - deepspeed._initialized_dist = True - deepspeed.init_distributed() - train_dl, subtrain_dl, val_dl = create_train_val_dataloader() - - def train_feeder(engines, batch, name): - stats = {} - model = engines[name] - if name.startswith("ar"): - _ = model( - text_list=batch["text"], - proms_list=batch["proms"], - resp_list=[r[..., 0] for r in batch["resps"]], - ) - elif name.startswith("nar"): - _ = model( - text_list=batch["text"], - proms_list=batch["proms"], - resps_list=batch["resps"], - ) - else: - raise NotImplementedError(name) - - losses = model.gather_attribute("loss") - - loss = torch.stack([*losses.values()]).sum() - - stats = {} - stats |= {k: v.item() for k, v in losses.items()} - stats |= engines.gather_attribute("scalar") - - return loss, stats - - @torch.inference_mode() - def run_eval(engines, eval_name, dl): - engines_stats = { - 'eval': eval_name - } - - AR = None - NAR = None - - names = [] - for name in engines: - model = engines[name] - names.append(name) - if name[:2] == "ar": - AR = model - elif name[:3] == "nar": - NAR = model - - stats = defaultdict(list) - stats['loss'] = [] - - for batch in tqdm(dl): - batch: dict = to_device(batch, cfg.device) - - # if we're training both models, provide output for both - if AR is not None and NAR is not None: - name = "+".join(names) - - resp_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.temperature) - resps_list = [ r.unsqueeze(-1) for r in resp_list ] - resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.temperature) - - for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]): - if len(hyp) == 0: - continue - - filename = f'{speaker}_{path.parts[-1]}' - - # to-do, refine the output dir to be sane-er - ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav") - hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav") - prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav") - - hyp_path.parent.mkdir(parents=True, exist_ok=True) - ref_path.parent.mkdir(parents=True, exist_ok=True) - prom_path.parent.mkdir(parents=True, exist_ok=True) - - ref_audio, sr = qnt.decode_to_file(ref, ref_path) - hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path) - prom_audio, sr = qnt.decode_to_file(prom, prom_path) - - min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] ) - ref_audio = ref_audio[..., 0:min_length] - hyp_audio = hyp_audio[..., 0:min_length] - - stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item()) - else: - for name in engines: - model = engines[name] - - if name.startswith("ar"): - resp_list = model( - text_list=batch["text"], - proms_list=batch["proms"], - max_steps=cfg.evaluation.steps, - sampling_temperature=cfg.evaluation.temperature, - ) - resps_list = [r.unsqueeze(-1) for r in resp_list] - elif name.startswith("nar"): - resps_list = model( - text_list=batch["text"], - proms_list=batch["proms"], - resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]], - sampling_temperature=cfg.evaluation.temperature, - ) - else: - raise NotImplementedError(name) - - losses = model.gather_attribute("loss") - - batch_stats = {} - batch_stats |= {k: v.item() for k, v in losses.items()} - batch_stats |= engines.gather_attribute("scalar") - - for k, v in batch_stats.items(): - stats[k].append(v) - - for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]): - if len(hyp) == 0: - continue - - filename = f'{speaker}_{path.parts[-1]}' - - # to-do, refine the output dir to be sane-er - ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav") - hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav") - prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav") - - hyp_path.parent.mkdir(parents=True, exist_ok=True) - ref_path.parent.mkdir(parents=True, exist_ok=True) - prom_path.parent.mkdir(parents=True, exist_ok=True) - - ref_audio, sr = qnt.decode_to_file(ref, ref_path) - hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path) - prom_audio, sr = qnt.decode_to_file(prom, prom_path) - - # pseudo loss calculation since we don't get the logits during eval - min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] ) - ref_audio = ref_audio[..., 0:min_length] - hyp_audio = hyp_audio[..., 0:min_length] - - stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item()) - - stats = {k: sum(v) / len(v) for k, v in stats.items()} - engines_stats.update(flatten_dict({ name: stats })) - - iteration = engines.global_step - engines_stats['it'] = iteration - engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl) - - _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.") - + def eval_fn(engines): try: run_eval(engines, "subtrain", subtrain_dl) @@ -256,12 +158,10 @@ def main(): qnt.unload_model() trainer.train( - engines_loader=load_engines, train_dl=train_dl, train_feeder=train_feeder, eval_fn=eval_fn, ) - if __name__ == "__main__": main() \ No newline at end of file diff --git a/vall_e/utils/engines.py b/vall_e/utils/engines.py deleted file mode 100755 index 6549b9a..0000000 --- a/vall_e/utils/engines.py +++ /dev/null @@ -1,252 +0,0 @@ -""" -# https://github.com/enhuiz/pytorch-training-utilities -""" - -# to-do: replace this -# to-do: swap out deepspeed - -from ..config import Config -from .distributed import fix_unset_envs -from .utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device - -import logging -import time -import torch -import torch.distributed - -from deepspeed import DeepSpeedEngine -from torch import Tensor -from torch.distributed import all_reduce -from typing import Any, Protocol - -Stats = dict[str, float] - -_logger = logging.getLogger(__name__) - - -class Engine(DeepSpeedEngine): - def __init__(self, *args, **kwargs): - fix_unset_envs() - super().__init__(None, *args, **kwargs) - self._frozen_params = set() - - def freeze(self): - for p in self.module.parameters(): - if p.requires_grad: - p.requires_grad_(False) - self._frozen_params.add(p) - - def unfreeze(self): - for p in self._frozen_params: - p.requires_grad_(True) - self._frozen_params.clear() - - @property - def global_step(self): - return self.global_steps - - def gather_attribute(self, *args, **kwargs): - return gather_attribute(self.module, *args, **kwargs) - - def dispatch_attribute(self, *args, **kwargs): - return dispatch_attribute(self.module, *args, **kwargs) - - -class TrainFeeder(Protocol): - def __call__( - self, *, engines: "Engines", batch: Any, name: str - ) -> None | tuple[Tensor, Stats]: - ... - - -class Engines(dict[str, Engine]): - def setup(self, cfg: Config): - self._cfg = cfg - self._global_step = 0 - - @property - def cfg(self) -> Config: - return self._cfg - - @property - def config(self): - return self._cfg - - @property - def global_step(self): - return self._global_step - - def gather_attribute(self, *args, **kwargs): - ret = {} - for engine in self.values(): - ret |= engine.gather_attribute(*args, **kwargs) - return ret - - def dispatch_attribute(self, *args, **kwargs): - for engine in self.values(): - engine.dispatch_attribute(*args, **kwargs) - - def save_checkpoint(self, tag=None): - if not tag: - tag = self.cfg.trainer.save_tag - tag = tag.lower() - if tag[:2] == "it" or tag[:4] == "step": - tag = self.global_step - - self.cfg.ckpt_dir.mkdir(parents=True, exist_ok=True) - for name, engine in self.items(): - engine.save_checkpoint(self.cfg.ckpt_dir / name, tag=tag) - - def load_checkpoint(self, tag=None): - if not tag: - tag = self.cfg.trainer.load_tag - - for name, engine in self.items(): - load_dir = self.cfg.ckpt_dir / name - engine.load_checkpoint( - tag=tag, - load_dir=load_dir, - load_module_strict=self.cfg.trainer.strict_loading, - load_optimizer_states=self.cfg.trainer.load_states, - load_lr_scheduler_states=self.cfg.trainer.load_states, - load_module_only=False, # not self.cfg.trainer.load_states, - ) - if self.cfg.trainer.restart_step_count: - engine.global_steps = 0 - - # update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler - if self.cfg.hyperparameters.scheduler_type == "": - self.set_lr(self.cfg.hyperparameters.learning_rate) - - self._update_global_step() - - def set_lr(self, lr): - try: - for engine in self.values(): - if hasattr(engine.optimizer, 'param_groups'): - print(engine.optimizer.param_groups) - for param_group in engine.optimizer.param_groups: - param_group['lr'] = lr - else: - engine.optimizer.set_lr(lr) - except Exception as e: - print(str(e)) - - def _update_global_step(self): - for engine in self.values(): - self._global_step = max(self._global_step, engine.global_step) - - def eval(self): - for engine in self.values(): - engine.eval() - - def train(self): - for engine in self.values(): - engine.train() - - def step(self, feeder: TrainFeeder, batch): - total_elapsed_time = 0 - - stats: Any = dict() - - if self.cfg.trainer.gc_mode == 'step': - do_gc() - - batch = to_device(batch, torch.cuda.current_device()) - - for name, engine in self.items(): - torch.cuda.synchronize() - if self.cfg.trainer.gc_mode == 'substep': - do_gc() - - start_time = time.time() - - tries = 4 - n_ooms = torch.zeros([], device=self.cfg.device) - if self.cfg.trainer.aggressive_optimizations: - batch = to_device(batch, torch.cuda.current_device()) - # engine = engine.to(torch.cuda.current_device()) - - while tries >= 0: - try: - res = feeder( engines=self, batch=batch, name=name ) - break - except RuntimeError as e: - print("Forward", str(e)) - - if "out of memory" not in str(e): - self.save_checkpoint() - raise e - - # shrink batch size until it's happy - for k in batch: - batch[k] = batch[k][:-1] - - if tries <= 0: - # trigger OOM - n_ooms += 1 - else: - # also do GC - do_gc() - continue - - all_reduce(n_ooms) - if n_ooms.item() > 0: - self.save_checkpoint() - raise RuntimeError("Out of memory during forward pass!") - - if res is None: - continue - - loss, engine_stats = res - - n_ooms = torch.zeros([], device=self.cfg.device) - - if self.cfg.trainer.aggressive_optimizations: - batch = to_device(batch, 'cpu') - - try: - engine.backward(loss) - except RuntimeError as e: - print("Backwards:", str(e)) - - if "out of memory" not in str(e): - self.save_checkpoint() - raise e - - n_ooms += 1 - - all_reduce(n_ooms) - if n_ooms.item() > 0: - self.save_checkpoint() - raise RuntimeError("Out of memory during backwards pass!") - - engine.step() - torch.cuda.synchronize() - elapsed_time = time.time() - start_time - total_elapsed_time += elapsed_time - - stats.update( - flatten_dict( - { - name.split("-")[0]: dict( - loss=loss.item(), - lr=engine.get_lr()[0], - grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation - elapsed_time=elapsed_time, - engine_step=engine.global_step, - **engine_stats, - ) - } - ), - ) - del loss - # engine = engine.to('cpu') - - self._update_global_step() - stats["batch_size"] = len(batch["text"]) - stats["elapsed_time"] = total_elapsed_time - stats["wall_time"] = time.time() - stats["global_step"] = self.global_step - - return stats diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index de47763..c796025 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -17,239 +17,258 @@ from torch.utils.data import DataLoader from tqdm import tqdm from typing import Protocol -from ..config import Config +from ..config import cfg from .distributed import ( - global_leader_only, - global_rank, - is_global_leader, - is_local_leader, - local_leader_only, + fix_unset_envs, + global_leader_only, + global_rank, + is_global_leader, + is_local_leader, + local_leader_only, ) -from .engines import Engine, Engines, TrainFeeder +from ..engines import Engine, Engines, TrainFeeder, default_feeder +from ..models import get_models + from .utils import to_device, do_gc +from ..utils import wrapper as ml _logger = logging.getLogger(__name__) _engines: Engines _command: str def get_global_step(): - try: - return _engines.global_step - except: - return None + try: + return _engines.global_step + except: + return None def get_micro_step(): - try: - return _engines.micro_step - except: - return None - - -def get_cfg(): - try: - return _engines.cfg - except: - raise RuntimeError("Trainer has not been setup. Have you called trainer.train?") - + try: + return _engines.micro_step + except: + return None def get_cmd(): - try: - return _command - except: - raise RuntimeError("Trainer has not been setup. Have you called trainer.train?") + try: + return _command + except: + raise RuntimeError("Trainer has not been setup. Have you called trainer.train?") get_iteration = get_global_step +def load_engines(): + models = get_models(cfg.models.get()) + engines = dict() -class EnginesLoader(Protocol): - def __call__(self) -> Engines: - ... + for name in models: + model = models[name] + optimizer = None + lr_scheduler = None -def load_engines(engines: dict[str, Engine], config: Config): - engines = Engines(engines) - engines.setup(config) - if not engines.cfg.trainer.load_state_dict: - engines.load_checkpoint() - return engines + if cfg.hyperparameters.optimizer.lower() == "adamw-torch": + optimizer = ml.AdamW( + model.parameters(), + lr=cfg.hyperparameters.learning_rate, + betas=(0.9, 0.96), + eps=1e-07, + weight_decay=0.01, + ) + if cfg.trainer.load_state_dict: + load_path = cfg.ckpt_dir / name / "fp32.pth" + model.load_state_dict(torch.load(load_path)) + + engines[name] = Engine( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + ) + + engines = Engines(engines) + engines.setup() + + if not cfg.trainer.load_state_dict: + engines.load_checkpoint() + + return engines class EvalFn(Protocol): - def __call__(self, *, engines: Engines): - ... + def __call__(self, *, engines: Engines): + ... class Logger(Protocol): - def __call__(self, *, data: dict): - ... + def __call__(self, *, data: dict): + ... @cache def _get_stdin_selector(): - selector = selectors.DefaultSelector() - selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ) - return selector + selector = selectors.DefaultSelector() + selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ) + return selector def _non_blocking_input(): - global _command - l = [""] - if is_global_leader(): - s = "" - selector = _get_stdin_selector() - events = selector.select(timeout=0) - for key, _ in events: - s: str = key.fileobj.readline().strip() - _logger.info(f'Get stdin "{s}".') - l[0] = s - broadcast_object_list(l, src=0) - _command = l[0] - return _command + global _command + l = [""] + if is_global_leader(): + s = "" + selector = _get_stdin_selector() + events = selector.select(timeout=0) + for key, _ in events: + s: str = key.fileobj.readline().strip() + _logger.info(f'Get stdin "{s}".') + l[0] = s + broadcast_object_list(l, src=0) + _command = l[0] + return _command def _make_infinite_epochs(dl): - while True: - _logger.info("New epoch starts.") - yield from tqdm(dl, "Epoch progress", dynamic_ncols=True) + while True: + _logger.info("New epoch starts.") + yield from tqdm(dl, "Epoch progress", dynamic_ncols=True) @local_leader_only(default=None) def logger(data): - return _logger.info(json.dumps(data, default=str)) + return _logger.info(json.dumps(data, default=str)) def seed(seed): - # Set up random seeds, after fork() - random.seed(seed + global_rank()) - np.random.seed(seed + global_rank()) - torch.manual_seed(seed + global_rank()) + # Set up random seeds, after fork() + random.seed(seed + global_rank()) + np.random.seed(seed + global_rank()) + torch.manual_seed(seed + global_rank()) def train( - engines_loader: EnginesLoader, - train_dl: DataLoader, - train_feeder: TrainFeeder, - eval_fn: EvalFn, - logger: Logger = logger, + train_dl: DataLoader, + train_feeder: TrainFeeder = default_feeder, + eval_fn: EvalFn = lambda x: ..., + logger: Logger = logger, ): - engines = engines_loader() - cfg = engines.cfg + fix_unset_envs() - """ - if is_local_leader(): - cfg.dump() - _logger.info(cfg) - """ + engines = load_engines() - # Setup global engines - global _engines - _engines = engines + """ + if is_local_leader(): + cfg.dump() + _logger.info(cfg) + """ - events = [] + # Setup global engines + global _engines + _engines = engines - eval_fn = global_leader_only(eval_fn) + events = [] - # Pre-loop command - command = _non_blocking_input() - if command in ["eval", "eval_quit"]: - engines.eval() - eval_fn(engines=engines) - engines.train() - if command in ["quit", "eval_quit"]: - return + eval_fn = global_leader_only(eval_fn) - last_save_step = engines.global_step - last_eval_step = 0 + # Pre-loop command + command = _non_blocking_input() + if command in ["eval", "eval_quit"]: + engines.eval() + eval_fn(engines=engines) + engines.train() + if command in ["quit", "eval_quit"]: + return - # Training loop - for batch in _make_infinite_epochs(train_dl): - if engines.global_step >= cfg.trainer.iterations: - break + last_save_step = engines.global_step + last_eval_step = 0 - #batch = to_device(batch, torch.cuda.current_device()) - stats = engines.step(feeder=train_feeder, batch=batch) + # Training loop + for batch in _make_infinite_epochs(train_dl): + if engines.global_step >= cfg.trainer.iterations: + break - iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps - stats['it'] = iteration - stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl) + #batch = to_device(batch, torch.cuda.current_device()) + stats = engines.step(batch=batch, feeder=train_feeder) - stats['batch'] = { - 'size': stats['batch_size'], - 'id': batch['spkr_id'], - 'index': [ index for index in batch['index'] ], - 'text_len': [ text.shape[0] for text in batch['text'] ], - 'prom_len': [ prom.shape[0] for prom in batch['proms'] ], - 'resp_len': [ resp.shape[0] for resp in batch['resps'] ], - } + iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps + stats['it'] = iteration + stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl) - del stats['batch_size'] - del stats['wall_time'] - del stats['global_step'] + stats['batch'] = { + 'size': stats['batch_size'], + 'id': batch['spkr_id'], + 'index': [ index for index in batch['index'] ], + 'text_len': [ text.shape[0] for text in batch['text'] ], + 'prom_len': [ prom.shape[0] for prom in batch['proms'] ], + 'resp_len': [ resp.shape[0] for resp in batch['resps'] ], + } - elapsed_time = stats.get("elapsed_time", 0) - _logger.info(f"Training Metrics: {json.dumps(stats)}.") + del stats['batch_size'] + del stats['wall_time'] + del stats['global_step'] - command = _non_blocking_input() + elapsed_time = stats.get("elapsed_time", 0) + _logger.info(f"Training Metrics: {json.dumps(stats)}.") - if "@" in command: - what, when = command.split("@") - try: - events.append((what, int(when))) - _logger.info(f"Event {command} registered.") - except Exception as e: - _logger.error(e) - command = "" + command = _non_blocking_input() - # Commands are the current command plus the triggered (i.e. iteration >= trigger point) events - events = [e for e in events if e[1] >= engines.global_step] - commands = [command] + [e[0] for e in events if e[1] == engines.global_step] + if "@" in command: + what, when = command.split("@") + try: + events.append((what, int(when))) + _logger.info(f"Event {command} registered.") + except Exception as e: + _logger.error(e) + command = "" - for command in commands: - if command in ["event show", "event"]: - msg = "Events:\n" + "\n".join(["@".join(map(str, e)) for e in events]) - _logger.info(msg) + # Commands are the current command plus the triggered (i.e. iteration >= trigger point) events + events = [e for e in events if e[1] >= engines.global_step] + commands = [command] + [e[0] for e in events if e[1] == engines.global_step] - if command == "event clear": - events.clear() + for command in commands: + if command in ["event show", "event"]: + msg = "Events:\n" + "\n".join(["@".join(map(str, e)) for e in events]) + _logger.info(msg) - if "time" in command: - target_iter = cfg.trainer.iterations - if " to " in command: - try: - target_iter = int(command.split(" to ")[-1]) - except Exception as e: - _logger.error(e) - remaining_iters = target_iter - engines.global_step + 1 - remaining_time = int(remaining_iters * elapsed_time) - _logger.info(humanize.precisedelta(remaining_time)) + if command == "event clear": + events.clear() - if "lr" in command: - rate = float(command.split(" ")[-1]) - engines.set_lr(rate) - print("Updating LR to:", rate) + if "time" in command: + target_iter = cfg.trainer.iterations + if " to " in command: + try: + target_iter = int(command.split(" to ")[-1]) + except Exception as e: + _logger.error(e) + remaining_iters = target_iter - engines.global_step + 1 + remaining_time = int(remaining_iters * elapsed_time) + _logger.info(humanize.precisedelta(remaining_time)) - save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency + if "lr" in command: + rate = float(command.split(" ")[-1]) + engines.set_lr(rate) + print("Updating LR to:", rate) - saving_commands = ["save"] + save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency - if cfg.trainer.save_on_quit: - saving_commands.append("quit") + saving_commands = ["save"] - if engines.global_step != last_save_step: - if engines.global_step % save_ckpt_every == 0 or command in saving_commands: - engines.save_checkpoint() - last_save_step = engines.global_step + if cfg.trainer.save_on_quit: + saving_commands.append("quit") - if engines.global_step != last_eval_step: - if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]: - do_gc() + if engines.global_step != last_save_step: + if engines.global_step % save_ckpt_every == 0 or command in saving_commands: + engines.save_checkpoint() + last_save_step = engines.global_step - engines.eval() - eval_fn(engines=engines) - engines.train() - last_eval_step = engines.global_step + if engines.global_step != last_eval_step: + if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]: + do_gc() - if command in ["quit"]: - return \ No newline at end of file + engines.eval() + eval_fn(engines=engines) + engines.train() + last_eval_step = engines.global_step + + if command in ["quit"]: + return \ No newline at end of file