diff --git a/data/config.yaml b/data/config.yaml
index b24a52f..93e14be 100755
--- a/data/config.yaml
+++ b/data/config.yaml
@@ -15,12 +15,12 @@ dataset:
workers: 8
cache: True
- phones_range: [4, 256]
- duration_range: [1.0, 12.0]
+ phones_range: [4, 512]
+ duration_range: [1.0, 24.0]
random_utterance: 1.0
- max_prompts: 3
- prompt_duration: 3.0
+ max_prompts: 6
+ prompt_duration: 6.0
@@ -69,7 +69,8 @@ evaluation:
size: 32
steps: 300
- temperature: 1.0
+ ar_temperature: 1.0
+ nar_temperature: 0.2
iterations: 100_000
@@ -91,7 +92,13 @@ trainer:
weight_dtype: bfloat16
- zero_optimization_level: 2
- use_compression_training: True
+ backend: deepspeed
+ deepspeed:
+ zero_optimization_level: 0
+ use_compression_training: True
-use_vocos: False
+ use_vocos: True
+ enabled: false
\ No newline at end of file
diff --git a/vall_e/config.py b/vall_e/config.py
index 0eca960..a27fa8b 100755
--- a/vall_e/config.py
+++ b/vall_e/config.py
@@ -232,7 +232,121 @@ class Evaluation:
size: int = 64
steps: int = 500
- temperature: float = 1.0
+ ar_temperature: float = 1.0
+ nar_temperature: float = 0.2
+class DeepSpeed:
+ zero_optimization_level: int = 0
+ use_compression_training: bool = False
+ def get_ds_cfg(self, model):
+ weights = [ name[0] for name in model.named_parameters() ]
+ bits = 8
+ scheduler_params = {}
+ for k in cfg.hyperparameters.scheduler_params:
+ scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
+ if cfg.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params:
+ scheduler_params['total_num_steps'] = cfg.trainer.iterations
+ ds_cfg = {
+ "train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size,
+ "gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps,
+ "optimizer": {
+ "type": cfg.hyperparameters.optimizer,
+ "params": {
+ "lr": cfg.hyperparameters.learning_rate,
+ }
+ },
+ "scheduler": {
+ "type": cfg.hyperparameters.scheduler_type,
+ "params": scheduler_params,
+ } if cfg.hyperparameters.scheduler_type != "" else None,
+ "gradient_clipping": cfg.hyperparameters.gradient_clipping,
+ "fp16": {
+ "enabled": True,
+ "auto_cast": True,
+ } if cfg.trainer.weight_dtype.lower() == "float16" else None,
+ "bf16": {
+ "enabled": cfg.trainer.weight_dtype.lower() == "bfloat16"
+ },
+ "compression_training": {
+ "weight_quantization": {
+ "shared_parameters":{
+ "enabled": True,
+ "quantizer_kernel": True,
+ "schedule_offset": 0,
+ "quantize_groups": 64,
+ "quantize_verbose": True,
+ "quantization_type": "symmetric",
+ "rounding": "nearest",
+ "quantize_weight_in_forward": True,
+ "fp16_mixed_quantize":{
+ "enabled": False,
+ "quantize_change_ratio": 1
+ }
+ },
+ "different_groups": {
+ "wq1": {
+ "params": {
+ "start_bits": bits,
+ "target_bits": bits,
+ "quantization_period": 0
+ },
+ "modules": weights
+ }
+ }
+ },
+ "activation_quantization": {
+ "shared_parameters":{
+ "enabled": True,
+ "quantization_type": "symmetric",
+ "range_calibration": "dynamic",
+ "schedule_offset": 0
+ },
+ "different_groups": {
+ "aq1": {
+ "params": {
+ "bits": bits
+ },
+ "modules": weights
+ }
+ }
+ }
+ } if self.use_compression_training else None,
+ "zero_optimization": {
+ "stage": self.zero_optimization_level,
+ "contiguous_gradients": True,
+ "overlap_comm": True,
+ "reduce_scatter": True,
+ "reduce_bucket_size": 5e8,
+ "allgather_bucket_size": 5e8,
+ "sub_group_size": 5e8,
+ "round_robin_gradients": True,
+ "offload_optimizer": {
+ "device": "cpu",
+ "pin_memory": True
+ },
+ "offload_param": {
+ "device": "cpu",
+ "pin_memory": True
+ }
+ } if self.zero_optimization_level > 0 else None,
+ "comms_logger": {
+ "enabled": False
+ }
+ }
+ null_keys = [ k for k in ds_cfg if not ds_cfg[k] ]
+ for k in null_keys:
+ del ds_cfg[k]
+ if os.path.exists("./config/ds_config.json"):
+ ds_cfg.update(json.load(open("./config/ds_config.json", "r", encoding="utf-8")))
+ return ds_cfg
class Trainer:
@@ -256,8 +370,9 @@ class Trainer:
weight_dtype: str = "float16"
- zero_optimization_level: int = 0
- use_compression_training: bool = False
+ backend: str = "deepspeed"
+ deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
@@ -284,7 +399,6 @@ class Config(_Config):
inference: Inference = field(default_factory=lambda: Inference)
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
def sample_rate(self):
return 24_000
@@ -293,142 +407,6 @@ class Config(_Config):
def get_spkr(self):
return eval(self.dataset.speaker_name_getter)
- @property
- def scheduler(self):
- cfg = {
- "type": self.hyperparameters.scheduler_type,
- "params": {},
- }
- for k in self.hyperparameters.scheduler_params:
- cfg['params'][k] = self.hyperparameters.scheduler_params[k]
- if self.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in cfg['params']:
- cfg['params']['total_num_steps'] = self.trainer.iterations
- return cfg
- @property
- def fp16_cfg(self):
- if self.trainer.weight_dtype.lower() != "float16":
- return None
- return {
- "enabled": True,
- "auto_cast": True,
- }
- @property
- def bf16_cfg(self):
- return {
- "enabled": self.trainer.weight_dtype.lower() == "bfloat16"
- }
- def get_compression_cfg(self, model):
- if not self.trainer.use_compression_training:
- return None
- weights = [ name[0] for name in model.named_parameters() ]
- bits = 8
- return {
- "weight_quantization": {
- "shared_parameters":{
- "enabled": True,
- "quantizer_kernel": True,
- "schedule_offset": 0,
- "quantize_groups": 64,
- "quantize_verbose": True,
- "quantization_type": "symmetric",
- "rounding": "nearest",
- "quantize_weight_in_forward": True,
- "fp16_mixed_quantize":{
- "enabled": False,
- "quantize_change_ratio": 1
- }
- },
- "different_groups": {
- "wq1": {
- "params": {
- "start_bits": bits,
- "target_bits": bits,
- "quantization_period": 0
- },
- "modules": weights
- }
- }
- },
- "activation_quantization": {
- "shared_parameters":{
- "enabled": True,
- "quantization_type": "symmetric",
- "range_calibration": "dynamic",
- "schedule_offset": 0
- },
- "different_groups": {
- "aq1": {
- "params": {
- "bits": bits
- },
- "modules": weights
- }
- }
- }
- }
- @property
- def zero_cfg(self):
- if self.trainer.zero_optimization_level == 0:
- return None
- return {
- "stage": self.trainer.zero_optimization_level,
- "contiguous_gradients": True,
- "overlap_comm": True,
- "reduce_scatter": True,
- "reduce_bucket_size": 5e8,
- "allgather_bucket_size": 5e8,
- "sub_group_size": 5e8,
- "round_robin_gradients": True,
- "offload_optimizer": {
- "device": "cpu",
- "pin_memory": True
- },
- "offload_param": {
- "device": "cpu",
- "pin_memory": True
- }
- }
- def get_ds_cfg(self, model):
- cfg = {
- "train_micro_batch_size_per_gpu": self.hyperparameters.batch_size,
- "gradient_accumulation_steps": self.hyperparameters.gradient_accumulation_steps,
- "optimizer": {
- "type": self.hyperparameters.optimizer,
- "params": {
- "lr": self.hyperparameters.learning_rate,
- }
- },
- "scheduler": self.hyperparameters.scheduler if self.hyperparameters.scheduler_type != "" else None,
- "gradient_clipping": self.hyperparameters.gradient_clipping,
- "fp16": self.fp16_cfg,
- "bf16": self.bf16_cfg,
- "compression_training": self.get_compression_cfg(model),
- "zero_optimization": self.zero_cfg,
- "comms_logger": {
- "enabled": False
- }
- }
- null_keys = [ k for k in cfg if not cfg[k] ]
- for k in null_keys:
- del cfg[k]
- if os.path.exists("./config/ds_config.json"):
- ds_cfg = json.load(open("./config/ds_config.json", "r", encoding="utf-8"))
- cfg.update(ds_cfg)
- return cfg
def cache_dir(self):
return ".cache" / self.relpath
@@ -455,6 +433,8 @@ cfg.trainer = Trainer(**cfg.trainer)
cfg.inference = Inference(**cfg.inference)
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
+cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed)
# cached_property stopped working...
if cfg.dataset.use_hdf5:
diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py
new file mode 100644
index 0000000..3dcf775
--- /dev/null
+++ b/vall_e/engines/__init__.py
@@ -0,0 +1,9 @@
+from ..config import cfg
+if cfg.trainer.backend == "deepspeed":
+ from .deepspeed import Engine
+elif cfg.trainer.backend == "local":
+ from .base import Engine
+from .base import Engines
+from .base import TrainFeeder
\ No newline at end of file
diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py
new file mode 100644
index 0000000..bec3656
--- /dev/null
+++ b/vall_e/engines/base.py
@@ -0,0 +1,374 @@
+from torch import Tensor
+from typing import Any, Protocol
+Stats = dict[str, float]
+class TrainFeeder(Protocol):
+ def __call__(
+ self, *, engine: "Engine", batch: Any
+ ) -> None | tuple[Tensor, Stats]:
+ ...
+def default_feed(engine, batch):
+ if isinstance(batch, list):
+ engine( *batch )
+ elif isinstance(batch, dict):
+ engine( **batch )
+ else:
+ engine( batch )
+ losses = engine.gather_attribute("loss")
+ loss = torch.stack([*losses.values()]).sum()
+ stats = {}
+ stats |= {k: v.item() for k, v in losses.items()}
+ return loss, stats
+from ..config import cfg
+from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
+import logging
+import time
+import torch
+import torch.distributed
+import os
+from torch import Tensor
+from torch.distributed import all_reduce
+from typing import Any, Protocol
+from .base import TrainFeeder
+_logger = logging.getLogger(__name__)
+# A very naive engine implementation using barebones PyTorch
+class Engine():
+ def __init__(self, *args, **kwargs):
+ self.module = kwargs['model']
+ self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
+ self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
+ self.global_steps = 0
+ self.micro_steps = 0
+ self.gradient_accumulation_steps = cfg.hyperparameters.gradient_accumulation_steps
+ def freeze(self):
+ for p in self.module.parameters():
+ if p.requires_grad:
+ p.requires_grad_(False)
+ self._frozen_params.add(p)
+ def unfreeze(self):
+ for p in self._frozen_params:
+ p.requires_grad_(True)
+ self._frozen_params.clear()
+ @property
+ def global_step(self):
+ return self.global_steps
+ @property
+ def micro_step(self):
+ return self.micro_steps
+ def train_batch_size(self):
+ return cfg.hyperparameters.batch_size
+ def gather_attribute(self, *args, **kwargs):
+ return gather_attribute(self.module, *args, **kwargs)
+ def dispatch_attribute(self, *args, **kwargs):
+ return dispatch_attribute(self.module, *args, **kwargs)
+ def save_checkpoint(self, save_dir, tag ):
+ torch.save({
+ "global_step": self.global_step,
+ "micro_step": self.micro_step,
+ "module": self.module.state_dict(),
+ "optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
+ "lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
+ }, save_dir / tag / "state.pth")
+ def load_checkpoint(self, load_dir, tag, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
+ state = torch.load(load_dir / tag / "state.pth")
+ self.global_step = state['global_step']
+ self.micro_step = state['micro_step']
+ self.module.load_state_dict(state['module'])
+ load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
+ load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
+ if load_optimizer_states:
+ self.optimizer.load_state_dict(state['optimizer'])
+ if load_lr_scheduler_states:
+ self.lr_scheduler.load_state_dict(state['lr_scheduler'])
+ def eval(self):
+ return self.module.eval()
+ def train(self):
+ return self.module.train()
+ def to(self, *args, **kwargs):
+ self.module = self.module.to(*args, **kwargs)
+ return self.module
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+ def forward(self, *args, **kwargs):
+ return self.module.forward(*args, **kwargs)
+ def backward(self, loss):
+ return (loss / self.gradient_accumulation_steps).backward()
+ def step(self):
+ with torch.set_grad_enabled(self.gradient_accumulation_steps > 1):
+ self.micro_steps += 1
+ if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0:
+ self.global_steps += 1
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ def get_lr(self):
+ lrs = []
+ for param_group in self.optimizer.param_groups:
+ if 'lr' in param_group:
+ lrs.append(param_group['lr'])
+ return lrs
+ def set_lr(self, lr):
+ for param_group in self.optimizer.param_groups:
+ if 'lr' in param_group:
+ param_group['lr'] = lr
+ def get_global_grad_norm(self):
+ return 0.0
+ def traverse(self, *args, **kwargs):
+ self.forward(*args, **kwargs)
+ losses = self.gather_attribute("loss")
+ loss = torch.stack([*losses.values()]).sum()
+ stats = {}
+ stats |= {k: v.item() for k, v in losses.items()}
+ stats |= self.gather_attribute("scalar")
+ self.backward(loss)
+ self.step()
+ return stats
+# and now to ignore everything from the above
+class Engines(dict[str, Engine]):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.setup()
+ def setup(self):
+ self._global_step = 0
+ self._micro_step = 0
+ @property
+ def global_step(self):
+ return self._global_step
+ @property
+ def micro_step(self):
+ return self._micro_step
+ def gather_attribute(self, *args, **kwargs):
+ ret = {}
+ for engine in self.values():
+ ret |= engine.gather_attribute(*args, **kwargs)
+ return ret
+ def dispatch_attribute(self, *args, **kwargs):
+ for engine in self.values():
+ engine.dispatch_attribute(*args, **kwargs)
+ def save_checkpoint(self, tag=None):
+ if not tag:
+ tag = cfg.trainer.save_tag
+ tag = tag.lower()
+ if tag[:2] == "it" or tag[:4] == "step":
+ tag = f'{self.global_step}'
+ cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
+ for name, engine in self.items():
+ engine.save_checkpoint(cfg.ckpt_dir / name, tag=tag)
+ def load_checkpoint(self, tag=None):
+ if not tag:
+ tag = cfg.trainer.load_tag
+ for name, engine in self.items():
+ load_dir = cfg.ckpt_dir / name
+ engine.load_checkpoint(
+ tag=tag,
+ load_dir=load_dir,
+ load_module_strict=cfg.trainer.strict_loading,
+ load_optimizer_states=cfg.trainer.load_states,
+ load_lr_scheduler_states=cfg.trainer.load_states,
+ )
+ if cfg.trainer.restart_step_count:
+ engine.global_steps = 0
+ # update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
+ if cfg.hyperparameters.scheduler_type == "":
+ self.set_lr(cfg.hyperparameters.learning_rate)
+ self._update_global_step()
+ self._update_micro_step()
+ def set_lr(self, lr):
+ for engine in self.values():
+ engine.set_lr(lr)
+ def _update_global_step(self):
+ for engine in self.values():
+ self._global_step = max(self._global_step, engine.global_step)
+ def _update_micro_step(self):
+ for engine in self.values():
+ self._micro_step = max(self._micro_step, engine.micro_step)
+ def train_batch_size(self):
+ batch_size = 0
+ for engine in self.values():
+ batch_size = max(batch_size, engine.train_batch_size())
+ def eval(self):
+ for engine in self.values():
+ engine.eval()
+ def train(self):
+ for engine in self.values():
+ engine.train()
+ def traverse(self):
+ stats = {}
+ for name, engine in self.items():
+ stat = engine.traverse()
+ stats.update(flatten_dict({ name.split("-")[0]: stat }))
+ return stats
+ def step(self, batch, feeder: TrainFeeder = default_feed, device=torch.cuda.current_device()):
+ total_elapsed_time = 0
+ stats: Any = dict()
+ if cfg.trainer.gc_mode == 'step':
+ do_gc()
+ batch = to_device(batch, device)
+ for name, engine in self.items():
+ #torch.cuda.synchronize()
+ if cfg.trainer.gc_mode == 'substep':
+ do_gc()
+ start_time = time.time()
+ tries = 4
+ n_ooms = torch.zeros([], device=cfg.device)
+ if cfg.trainer.aggressive_optimizations:
+ batch = to_device(batch, device)
+ res = feeder( engine=engine, batch=batch )
+ """
+ while tries >= 0:
+ try:
+ res = feeder( engine=engine, batch=batch )
+ break
+ except RuntimeError as e:
+ print("Forward", str(e))
+ if "out of memory" not in str(e):
+ self.save_checkpoint()
+ raise e
+ # shrink batch size until it's happy
+ for k in batch:
+ batch[k] = batch[k][:-1]
+ if tries <= 0:
+ # trigger OOM
+ n_ooms += 1
+ else:
+ # also do GC
+ do_gc()
+ continue
+ all_reduce(n_ooms)
+ if n_ooms.item() > 0:
+ self.save_checkpoint()
+ raise RuntimeError("Out of memory during forward pass!")
+ """
+ if res is None:
+ continue
+ loss, engine_stats = res
+ engine_stats |= self.gather_attribute("scalar")
+ n_ooms = torch.zeros([], device=cfg.device)
+ if cfg.trainer.aggressive_optimizations:
+ batch = to_device(batch, 'cpu')
+ engine.backward(loss)
+ """
+ try:
+ engine.backward(loss)
+ except RuntimeError as e:
+ print("Backwards:", str(e))
+ if "out of memory" not in str(e):
+ self.save_checkpoint()
+ raise e
+ n_ooms += 1
+ all_reduce(n_ooms)
+ if n_ooms.item() > 0:
+ self.save_checkpoint()
+ raise RuntimeError("Out of memory during backwards pass!")
+ """
+ engine.step()
+ #torch.cuda.synchronize()
+ elapsed_time = time.time() - start_time
+ total_elapsed_time += elapsed_time
+ stats.update(
+ flatten_dict(
+ {
+ name.split("-")[0]: dict(
+ loss=loss.item(),
+ lr=engine.get_lr()[0],
+ grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
+ elapsed_time=elapsed_time,
+ engine_step=engine.global_step,
+ **engine_stats,
+ )
+ }
+ ),
+ )
+ self._update_global_step()
+ self._update_micro_step()
+ stats["batch_size"] = self.train_batch_size() # len(batch["text"])
+ stats["elapsed_time"] = total_elapsed_time
+ stats["wall_time"] = time.time()
+ stats["global_step"] = self.global_step
+ return stats
diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py
new file mode 100644
index 0000000..850a77f
--- /dev/null
+++ b/vall_e/engines/deepspeed.py
@@ -0,0 +1,89 @@
+# https://github.com/enhuiz/pytorch-training-utilities
+# to-do: replace this
+# to-do: swap out deepspeed
+from ..config import cfg
+from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
+import logging
+import time
+import torch
+import torch.distributed
+from torch import Tensor
+from torch.distributed import all_reduce
+from typing import Any, Protocol
+from .base import TrainFeeder
+_logger = logging.getLogger(__name__)
+from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distributed
+from deepspeed.accelerator import get_accelerator
+initialized_dist = False
+if not initialized_dist:
+ initialized_dist = True
+ init_distributed()
+class Engine(DeepSpeedEngine):
+ def __init__(self, *args, **kwargs):
+ kwargs['config'] = cfg.trainer.deepspeed.get_ds_cfg(model=kwargs['model'])
+ kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
+ super().__init__(None, *args, **kwargs)
+ self._frozen_params = set()
+ def freeze(self):
+ for p in self.module.parameters():
+ if p.requires_grad:
+ p.requires_grad_(False)
+ self._frozen_params.add(p)
+ def unfreeze(self):
+ for p in self._frozen_params:
+ p.requires_grad_(True)
+ self._frozen_params.clear()
+ @property
+ def global_step(self):
+ return self.global_steps
+ @property
+ def micro_step(self):
+ return self.micro_steps
+ def gather_attribute(self, *args, **kwargs):
+ return gather_attribute(self.module, *args, **kwargs)
+ def dispatch_attribute(self, *args, **kwargs):
+ return dispatch_attribute(self.module, *args, **kwargs)
+ def set_lr(self, lr):
+ try:
+ if hasattr(self.optimizer, 'param_groups'):
+ print(self.optimizer.param_groups)
+ for param_group in self.optimizer.param_groups:
+ param_group['lr'] = lr
+ else:
+ self.optimizer.set_lr(lr)
+ except Exception as e:
+ print(str(e))
+ def traverse(self, *args, **kwargs):
+ self.forward(*args, **kwargs)
+ losses = self.gather_attribute("loss")
+ loss = torch.stack([*losses.values()]).sum()
+ stats = {}
+ stats |= {k: v.item() for k, v in losses.items()}
+ stats |= self.gather_attribute("scalar")
+ self.backward(loss)
+ self.step()
+ return stats
\ No newline at end of file
diff --git a/vall_e/inference.py b/vall_e/inference.py
index 944daa0..8e7914a 100755
--- a/vall_e/inference.py
+++ b/vall_e/inference.py
@@ -72,8 +72,8 @@ class TTS():
prom = to_device(prom, self.device).to(torch.int16)
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
- resp_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp)
- resps_list = [r.unsqueeze(-1) for r in resp_list]
+ resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp)
+ resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp)
wav, sr = qnt.decode_to_file(resps_list[0], out_path)
diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py
index dc33d91..93fe7a5 100755
--- a/vall_e/models/__init__.py
+++ b/vall_e/models/__init__.py
@@ -1,20 +1,22 @@
from .ar import AR
from .nar import NAR
-def get_model(model):
- if model.name == "ar":
+def get_model(cfg):
+ if cfg.name == "ar":
Model = AR
- elif model.name == "nar":
+ elif cfg.name == "nar":
Model = NAR
- raise f"invalid model name: {model.name}"
- name = model.name
+ raise f"invalid model name: {cfg.name}"
+ name = cfg.name
model = Model(
+ model._cfg = cfg
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py
index fabcd7b..4c5e293 100755
--- a/vall_e/models/ar.py
+++ b/vall_e/models/ar.py
@@ -50,112 +50,80 @@ class AR(Base):
text_list: list[Tensor],
proms_list: list[Tensor],
- resp_list: list[Tensor] | None = None,
+ resps_list: list[Tensor] | None = None,
max_steps: int = 1000,
sampling_temperature: float = 1.0,
naive: bool = True,
- if resp_list is not None:
+ if resps_list is not None:
+ resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
return super().forward(
- text_list,
- proms_list,
- self._unsqueeze_list(resp_list),
- resp_list,
+ text_list=text_list,
+ proms_list=proms_list,
+ resps_list=self._unsqueeze_list(resps_list),
+ targ_list=resps_list,
device = text_list[0].device
- resp_list: list[Tensor] = [
+ resps_list: list[Tensor] = [
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
stopped = torch.zeros(len(text_list), device=device).bool()
- if self.arch_type == "transformer":
- naive = True
chunk_size = 1 # don't really know what to do about this desu
state = None
start = 0
- # prefill
- if self.arch_type == "retnet/local":
- # pre-process
- state = [
- [
- torch.zeros(self.retnet.hidden_dim // self.retnet.heads, self.retnet.v_dim // self.retnet.heads, device=device).unsqueeze(0).repeat(len(text_list), 1, 1)
- for _ in range(self.retnet.heads)
- ] for _ in range(self.retnet.layers)
- ]
- resps_list = self._unsqueeze_list(resp_list)
- x_list = self._samplewise_merge_tensors(
- self.text_emb(text_list),
- self.proms_emb(proms_list),
- self.resps_emb(resps_list),
- sep=self.sep,
- )
- x, m = list_to_tensor(x_list)
- start = x.shape[1]
- for i in trange(start-1):
- _, state = self.retnet.forward_recurrent( x[:, i:i+1, :], state, i+1 )
for n in trange(max_steps // chunk_size):
# get next in sequence
r, state = super().forward(
- self._unsqueeze_list(resp_list),
+ self._unsqueeze_list(resps_list),
- state=state,
+ state=state if not naive else None,
# append outputted token
for i, ri in enumerate(r):
- resp_list[i] = torch.cat([resp_list[i], ri[None]])
+ resps_list[i] = torch.cat([resps_list[i], ri[None]])
# stop token found
stopped |= r == self.stop_token
if stopped.all().item():
- pruned = [self._prune(r) for r in resp_list]
+ pruned = [self._prune(r) for r in resps_list]
return pruned
def example_usage():
+ cfg.trainer.backend = "local"
from functools import partial
from einops import repeat
from ..emb.qnt import decode_to_file
- from ..utils import gather_attribute
+ from ..engines import Engine
+ from tqdm import tqdm
device = "cpu"
+ x8 = partial(repeat, pattern="t -> t l", l=2)
symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"):
split = content.split(" ")
phones = [f""] + [ " " if not p else p for p in split ] + [f""]
return torch.tensor([*map(symmap.get, phones)]).to()
- qnt = torch.load("data/qnt.pt")[0, 0].to(device)
- kwargs = {
- 'n_tokens': 1024,
- 'd_model': 1024,
- 'n_heads': 16,
- 'n_layers': 12,
- }
+ qnt = torch.load("data/qnt.pt")[0].t()[:, :2].to(device)
- model = AR(**kwargs).to(device)
- x8 = partial(repeat, pattern="t -> t l", l=2)
text_list = [
#torch.tensor([1, 2, 3], device=device),
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
@@ -164,41 +132,41 @@ def example_usage():
x8(torch.tensor([1, 2, 3], device=device)),
- resp_list = [
+ resps_list = [
text_list = text_list[:1]
proms_list = proms_list[:1]
- resp_list = resp_list[:1]
+ resps_list = resps_list[:1]
- model.eval()
- out = model(text_list, proms_list, max_steps=75)[0]
- print("qnt:", qnt.shape, qnt)
- print("out:", out.shape, out)
- wav, sr = decode_to_file(out, "data/test/test.ar.init.wav", device=device)
+ kwargs = {
+ 'n_tokens': 1024,
+ 'd_model': 1024,
+ 'n_heads': 16,
+ 'n_layers': 12,
+ }
+ model = AR(**kwargs).to(device)
+ engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4))
+ def sample( name, steps=400 ):
+ engine.eval()
+ out = engine(text_list, proms_list, max_steps=steps)
+ for i, o in enumerate(out):
+ wav, sr = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)
- optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
+ def train():
+ engine.train()
+ t = trange(60)
+ for i in t:
+ stats = {"step": i}
+ stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
- model.train()
- for i in trange(60):
- optimizer.zero_grad()
- _ = model(text_list, proms_list, resp_list)
- losses = gather_attribute(model, "loss")
- loss = sum(losses.values())
- loss.backward()
- optimizer.step()
- if i % 20 == 0:
- print(f"iter={i}, {losses}.")
- model.eval()
- out = model(text_list, proms_list, max_steps=400)
- print("qnt:", qnt.shape, qnt)
- for i, o in enumerate(out):
- print("out:", i, o.shape, o)
- wav, sr = decode_to_file(o, f"data/test/test.ar.{i}.wav", device=device)
+ t.set_description(f"{stats}")
+ sample("init", 75)
+ train()
+ sample("final")
if __name__ == "__main__":
diff --git a/vall_e/models/base.py b/vall_e/models/base.py
index 7eb3902..9eb0c36 100755
--- a/vall_e/models/base.py
+++ b/vall_e/models/base.py
@@ -176,17 +176,7 @@ class Base(nn.Module):
self.retnet = RetNetDecoder(
- elif self.arch_type == "retnet/local":
- self.retnet = RetNet(
- layers=n_layers,
- hidden_dim=d_model,
- ffn_size=d_model * 4,
- heads=n_heads,
- dropout=p_dropout,
- norm_type=self.norm_type,
- n_levels=self.n_resp_levels,
- double_v_dim=True
- )
self.classifier = nn.Linear(d_model, n_resp_tokens)
self.accuracy_metric = MulticlassAccuracy(
@@ -272,7 +262,7 @@ class Base(nn.Module):
text_list: [t] * b
proms_list: [t' l] * b, l quantization levels.
resps_list: [t'' l] * b, l quantization levels.
- targ_list: [t''] * b, one quantization level only, when given, loss will be computed
+ targ_list: [t''] * b, one quantization level only; when given, loss will be computed
quant_levels: specify which quant_levels to feed forward, used in NAR mode.
shift_targ_list: whether to shift target list when computing loss. True if AR.
return_all_resp: True if NAR.
@@ -298,24 +288,12 @@ class Base(nn.Module):
elif self.arch_type == "retnet":
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
state = self.retnet.get_incremental_state( state, 'prev_state' )
- elif self.arch_type == "retnet/local":
- # recurrent inferencing
- if self.causal and state is not None:
- last = x.shape[1]
- x, state = self.retnet.forward_recurrent(
- x[:, last-1:last, :], # nasty way to grab the last embedding to forward
- state,
- last
- )
- else:
- x = self.retnet( x, quant_levels )
x = self.classifier(x) * m
# Remove padding
h_list = [hi[:li] for hi, li in zip(x, map(len, x_list))]
# compute loss if the target is given
if targ_list is not None:
if any([l == 0 for l in map(len, targ_list)]):
@@ -337,20 +315,24 @@ class Base(nn.Module):
# the NAR doesn't need to compute the loss for it
if self.resp_loss_only:
text_prom_list[i][:] = self.ignore_index
# roll the text/prompt for loss computing
- # the AR benefits from this
+ # the AR benefits from this, for some reason I'll figure out later
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
text_prom_list[i][-1] = self.ignore_index
- # necessary to roll the target if recurrently/causally/autoregressively generating, or it won't be able to work
+ # for the AR, roll by one and mark the ending with a stop token
+ # this coerces the model into properly inferencing causally
+ # why we don't just append a stop token in the dataloader, who knows
if shift_targ_list:
targ_list = [*targ_list]
for i in range(len(targ_list)):
targ_list[i] = targ_list[i].roll(-1, dims=0)
targ_list[i][-1] = self.stop_token
- # generate the sequence
+ # create the new target sequence to compute the loss against
y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep )
self.loss = dict(
@@ -367,6 +349,7 @@ class Base(nn.Module):
del prom_list
del text_prom_list
del y_list
# return the entire generated token string
if return_all:
@@ -387,124 +370,90 @@ class Base(nn.Module):
return ret, state
def example_usage():
+ from ..config import cfg
+ cfg.trainer.backend = "local"
from functools import partial
from einops import repeat
- from tqdm import trange
- from ..utils import gather_attribute
from ..emb.qnt import decode_to_file
+ from ..engines import Engine, Engines
+ from tqdm import tqdm, trange
from .ar import AR
from .nar import NAR
+ device = "cpu"
+ x8 = partial(repeat, pattern="t -> t l", l=2)
symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"):
split = content.split(" ")
phones = [f""] + [ " " if not p else p for p in split ] + [f""]
return torch.tensor([*map(symmap.get, phones)]).to()
- device = "cpu"
kwargs = {
'n_tokens': 1024,
'd_model': 1024,
'n_heads': 16,
'n_layers': 12,
- model_ar = AR(**kwargs).to(device)
- model_nar = NAR(**kwargs).to(device)
+ models = { "ar": AR(**kwargs).to(device), "nar": NAR(**kwargs).to(device) }
+ engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
train = True
- if train:
- qnt = torch.load("data/qnt.pt").to(device)
- text_list = [
- tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
- #tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device),
- ]
- x8 = partial(repeat, pattern="t -> t l", l=2)
- proms_list = [
- qnt[0][:2,:].t().to(device),
- #x8(torch.tensor([1, 2, 3], device=device)),
- # x8(torch.tensor([2, 3], device=device)),
- ]
+ qnt = torch.load("data/qnt.pt")[0].t()[:, :2].to(device)
+ text_list = [
+ tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
+ #tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device),
+ ]
- resp_list_ar = [
- qnt[0,0].to(device),
- # qnt[0,0].to(device),
- ]
- resp_list_nar = [
- qnt[0][:2,:].t().to(device),
- # qnt[0][:2,:].t().to(device),
- ]
- model_ar.train()
- optimizer = torch.optim.AdamW(model_ar.parameters(), lr=1e-4)
- for i in trange(60):
- optimizer.zero_grad()
- _ = model_ar(text_list, proms_list, resp_list_ar)
- losses = gather_attribute(model_ar, "loss")
- loss = sum(losses.values())
- loss.backward()
- optimizer.step()
- if i % 20 == 0:
- print(f"iter={i}, {losses}.")
- model_nar.train()
- optimizer = torch.optim.AdamW(model_nar.parameters(), lr=1e-4)
- for i in trange(60):
- optimizer.zero_grad()
- _ = model_nar(text_list, proms_list, resps_list=resp_list_nar)
- losses = gather_attribute(model_nar, "loss")
- loss = sum(losses.values())
- loss.backward()
- optimizer.step()
- if i % 20 == 0:
- stats = {k: v.item() for k, v in losses.items()}
- stats["loss"] = loss.item()
- print(f"iter={i}, {stats}.")
- else:
- qnt = torch.load("data/test/test.qnt.pt")[0][:2,:].t().to(device)
- text_list = [
- #tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
- tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device),
- ]
- proms_list = [
- qnt.to(device),
- ]
- model_ar.load_state_dict(torch.load("data/test/ar.pth"))
- model_nar.load_state_dict(torch.load("data/test/nar.pth"))
- model_ar.eval()
- resp_list = model_ar(text_list, proms_list, max_steps=300, sampling_temperature=1.0)
- resps_list = [r.unsqueeze(-1) for r in resp_list]
+ proms_list = [
+ qnt.to(device),
+ ]
+ resps_list = [
+ qnt.to(device),
+ ]
- print("qnt:", qnt.shape, qnt)
- print("out:", resp_list[0].shape, resp_list[0])
- wav, sr = decode_to_file(resp_list[0], "data/test/test.ar.init.wav", device=device)
- print(wav, sr)
+ def sample( name, steps=400 ):
+ AR = None
+ NAR = None
- model_nar.eval()
- codes = model_nar(
- text_list,
- proms_list,
- resps_list=resps_list,
- sampling_temperature=1.0,
- )[0]
+ engines.eval()
+ for name, engine in engines.items():
+ if name[:2] == "ar":
+ AR = engine
+ elif name[:3] == "nar":
+ NAR = engine
+ resps_list = AR(text_list, proms_list, max_steps=steps, sampling_temperature=1.0)
+ resps_list = [r.unsqueeze(-1) for r in resps_list]
+ codes = NAR( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
- print("qnt:", qnt.shape, qnt)
- print("codes:", codes.shape, codes)
+ decode_to_file(resps_list[0], f"./data/ar.{name}.wav", device=device)
+ decode_to_file(codes[0], f"./data/ar+nar.{name}.wav", device=device)
+ if train:
+ sample("init", 15)
- wav, sr = decode_to_file(codes, "data/test/test.ar+nar.init.wav", device=device)
- print(wav, sr)
+ engines.train()
+ t = trange(60)
+ for i in t:
+ """
+ stats = {"step": i}
+ for name, engine in engines.items():
+ stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
+ """
+ stats = engines.step({"text_list": text_list, "proms_list": proms_list, "resps_list": resps_list}, device="cpu")
+ t.set_description(f"{stats}")
+ else:
+ for name, engine in engines.items():
+ engine.module.load_state_dict(torch.load(f"./data/{name}.pth"))
+ sample("final")
if __name__ == "__main__":
diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py
index 2cba56e..57dc966 100755
--- a/vall_e/models/nar.py
+++ b/vall_e/models/nar.py
@@ -80,7 +80,6 @@ class NAR(Base):
- # Yes, just nothing as we are training
prev_list = []
prev_list = resps_list
@@ -93,7 +92,7 @@ class NAR(Base):
quant_levels = torch.full((len(text_list),), level, device=device)
- resp_list, _ = super().forward(
+ resps_list, _ = super().forward(
@@ -105,24 +104,22 @@ class NAR(Base):
prev_list = [
torch.cat([rs, r.unsqueeze(-1)], dim=-1)
- for rs, r in zip(prev_list, resp_list)
+ for rs, r in zip(prev_list, resps_list)
return prev_list
def example_usage():
+ cfg.trainer.backend = "local"
from functools import partial
- from pathlib import Path
from einops import repeat
from ..emb.qnt import decode_to_file
- from ..utils import gather_attribute
- from ..config import cfg
+ from ..engines import Engine
+ from tqdm import tqdm
device = "cpu"
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"):
@@ -130,15 +127,8 @@ def example_usage():
phones = [f""] + [ " " if not p else p for p in split ] + [f""]
return torch.tensor([*map(symmap.get, phones)]).to()
- resps = torch.load("data/qnt.pt")[0][:2, :].to(device)
- kwargs = {
- 'n_tokens': 1024,
- 'd_model': 1024,
- 'n_heads': 16,
- 'n_layers': 12,
- }
- model = NAR(**kwargs).to(device)
+ # to-do: unmangle this and the resp shit
+ qnt = torch.load("data/qnt.pt")[0].t()[:, :2].to(device)
text_list = [
#torch.tensor([1, 2, 3], device=device),
@@ -149,65 +139,36 @@ def example_usage():
x8(torch.tensor([2, 3], device=device)),
- resps_x1_list = [
- resps[:1].t().to(device),
+ resps_list = [
+ qnt.to(device),
- resps_x8_list = [
- resps.t().to(device),
- ]
+ kwargs = {
+ 'n_tokens': 1024,
+ 'd_model': 1024,
+ 'n_heads': 16,
+ 'n_layers': 12,
+ }
+ model = NAR(**kwargs).to(device)
+ engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4))
- model.eval()
- codes = model(
- text_list,
- proms_list,
- resps_list=resps_x1_list,
- sampling_temperature=0.2,
- )[0]
+ def sample( name ):
+ engine.eval()
+ codes = engine( text_list, proms_list, resps_list=[r[..., 0].unsqueeze(-1) for r in resps_list], sampling_temperature=0.2 )
+ decode_to_file( codes[0], f"data/nar.{name}.wav", device )
- decode_to_file(
- codes,
- Path("data/test/test.nar.init.wav"),
- device
- )
+ def train():
+ engine.train()
+ t = trange(60)
+ for i in t:
+ stats = {"step": i}
+ stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
- optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
+ t.set_description(f"{stats}")
- model.train()
- for i in trange(50):
- optimizer.zero_grad()
- _ = model(text_list, proms_list, resps_list=resps_x8_list)
- losses = gather_attribute(model, "loss")
- loss = sum(losses.values())
- loss.backward()
- optimizer.step()
- if i % 20 == 0:
- stats = {k: v.item() for k, v in losses.items()}
- stats["loss"] = loss.item()
- print(f"iter={i}, {stats}.")
- model.eval()
- for i in trange(1, 2): # cfg.models.prom_levels):
- resps_list = [
- resps[:i].t().to(device),
- ]
- codes = model(
- text_list,
- proms_list,
- resps_list=resps_list,
- sampling_temperature=0.2,
- )[0]
- decode_to_file(
- codes,
- Path(f"data/test/test.nar.1-{i}.wav"),
- device
- )
+ sample("init")
+ train()
+ sample("final")
if __name__ == "__main__":
diff --git a/vall_e/train.py b/vall_e/train.py
index 9bf8686..748f22d 100755
--- a/vall_e/train.py
+++ b/vall_e/train.py
@@ -1,17 +1,13 @@
# todo: clean this mess up
-# todo: yank deepspeed dependent code out into its own thing
from .config import cfg
from .data import create_train_val_dataloader
from .emb import qnt
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
-from .utils import wrapper as ml
-from .models import get_models
import auraloss
-import deepspeed
import json
import logging
import random
@@ -21,10 +17,6 @@ import traceback
from collections import defaultdict
-from deepspeed import comm as dist
-from deepspeed import DeepSpeedConfig
-from deepspeed.accelerator import get_accelerator
from tqdm import tqdm
mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cuda")
@@ -38,210 +30,120 @@ def left_crop(x, len):
return x[..., 0:len]
_logger = logging.getLogger(__name__)
-deepspeed._initialized_dist = False
-def load_engines():
- if not deepspeed._initialized_dist:
- deepspeed._initialized_dist = True
- deepspeed.init_distributed()
+def train_feeder(engine, batch):
+ engine( text_list=batch["text"], proms_list=batch["proms"], resps_list=batch["resps"] )
- models = get_models(cfg.models.get())
- engines = dict()
+ losses = engine.gather_attribute("loss")
- for name in models:
- model = models[name]
+ loss = torch.stack([*losses.values()]).sum()
- optimizer = None
- lr_scheduler = None
+ stats = {}
+ stats |= {k: v.item() for k, v in losses.items()}
- Adam = ml.Adam
- AdamW = ml.AdamW
+ return loss, stats
- if cfg.hyperparameters.optimizer.lower() == "adamw-torch":
- optimizer = AdamW(
- model.parameters(),
- lr=cfg.hyperparameters.learning_rate,
- betas=(0.9, 0.96),
- eps=1e-07,
- weight_decay=0.01,
- )
+def run_eval(engines, eval_name, dl):
+ engines_stats = {
+ 'eval': eval_name
+ }
- if cfg.trainer.load_state_dict:
- load_dir = cfg.ckpt_dir / name / "fp32.pth"
- model.load_state_dict(torch.load(load_dir))
+ AR = None
+ NAR = None
- ds_cfg=cfg.get_ds_cfg(model=model)
- config_class=DeepSpeedConfig(ds_cfg)
- engines[name] = trainer.Engine(
- model=model,
- config=ds_cfg,
- config_class=config_class,
- optimizer=optimizer,
- lr_scheduler=lr_scheduler,
- )
+ names = []
+ for name, engine in engines.items():
+ names.append(name)
+ if name[:2] == "ar":
+ AR = engine
+ elif name[:3] == "nar":
+ NAR = engine
+ stats = defaultdict(list)
+ stats['loss'] = []
+ def process( name, batch, resps_list ):
+ for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]):
+ if len(hyp) == 0:
+ continue
+ filename = f'{speaker}_{path.parts[-1]}'
+ # to-do, refine the output dir to be sane-er
+ ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
+ hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")
+ prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav")
+ hyp_path.parent.mkdir(parents=True, exist_ok=True)
+ ref_path.parent.mkdir(parents=True, exist_ok=True)
+ prom_path.parent.mkdir(parents=True, exist_ok=True)
+ ref_audio, sr = qnt.decode_to_file(ref, ref_path)
+ hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
+ prom_audio, sr = qnt.decode_to_file(prom, prom_path)
+ # pseudo loss calculation since we don't get the logits during eval
+ min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
+ ref_audio = ref_audio[..., 0:min_length]
+ hyp_audio = hyp_audio[..., 0:min_length]
+ try:
+ stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
+ except Exception as e:
+ print(str(e))
+ for batch in tqdm(dl):
+ batch: dict = to_device(batch, cfg.device)
+ # if we're training both models, provide output for both
+ if AR is not None and NAR is not None:
+ name = "+".join(names)
+ resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.temperature)
+ resps_list = [ r.unsqueeze(-1) for r in resps_list ]
+ resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.temperature)
+ process( name, batch, resps_list )
+ else:
+ for name in engines:
+ model = engines[name]
+ if name.startswith("ar"):
+ resps_list = model(
+ text_list=batch["text"],
+ proms_list=batch["proms"],
+ max_steps=cfg.evaluation.steps,
+ sampling_temperature=cfg.evaluation.ar_temperature,
+ )
+ resps_list = [r.unsqueeze(-1) for r in resps_list]
+ elif name.startswith("nar"):
+ resps_list = model(
+ text_list=batch["text"],
+ proms_list=batch["proms"],
+ resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]],
+ sampling_temperature=cfg.evaluation.nar_temperature,
+ )
+ else:
+ raise NotImplementedError(name)
+ process( name, batch, resps_list )
+ stats = {k: sum(v) / len(v) for k, v in stats.items()}
+ engines_stats.update(flatten_dict({ name: stats }))
+ iteration = engines.global_step
+ engines_stats['it'] = iteration
+ engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
+ _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
- return trainer.load_engines(engines, cfg)
def main():
- #dist.init_distributed(dist_backend=get_accelerator().communication_backend_name())
- if not deepspeed._initialized_dist:
- deepspeed._initialized_dist = True
- deepspeed.init_distributed()
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
- def train_feeder(engines, batch, name):
- stats = {}
- model = engines[name]
- if name.startswith("ar"):
- _ = model(
- text_list=batch["text"],
- proms_list=batch["proms"],
- resp_list=[r[..., 0] for r in batch["resps"]],
- )
- elif name.startswith("nar"):
- _ = model(
- text_list=batch["text"],
- proms_list=batch["proms"],
- resps_list=batch["resps"],
- )
- else:
- raise NotImplementedError(name)
- losses = model.gather_attribute("loss")
- loss = torch.stack([*losses.values()]).sum()
- stats = {}
- stats |= {k: v.item() for k, v in losses.items()}
- stats |= engines.gather_attribute("scalar")
- return loss, stats
- @torch.inference_mode()
- def run_eval(engines, eval_name, dl):
- engines_stats = {
- 'eval': eval_name
- }
- AR = None
- NAR = None
- names = []
- for name in engines:
- model = engines[name]
- names.append(name)
- if name[:2] == "ar":
- AR = model
- elif name[:3] == "nar":
- NAR = model
- stats = defaultdict(list)
- stats['loss'] = []
- for batch in tqdm(dl):
- batch: dict = to_device(batch, cfg.device)
- # if we're training both models, provide output for both
- if AR is not None and NAR is not None:
- name = "+".join(names)
- resp_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.temperature)
- resps_list = [ r.unsqueeze(-1) for r in resp_list ]
- resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.temperature)
- for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]):
- if len(hyp) == 0:
- continue
- filename = f'{speaker}_{path.parts[-1]}'
- # to-do, refine the output dir to be sane-er
- ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
- hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")
- prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav")
- hyp_path.parent.mkdir(parents=True, exist_ok=True)
- ref_path.parent.mkdir(parents=True, exist_ok=True)
- prom_path.parent.mkdir(parents=True, exist_ok=True)
- ref_audio, sr = qnt.decode_to_file(ref, ref_path)
- hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
- prom_audio, sr = qnt.decode_to_file(prom, prom_path)
- min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
- ref_audio = ref_audio[..., 0:min_length]
- hyp_audio = hyp_audio[..., 0:min_length]
- stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
- else:
- for name in engines:
- model = engines[name]
- if name.startswith("ar"):
- resp_list = model(
- text_list=batch["text"],
- proms_list=batch["proms"],
- max_steps=cfg.evaluation.steps,
- sampling_temperature=cfg.evaluation.temperature,
- )
- resps_list = [r.unsqueeze(-1) for r in resp_list]
- elif name.startswith("nar"):
- resps_list = model(
- text_list=batch["text"],
- proms_list=batch["proms"],
- resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]],
- sampling_temperature=cfg.evaluation.temperature,
- )
- else:
- raise NotImplementedError(name)
- losses = model.gather_attribute("loss")
- batch_stats = {}
- batch_stats |= {k: v.item() for k, v in losses.items()}
- batch_stats |= engines.gather_attribute("scalar")
- for k, v in batch_stats.items():
- stats[k].append(v)
- for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]):
- if len(hyp) == 0:
- continue
- filename = f'{speaker}_{path.parts[-1]}'
- # to-do, refine the output dir to be sane-er
- ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
- hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")
- prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav")
- hyp_path.parent.mkdir(parents=True, exist_ok=True)
- ref_path.parent.mkdir(parents=True, exist_ok=True)
- prom_path.parent.mkdir(parents=True, exist_ok=True)
- ref_audio, sr = qnt.decode_to_file(ref, ref_path)
- hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
- prom_audio, sr = qnt.decode_to_file(prom, prom_path)
- # pseudo loss calculation since we don't get the logits during eval
- min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
- ref_audio = ref_audio[..., 0:min_length]
- hyp_audio = hyp_audio[..., 0:min_length]
- stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
- stats = {k: sum(v) / len(v) for k, v in stats.items()}
- engines_stats.update(flatten_dict({ name: stats }))
- iteration = engines.global_step
- engines_stats['it'] = iteration
- engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
- _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
def eval_fn(engines):
run_eval(engines, "subtrain", subtrain_dl)
@@ -256,12 +158,10 @@ def main():
- engines_loader=load_engines,
if __name__ == "__main__":
\ No newline at end of file
diff --git a/vall_e/utils/engines.py b/vall_e/utils/engines.py
deleted file mode 100755
index 6549b9a..0000000
--- a/vall_e/utils/engines.py
+++ /dev/null
@@ -1,252 +0,0 @@
-# https://github.com/enhuiz/pytorch-training-utilities
-# to-do: replace this
-# to-do: swap out deepspeed
-from ..config import Config
-from .distributed import fix_unset_envs
-from .utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
-import logging
-import time
-import torch
-import torch.distributed
-from deepspeed import DeepSpeedEngine
-from torch import Tensor
-from torch.distributed import all_reduce
-from typing import Any, Protocol
-Stats = dict[str, float]
-_logger = logging.getLogger(__name__)
-class Engine(DeepSpeedEngine):
- def __init__(self, *args, **kwargs):
- fix_unset_envs()
- super().__init__(None, *args, **kwargs)
- self._frozen_params = set()
- def freeze(self):
- for p in self.module.parameters():
- if p.requires_grad:
- p.requires_grad_(False)
- self._frozen_params.add(p)
- def unfreeze(self):
- for p in self._frozen_params:
- p.requires_grad_(True)
- self._frozen_params.clear()
- @property
- def global_step(self):
- return self.global_steps
- def gather_attribute(self, *args, **kwargs):
- return gather_attribute(self.module, *args, **kwargs)
- def dispatch_attribute(self, *args, **kwargs):
- return dispatch_attribute(self.module, *args, **kwargs)
-class TrainFeeder(Protocol):
- def __call__(
- self, *, engines: "Engines", batch: Any, name: str
- ) -> None | tuple[Tensor, Stats]:
- ...
-class Engines(dict[str, Engine]):
- def setup(self, cfg: Config):
- self._cfg = cfg
- self._global_step = 0
- @property
- def cfg(self) -> Config:
- return self._cfg
- @property
- def config(self):
- return self._cfg
- @property
- def global_step(self):
- return self._global_step
- def gather_attribute(self, *args, **kwargs):
- ret = {}
- for engine in self.values():
- ret |= engine.gather_attribute(*args, **kwargs)
- return ret
- def dispatch_attribute(self, *args, **kwargs):
- for engine in self.values():
- engine.dispatch_attribute(*args, **kwargs)
- def save_checkpoint(self, tag=None):
- if not tag:
- tag = self.cfg.trainer.save_tag
- tag = tag.lower()
- if tag[:2] == "it" or tag[:4] == "step":
- tag = self.global_step
- self.cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
- for name, engine in self.items():
- engine.save_checkpoint(self.cfg.ckpt_dir / name, tag=tag)
- def load_checkpoint(self, tag=None):
- if not tag:
- tag = self.cfg.trainer.load_tag
- for name, engine in self.items():
- load_dir = self.cfg.ckpt_dir / name
- engine.load_checkpoint(
- tag=tag,
- load_dir=load_dir,
- load_module_strict=self.cfg.trainer.strict_loading,
- load_optimizer_states=self.cfg.trainer.load_states,
- load_lr_scheduler_states=self.cfg.trainer.load_states,
- load_module_only=False, # not self.cfg.trainer.load_states,
- )
- if self.cfg.trainer.restart_step_count:
- engine.global_steps = 0
- # update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
- if self.cfg.hyperparameters.scheduler_type == "":
- self.set_lr(self.cfg.hyperparameters.learning_rate)
- self._update_global_step()
- def set_lr(self, lr):
- try:
- for engine in self.values():
- if hasattr(engine.optimizer, 'param_groups'):
- print(engine.optimizer.param_groups)
- for param_group in engine.optimizer.param_groups:
- param_group['lr'] = lr
- else:
- engine.optimizer.set_lr(lr)
- except Exception as e:
- print(str(e))
- def _update_global_step(self):
- for engine in self.values():
- self._global_step = max(self._global_step, engine.global_step)
- def eval(self):
- for engine in self.values():
- engine.eval()
- def train(self):
- for engine in self.values():
- engine.train()
- def step(self, feeder: TrainFeeder, batch):
- total_elapsed_time = 0
- stats: Any = dict()
- if self.cfg.trainer.gc_mode == 'step':
- do_gc()
- batch = to_device(batch, torch.cuda.current_device())
- for name, engine in self.items():
- torch.cuda.synchronize()
- if self.cfg.trainer.gc_mode == 'substep':
- do_gc()
- start_time = time.time()
- tries = 4
- n_ooms = torch.zeros([], device=self.cfg.device)
- if self.cfg.trainer.aggressive_optimizations:
- batch = to_device(batch, torch.cuda.current_device())
- # engine = engine.to(torch.cuda.current_device())
- while tries >= 0:
- try:
- res = feeder( engines=self, batch=batch, name=name )
- break
- except RuntimeError as e:
- print("Forward", str(e))
- if "out of memory" not in str(e):
- self.save_checkpoint()
- raise e
- # shrink batch size until it's happy
- for k in batch:
- batch[k] = batch[k][:-1]
- if tries <= 0:
- # trigger OOM
- n_ooms += 1
- else:
- # also do GC
- do_gc()
- continue
- all_reduce(n_ooms)
- if n_ooms.item() > 0:
- self.save_checkpoint()
- raise RuntimeError("Out of memory during forward pass!")
- if res is None:
- continue
- loss, engine_stats = res
- n_ooms = torch.zeros([], device=self.cfg.device)
- if self.cfg.trainer.aggressive_optimizations:
- batch = to_device(batch, 'cpu')
- try:
- engine.backward(loss)
- except RuntimeError as e:
- print("Backwards:", str(e))
- if "out of memory" not in str(e):
- self.save_checkpoint()
- raise e
- n_ooms += 1
- all_reduce(n_ooms)
- if n_ooms.item() > 0:
- self.save_checkpoint()
- raise RuntimeError("Out of memory during backwards pass!")
- engine.step()
- torch.cuda.synchronize()
- elapsed_time = time.time() - start_time
- total_elapsed_time += elapsed_time
- stats.update(
- flatten_dict(
- {
- name.split("-")[0]: dict(
- loss=loss.item(),
- lr=engine.get_lr()[0],
- grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
- elapsed_time=elapsed_time,
- engine_step=engine.global_step,
- **engine_stats,
- )
- }
- ),
- )
- del loss
- # engine = engine.to('cpu')
- self._update_global_step()
- stats["batch_size"] = len(batch["text"])
- stats["elapsed_time"] = total_elapsed_time
- stats["wall_time"] = time.time()
- stats["global_step"] = self.global_step
- return stats
diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py
index de47763..c796025 100755
--- a/vall_e/utils/trainer.py
+++ b/vall_e/utils/trainer.py
@@ -17,239 +17,258 @@ from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import Protocol
-from ..config import Config
+from ..config import cfg
from .distributed import (
- global_leader_only,
- global_rank,
- is_global_leader,
- is_local_leader,
- local_leader_only,
+ fix_unset_envs,
+ global_leader_only,
+ global_rank,
+ is_global_leader,
+ is_local_leader,
+ local_leader_only,
-from .engines import Engine, Engines, TrainFeeder
+from ..engines import Engine, Engines, TrainFeeder, default_feeder
+from ..models import get_models
from .utils import to_device, do_gc
+from ..utils import wrapper as ml
_logger = logging.getLogger(__name__)
_engines: Engines
_command: str
def get_global_step():
- try:
- return _engines.global_step
- except:
- return None
+ try:
+ return _engines.global_step
+ except:
+ return None
def get_micro_step():
- try:
- return _engines.micro_step
- except:
- return None
-def get_cfg():
- try:
- return _engines.cfg
- except:
- raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
+ try:
+ return _engines.micro_step
+ except:
+ return None
def get_cmd():
- try:
- return _command
- except:
- raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
+ try:
+ return _command
+ except:
+ raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
get_iteration = get_global_step
+def load_engines():
+ models = get_models(cfg.models.get())
+ engines = dict()
-class EnginesLoader(Protocol):
- def __call__(self) -> Engines:
- ...
+ for name in models:
+ model = models[name]
+ optimizer = None
+ lr_scheduler = None
-def load_engines(engines: dict[str, Engine], config: Config):
- engines = Engines(engines)
- engines.setup(config)
- if not engines.cfg.trainer.load_state_dict:
- engines.load_checkpoint()
- return engines
+ if cfg.hyperparameters.optimizer.lower() == "adamw-torch":
+ optimizer = ml.AdamW(
+ model.parameters(),
+ lr=cfg.hyperparameters.learning_rate,
+ betas=(0.9, 0.96),
+ eps=1e-07,
+ weight_decay=0.01,
+ )
+ if cfg.trainer.load_state_dict:
+ load_path = cfg.ckpt_dir / name / "fp32.pth"
+ model.load_state_dict(torch.load(load_path))
+ engines[name] = Engine(
+ model=model,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ )
+ engines = Engines(engines)
+ engines.setup()
+ if not cfg.trainer.load_state_dict:
+ engines.load_checkpoint()
+ return engines
class EvalFn(Protocol):
- def __call__(self, *, engines: Engines):
- ...
+ def __call__(self, *, engines: Engines):
+ ...
class Logger(Protocol):
- def __call__(self, *, data: dict):
- ...
+ def __call__(self, *, data: dict):
+ ...
def _get_stdin_selector():
- selector = selectors.DefaultSelector()
- selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ)
- return selector
+ selector = selectors.DefaultSelector()
+ selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ)
+ return selector
def _non_blocking_input():
- global _command
- l = [""]
- if is_global_leader():
- s = ""
- selector = _get_stdin_selector()
- events = selector.select(timeout=0)
- for key, _ in events:
- s: str = key.fileobj.readline().strip()
- _logger.info(f'Get stdin "{s}".')
- l[0] = s
- broadcast_object_list(l, src=0)
- _command = l[0]
- return _command
+ global _command
+ l = [""]
+ if is_global_leader():
+ s = ""
+ selector = _get_stdin_selector()
+ events = selector.select(timeout=0)
+ for key, _ in events:
+ s: str = key.fileobj.readline().strip()
+ _logger.info(f'Get stdin "{s}".')
+ l[0] = s
+ broadcast_object_list(l, src=0)
+ _command = l[0]
+ return _command
def _make_infinite_epochs(dl):
- while True:
- _logger.info("New epoch starts.")
- yield from tqdm(dl, "Epoch progress", dynamic_ncols=True)
+ while True:
+ _logger.info("New epoch starts.")
+ yield from tqdm(dl, "Epoch progress", dynamic_ncols=True)
def logger(data):
- return _logger.info(json.dumps(data, default=str))
+ return _logger.info(json.dumps(data, default=str))
def seed(seed):
- # Set up random seeds, after fork()
- random.seed(seed + global_rank())
- np.random.seed(seed + global_rank())
- torch.manual_seed(seed + global_rank())
+ # Set up random seeds, after fork()
+ random.seed(seed + global_rank())
+ np.random.seed(seed + global_rank())
+ torch.manual_seed(seed + global_rank())
def train(
- engines_loader: EnginesLoader,
- train_dl: DataLoader,
- train_feeder: TrainFeeder,
- eval_fn: EvalFn,
- logger: Logger = logger,
+ train_dl: DataLoader,
+ train_feeder: TrainFeeder = default_feeder,
+ eval_fn: EvalFn = lambda x: ...,
+ logger: Logger = logger,
- engines = engines_loader()
- cfg = engines.cfg
+ fix_unset_envs()
- """
- if is_local_leader():
- cfg.dump()
- _logger.info(cfg)
- """
+ engines = load_engines()
- # Setup global engines
- global _engines
- _engines = engines
+ """
+ if is_local_leader():
+ cfg.dump()
+ _logger.info(cfg)
+ """
- events = []
+ # Setup global engines
+ global _engines
+ _engines = engines
- eval_fn = global_leader_only(eval_fn)
+ events = []
- # Pre-loop command
- command = _non_blocking_input()
- if command in ["eval", "eval_quit"]:
- engines.eval()
- eval_fn(engines=engines)
- engines.train()
- if command in ["quit", "eval_quit"]:
- return
+ eval_fn = global_leader_only(eval_fn)
- last_save_step = engines.global_step
- last_eval_step = 0
+ # Pre-loop command
+ command = _non_blocking_input()
+ if command in ["eval", "eval_quit"]:
+ engines.eval()
+ eval_fn(engines=engines)
+ engines.train()
+ if command in ["quit", "eval_quit"]:
+ return
- # Training loop
- for batch in _make_infinite_epochs(train_dl):
- if engines.global_step >= cfg.trainer.iterations:
- break
+ last_save_step = engines.global_step
+ last_eval_step = 0
- #batch = to_device(batch, torch.cuda.current_device())
- stats = engines.step(feeder=train_feeder, batch=batch)
+ # Training loop
+ for batch in _make_infinite_epochs(train_dl):
+ if engines.global_step >= cfg.trainer.iterations:
+ break
- iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps
- stats['it'] = iteration
- stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
+ #batch = to_device(batch, torch.cuda.current_device())
+ stats = engines.step(batch=batch, feeder=train_feeder)
- stats['batch'] = {
- 'size': stats['batch_size'],
- 'id': batch['spkr_id'],
- 'index': [ index for index in batch['index'] ],
- 'text_len': [ text.shape[0] for text in batch['text'] ],
- 'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
- 'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
- }
+ iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps
+ stats['it'] = iteration
+ stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
- del stats['batch_size']
- del stats['wall_time']
- del stats['global_step']
+ stats['batch'] = {
+ 'size': stats['batch_size'],
+ 'id': batch['spkr_id'],
+ 'index': [ index for index in batch['index'] ],
+ 'text_len': [ text.shape[0] for text in batch['text'] ],
+ 'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
+ 'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
+ }
- elapsed_time = stats.get("elapsed_time", 0)
- _logger.info(f"Training Metrics: {json.dumps(stats)}.")
+ del stats['batch_size']
+ del stats['wall_time']
+ del stats['global_step']
- command = _non_blocking_input()
+ elapsed_time = stats.get("elapsed_time", 0)
+ _logger.info(f"Training Metrics: {json.dumps(stats)}.")
- if "@" in command:
- what, when = command.split("@")
- try:
- events.append((what, int(when)))
- _logger.info(f"Event {command} registered.")
- except Exception as e:
- _logger.error(e)
- command = ""
+ command = _non_blocking_input()
- # Commands are the current command plus the triggered (i.e. iteration >= trigger point) events
- events = [e for e in events if e[1] >= engines.global_step]
- commands = [command] + [e[0] for e in events if e[1] == engines.global_step]
+ if "@" in command:
+ what, when = command.split("@")
+ try:
+ events.append((what, int(when)))
+ _logger.info(f"Event {command} registered.")
+ except Exception as e:
+ _logger.error(e)
+ command = ""
- for command in commands:
- if command in ["event show", "event"]:
- msg = "Events:\n" + "\n".join(["@".join(map(str, e)) for e in events])
- _logger.info(msg)
+ # Commands are the current command plus the triggered (i.e. iteration >= trigger point) events
+ events = [e for e in events if e[1] >= engines.global_step]
+ commands = [command] + [e[0] for e in events if e[1] == engines.global_step]
- if command == "event clear":
- events.clear()
+ for command in commands:
+ if command in ["event show", "event"]:
+ msg = "Events:\n" + "\n".join(["@".join(map(str, e)) for e in events])
+ _logger.info(msg)
- if "time" in command:
- target_iter = cfg.trainer.iterations
- if " to " in command:
- try:
- target_iter = int(command.split(" to ")[-1])
- except Exception as e:
- _logger.error(e)
- remaining_iters = target_iter - engines.global_step + 1
- remaining_time = int(remaining_iters * elapsed_time)
- _logger.info(humanize.precisedelta(remaining_time))
+ if command == "event clear":
+ events.clear()
- if "lr" in command:
- rate = float(command.split(" ")[-1])
- engines.set_lr(rate)
- print("Updating LR to:", rate)
+ if "time" in command:
+ target_iter = cfg.trainer.iterations
+ if " to " in command:
+ try:
+ target_iter = int(command.split(" to ")[-1])
+ except Exception as e:
+ _logger.error(e)
+ remaining_iters = target_iter - engines.global_step + 1
+ remaining_time = int(remaining_iters * elapsed_time)
+ _logger.info(humanize.precisedelta(remaining_time))
- save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
+ if "lr" in command:
+ rate = float(command.split(" ")[-1])
+ engines.set_lr(rate)
+ print("Updating LR to:", rate)
- saving_commands = ["save"]
+ save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
- if cfg.trainer.save_on_quit:
- saving_commands.append("quit")
+ saving_commands = ["save"]
- if engines.global_step != last_save_step:
- if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
- engines.save_checkpoint()
- last_save_step = engines.global_step
+ if cfg.trainer.save_on_quit:
+ saving_commands.append("quit")
- if engines.global_step != last_eval_step:
- if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
- do_gc()
+ if engines.global_step != last_save_step:
+ if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
+ engines.save_checkpoint()
+ last_save_step = engines.global_step
- engines.eval()
- eval_fn(engines=engines)
- engines.train()
- last_eval_step = engines.global_step
+ if engines.global_step != last_eval_step:
+ if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
+ do_gc()
- if command in ["quit"]:
- return
\ No newline at end of file
+ engines.eval()
+ eval_fn(engines=engines)
+ engines.train()
+ last_eval_step = engines.global_step
+ if command in ["quit"]:
+ return
\ No newline at end of file