big cleanup
This commit is contained in:
parent
2e03e5ac93
commit
c85101403f
|
@ -15,12 +15,12 @@ dataset:
|
|||
workers: 8
|
||||
cache: True
|
||||
|
||||
phones_range: [4, 256]
|
||||
duration_range: [1.0, 12.0]
|
||||
phones_range: [4, 512]
|
||||
duration_range: [1.0, 24.0]
|
||||
|
||||
random_utterance: 1.0
|
||||
max_prompts: 3
|
||||
prompt_duration: 3.0
|
||||
max_prompts: 6
|
||||
prompt_duration: 6.0
|
||||
|
||||
models:
|
||||
_models:
|
||||
|
@ -69,7 +69,8 @@ evaluation:
|
|||
size: 32
|
||||
|
||||
steps: 300
|
||||
temperature: 1.0
|
||||
ar_temperature: 1.0
|
||||
nar_temperature: 0.2
|
||||
|
||||
trainer:
|
||||
iterations: 100_000
|
||||
|
@ -91,7 +92,13 @@ trainer:
|
|||
|
||||
weight_dtype: bfloat16
|
||||
|
||||
zero_optimization_level: 2
|
||||
use_compression_training: True
|
||||
backend: deepspeed
|
||||
deepspeed:
|
||||
zero_optimization_level: 0
|
||||
use_compression_training: True
|
||||
|
||||
use_vocos: False
|
||||
inference:
|
||||
use_vocos: True
|
||||
|
||||
bitsandbytes:
|
||||
enabled: false
|
260
vall_e/config.py
260
vall_e/config.py
|
@ -232,7 +232,121 @@ class Evaluation:
|
|||
size: int = 64
|
||||
|
||||
steps: int = 500
|
||||
temperature: float = 1.0
|
||||
ar_temperature: float = 1.0
|
||||
nar_temperature: float = 0.2
|
||||
|
||||
@dataclass()
|
||||
class DeepSpeed:
|
||||
zero_optimization_level: int = 0
|
||||
use_compression_training: bool = False
|
||||
|
||||
def get_ds_cfg(self, model):
|
||||
weights = [ name[0] for name in model.named_parameters() ]
|
||||
bits = 8
|
||||
|
||||
scheduler_params = {}
|
||||
for k in cfg.hyperparameters.scheduler_params:
|
||||
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
|
||||
|
||||
if cfg.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params:
|
||||
scheduler_params['total_num_steps'] = cfg.trainer.iterations
|
||||
|
||||
ds_cfg = {
|
||||
"train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size,
|
||||
"gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps,
|
||||
"optimizer": {
|
||||
"type": cfg.hyperparameters.optimizer,
|
||||
"params": {
|
||||
"lr": cfg.hyperparameters.learning_rate,
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": cfg.hyperparameters.scheduler_type,
|
||||
"params": scheduler_params,
|
||||
} if cfg.hyperparameters.scheduler_type != "" else None,
|
||||
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
"auto_cast": True,
|
||||
} if cfg.trainer.weight_dtype.lower() == "float16" else None,
|
||||
"bf16": {
|
||||
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16"
|
||||
},
|
||||
"compression_training": {
|
||||
"weight_quantization": {
|
||||
"shared_parameters":{
|
||||
"enabled": True,
|
||||
"quantizer_kernel": True,
|
||||
"schedule_offset": 0,
|
||||
"quantize_groups": 64,
|
||||
"quantize_verbose": True,
|
||||
"quantization_type": "symmetric",
|
||||
"rounding": "nearest",
|
||||
"quantize_weight_in_forward": True,
|
||||
"fp16_mixed_quantize":{
|
||||
"enabled": False,
|
||||
"quantize_change_ratio": 1
|
||||
}
|
||||
},
|
||||
"different_groups": {
|
||||
"wq1": {
|
||||
"params": {
|
||||
"start_bits": bits,
|
||||
"target_bits": bits,
|
||||
"quantization_period": 0
|
||||
},
|
||||
"modules": weights
|
||||
}
|
||||
}
|
||||
},
|
||||
"activation_quantization": {
|
||||
"shared_parameters":{
|
||||
"enabled": True,
|
||||
"quantization_type": "symmetric",
|
||||
"range_calibration": "dynamic",
|
||||
"schedule_offset": 0
|
||||
},
|
||||
"different_groups": {
|
||||
"aq1": {
|
||||
"params": {
|
||||
"bits": bits
|
||||
},
|
||||
"modules": weights
|
||||
}
|
||||
}
|
||||
}
|
||||
} if self.use_compression_training else None,
|
||||
"zero_optimization": {
|
||||
"stage": self.zero_optimization_level,
|
||||
"contiguous_gradients": True,
|
||||
"overlap_comm": True,
|
||||
"reduce_scatter": True,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"sub_group_size": 5e8,
|
||||
"round_robin_gradients": True,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": True
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": True
|
||||
}
|
||||
} if self.zero_optimization_level > 0 else None,
|
||||
"comms_logger": {
|
||||
"enabled": False
|
||||
}
|
||||
}
|
||||
|
||||
null_keys = [ k for k in ds_cfg if not ds_cfg[k] ]
|
||||
for k in null_keys:
|
||||
del ds_cfg[k]
|
||||
|
||||
if os.path.exists("./config/ds_config.json"):
|
||||
ds_cfg.update(json.load(open("./config/ds_config.json", "r", encoding="utf-8")))
|
||||
|
||||
return ds_cfg
|
||||
|
||||
@dataclass()
|
||||
class Trainer:
|
||||
|
@ -256,8 +370,9 @@ class Trainer:
|
|||
|
||||
weight_dtype: str = "float16"
|
||||
|
||||
zero_optimization_level: int = 0
|
||||
use_compression_training: bool = False
|
||||
backend: str = "deepspeed"
|
||||
|
||||
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
||||
|
||||
|
||||
@dataclass()
|
||||
|
@ -284,7 +399,6 @@ class Config(_Config):
|
|||
inference: Inference = field(default_factory=lambda: Inference)
|
||||
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
|
||||
|
||||
|
||||
@property
|
||||
def sample_rate(self):
|
||||
return 24_000
|
||||
|
@ -293,142 +407,6 @@ class Config(_Config):
|
|||
def get_spkr(self):
|
||||
return eval(self.dataset.speaker_name_getter)
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
cfg = {
|
||||
"type": self.hyperparameters.scheduler_type,
|
||||
"params": {},
|
||||
}
|
||||
|
||||
for k in self.hyperparameters.scheduler_params:
|
||||
cfg['params'][k] = self.hyperparameters.scheduler_params[k]
|
||||
|
||||
if self.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in cfg['params']:
|
||||
cfg['params']['total_num_steps'] = self.trainer.iterations
|
||||
return cfg
|
||||
|
||||
@property
|
||||
def fp16_cfg(self):
|
||||
if self.trainer.weight_dtype.lower() != "float16":
|
||||
return None
|
||||
return {
|
||||
"enabled": True,
|
||||
"auto_cast": True,
|
||||
}
|
||||
|
||||
@property
|
||||
def bf16_cfg(self):
|
||||
return {
|
||||
"enabled": self.trainer.weight_dtype.lower() == "bfloat16"
|
||||
}
|
||||
|
||||
def get_compression_cfg(self, model):
|
||||
if not self.trainer.use_compression_training:
|
||||
return None
|
||||
|
||||
weights = [ name[0] for name in model.named_parameters() ]
|
||||
bits = 8
|
||||
|
||||
return {
|
||||
"weight_quantization": {
|
||||
"shared_parameters":{
|
||||
"enabled": True,
|
||||
"quantizer_kernel": True,
|
||||
"schedule_offset": 0,
|
||||
"quantize_groups": 64,
|
||||
"quantize_verbose": True,
|
||||
"quantization_type": "symmetric",
|
||||
"rounding": "nearest",
|
||||
"quantize_weight_in_forward": True,
|
||||
"fp16_mixed_quantize":{
|
||||
"enabled": False,
|
||||
"quantize_change_ratio": 1
|
||||
}
|
||||
},
|
||||
"different_groups": {
|
||||
"wq1": {
|
||||
"params": {
|
||||
"start_bits": bits,
|
||||
"target_bits": bits,
|
||||
"quantization_period": 0
|
||||
},
|
||||
"modules": weights
|
||||
}
|
||||
}
|
||||
},
|
||||
"activation_quantization": {
|
||||
"shared_parameters":{
|
||||
"enabled": True,
|
||||
"quantization_type": "symmetric",
|
||||
"range_calibration": "dynamic",
|
||||
"schedule_offset": 0
|
||||
},
|
||||
"different_groups": {
|
||||
"aq1": {
|
||||
"params": {
|
||||
"bits": bits
|
||||
},
|
||||
"modules": weights
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@property
|
||||
def zero_cfg(self):
|
||||
if self.trainer.zero_optimization_level == 0:
|
||||
return None
|
||||
|
||||
return {
|
||||
"stage": self.trainer.zero_optimization_level,
|
||||
"contiguous_gradients": True,
|
||||
"overlap_comm": True,
|
||||
"reduce_scatter": True,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"sub_group_size": 5e8,
|
||||
"round_robin_gradients": True,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": True
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": True
|
||||
}
|
||||
}
|
||||
|
||||
def get_ds_cfg(self, model):
|
||||
cfg = {
|
||||
"train_micro_batch_size_per_gpu": self.hyperparameters.batch_size,
|
||||
"gradient_accumulation_steps": self.hyperparameters.gradient_accumulation_steps,
|
||||
"optimizer": {
|
||||
"type": self.hyperparameters.optimizer,
|
||||
"params": {
|
||||
"lr": self.hyperparameters.learning_rate,
|
||||
}
|
||||
},
|
||||
"scheduler": self.hyperparameters.scheduler if self.hyperparameters.scheduler_type != "" else None,
|
||||
"gradient_clipping": self.hyperparameters.gradient_clipping,
|
||||
"fp16": self.fp16_cfg,
|
||||
"bf16": self.bf16_cfg,
|
||||
"compression_training": self.get_compression_cfg(model),
|
||||
"zero_optimization": self.zero_cfg,
|
||||
"comms_logger": {
|
||||
"enabled": False
|
||||
}
|
||||
}
|
||||
|
||||
null_keys = [ k for k in cfg if not cfg[k] ]
|
||||
for k in null_keys:
|
||||
del cfg[k]
|
||||
|
||||
if os.path.exists("./config/ds_config.json"):
|
||||
ds_cfg = json.load(open("./config/ds_config.json", "r", encoding="utf-8"))
|
||||
cfg.update(ds_cfg)
|
||||
|
||||
return cfg
|
||||
|
||||
@property
|
||||
def cache_dir(self):
|
||||
return ".cache" / self.relpath
|
||||
|
@ -455,6 +433,8 @@ cfg.trainer = Trainer(**cfg.trainer)
|
|||
cfg.inference = Inference(**cfg.inference)
|
||||
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
|
||||
|
||||
cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed)
|
||||
|
||||
# cached_property stopped working...
|
||||
if cfg.dataset.use_hdf5:
|
||||
try:
|
||||
|
|
9
vall_e/engines/__init__.py
Normal file
9
vall_e/engines/__init__.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
from ..config import cfg
|
||||
|
||||
if cfg.trainer.backend == "deepspeed":
|
||||
from .deepspeed import Engine
|
||||
elif cfg.trainer.backend == "local":
|
||||
from .base import Engine
|
||||
|
||||
from .base import Engines
|
||||
from .base import TrainFeeder
|
374
vall_e/engines/base.py
Normal file
374
vall_e/engines/base.py
Normal file
|
@ -0,0 +1,374 @@
|
|||
from torch import Tensor
|
||||
from typing import Any, Protocol
|
||||
|
||||
Stats = dict[str, float]
|
||||
|
||||
class TrainFeeder(Protocol):
|
||||
def __call__(
|
||||
self, *, engine: "Engine", batch: Any
|
||||
) -> None | tuple[Tensor, Stats]:
|
||||
...
|
||||
|
||||
def default_feed(engine, batch):
|
||||
if isinstance(batch, list):
|
||||
engine( *batch )
|
||||
elif isinstance(batch, dict):
|
||||
engine( **batch )
|
||||
else:
|
||||
engine( batch )
|
||||
|
||||
losses = engine.gather_attribute("loss")
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
|
||||
return loss, stats
|
||||
|
||||
|
||||
from ..config import cfg
|
||||
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
||||
|
||||
import logging
|
||||
import time
|
||||
import torch
|
||||
import torch.distributed
|
||||
import os
|
||||
|
||||
from torch import Tensor
|
||||
from torch.distributed import all_reduce
|
||||
from typing import Any, Protocol
|
||||
|
||||
from .base import TrainFeeder
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# A very naive engine implementation using barebones PyTorch
|
||||
class Engine():
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.module = kwargs['model']
|
||||
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
|
||||
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
|
||||
|
||||
self.global_steps = 0
|
||||
self.micro_steps = 0
|
||||
self.gradient_accumulation_steps = cfg.hyperparameters.gradient_accumulation_steps
|
||||
|
||||
def freeze(self):
|
||||
for p in self.module.parameters():
|
||||
if p.requires_grad:
|
||||
p.requires_grad_(False)
|
||||
self._frozen_params.add(p)
|
||||
|
||||
def unfreeze(self):
|
||||
for p in self._frozen_params:
|
||||
p.requires_grad_(True)
|
||||
self._frozen_params.clear()
|
||||
|
||||
@property
|
||||
def global_step(self):
|
||||
return self.global_steps
|
||||
|
||||
@property
|
||||
def micro_step(self):
|
||||
return self.micro_steps
|
||||
|
||||
def train_batch_size(self):
|
||||
return cfg.hyperparameters.batch_size
|
||||
|
||||
def gather_attribute(self, *args, **kwargs):
|
||||
return gather_attribute(self.module, *args, **kwargs)
|
||||
|
||||
def dispatch_attribute(self, *args, **kwargs):
|
||||
return dispatch_attribute(self.module, *args, **kwargs)
|
||||
|
||||
def save_checkpoint(self, save_dir, tag ):
|
||||
torch.save({
|
||||
"global_step": self.global_step,
|
||||
"micro_step": self.micro_step,
|
||||
"module": self.module.state_dict(),
|
||||
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
|
||||
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
||||
}, save_dir / tag / "state.pth")
|
||||
|
||||
def load_checkpoint(self, load_dir, tag, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
|
||||
state = torch.load(load_dir / tag / "state.pth")
|
||||
self.global_step = state['global_step']
|
||||
self.micro_step = state['micro_step']
|
||||
self.module.load_state_dict(state['module'])
|
||||
|
||||
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
|
||||
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
|
||||
|
||||
if load_optimizer_states:
|
||||
self.optimizer.load_state_dict(state['optimizer'])
|
||||
|
||||
if load_lr_scheduler_states:
|
||||
self.lr_scheduler.load_state_dict(state['lr_scheduler'])
|
||||
|
||||
def eval(self):
|
||||
return self.module.eval()
|
||||
|
||||
def train(self):
|
||||
return self.module.train()
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
self.module = self.module.to(*args, **kwargs)
|
||||
return self.module
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module.forward(*args, **kwargs)
|
||||
|
||||
def backward(self, loss):
|
||||
return (loss / self.gradient_accumulation_steps).backward()
|
||||
|
||||
def step(self):
|
||||
with torch.set_grad_enabled(self.gradient_accumulation_steps > 1):
|
||||
self.micro_steps += 1
|
||||
|
||||
if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0:
|
||||
self.global_steps += 1
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for param_group in self.optimizer.param_groups:
|
||||
if 'lr' in param_group:
|
||||
lrs.append(param_group['lr'])
|
||||
return lrs
|
||||
|
||||
def set_lr(self, lr):
|
||||
for param_group in self.optimizer.param_groups:
|
||||
if 'lr' in param_group:
|
||||
param_group['lr'] = lr
|
||||
|
||||
def get_global_grad_norm(self):
|
||||
return 0.0
|
||||
|
||||
def traverse(self, *args, **kwargs):
|
||||
self.forward(*args, **kwargs)
|
||||
losses = self.gather_attribute("loss")
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= self.gather_attribute("scalar")
|
||||
|
||||
self.backward(loss)
|
||||
self.step()
|
||||
|
||||
return stats
|
||||
|
||||
# and now to ignore everything from the above
|
||||
class Engines(dict[str, Engine]):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.setup()
|
||||
|
||||
def setup(self):
|
||||
self._global_step = 0
|
||||
self._micro_step = 0
|
||||
|
||||
@property
|
||||
def global_step(self):
|
||||
return self._global_step
|
||||
|
||||
@property
|
||||
def micro_step(self):
|
||||
return self._micro_step
|
||||
|
||||
def gather_attribute(self, *args, **kwargs):
|
||||
ret = {}
|
||||
for engine in self.values():
|
||||
ret |= engine.gather_attribute(*args, **kwargs)
|
||||
return ret
|
||||
|
||||
def dispatch_attribute(self, *args, **kwargs):
|
||||
for engine in self.values():
|
||||
engine.dispatch_attribute(*args, **kwargs)
|
||||
|
||||
def save_checkpoint(self, tag=None):
|
||||
if not tag:
|
||||
tag = cfg.trainer.save_tag
|
||||
tag = tag.lower()
|
||||
if tag[:2] == "it" or tag[:4] == "step":
|
||||
tag = f'{self.global_step}'
|
||||
|
||||
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||
for name, engine in self.items():
|
||||
engine.save_checkpoint(cfg.ckpt_dir / name, tag=tag)
|
||||
|
||||
def load_checkpoint(self, tag=None):
|
||||
if not tag:
|
||||
tag = cfg.trainer.load_tag
|
||||
|
||||
for name, engine in self.items():
|
||||
load_dir = cfg.ckpt_dir / name
|
||||
engine.load_checkpoint(
|
||||
tag=tag,
|
||||
load_dir=load_dir,
|
||||
load_module_strict=cfg.trainer.strict_loading,
|
||||
load_optimizer_states=cfg.trainer.load_states,
|
||||
load_lr_scheduler_states=cfg.trainer.load_states,
|
||||
)
|
||||
if cfg.trainer.restart_step_count:
|
||||
engine.global_steps = 0
|
||||
|
||||
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
|
||||
if cfg.hyperparameters.scheduler_type == "":
|
||||
self.set_lr(cfg.hyperparameters.learning_rate)
|
||||
|
||||
self._update_global_step()
|
||||
self._update_micro_step()
|
||||
|
||||
def set_lr(self, lr):
|
||||
for engine in self.values():
|
||||
engine.set_lr(lr)
|
||||
|
||||
def _update_global_step(self):
|
||||
for engine in self.values():
|
||||
self._global_step = max(self._global_step, engine.global_step)
|
||||
|
||||
def _update_micro_step(self):
|
||||
for engine in self.values():
|
||||
self._micro_step = max(self._micro_step, engine.micro_step)
|
||||
|
||||
def train_batch_size(self):
|
||||
batch_size = 0
|
||||
for engine in self.values():
|
||||
batch_size = max(batch_size, engine.train_batch_size())
|
||||
|
||||
def eval(self):
|
||||
for engine in self.values():
|
||||
engine.eval()
|
||||
|
||||
def train(self):
|
||||
for engine in self.values():
|
||||
engine.train()
|
||||
|
||||
def traverse(self):
|
||||
stats = {}
|
||||
for name, engine in self.items():
|
||||
stat = engine.traverse()
|
||||
stats.update(flatten_dict({ name.split("-")[0]: stat }))
|
||||
return stats
|
||||
|
||||
def step(self, batch, feeder: TrainFeeder = default_feed, device=torch.cuda.current_device()):
|
||||
total_elapsed_time = 0
|
||||
|
||||
stats: Any = dict()
|
||||
|
||||
if cfg.trainer.gc_mode == 'step':
|
||||
do_gc()
|
||||
|
||||
batch = to_device(batch, device)
|
||||
|
||||
for name, engine in self.items():
|
||||
#torch.cuda.synchronize()
|
||||
|
||||
if cfg.trainer.gc_mode == 'substep':
|
||||
do_gc()
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
tries = 4
|
||||
n_ooms = torch.zeros([], device=cfg.device)
|
||||
|
||||
if cfg.trainer.aggressive_optimizations:
|
||||
batch = to_device(batch, device)
|
||||
|
||||
res = feeder( engine=engine, batch=batch )
|
||||
"""
|
||||
while tries >= 0:
|
||||
try:
|
||||
res = feeder( engine=engine, batch=batch )
|
||||
break
|
||||
except RuntimeError as e:
|
||||
print("Forward", str(e))
|
||||
|
||||
if "out of memory" not in str(e):
|
||||
self.save_checkpoint()
|
||||
raise e
|
||||
|
||||
# shrink batch size until it's happy
|
||||
for k in batch:
|
||||
batch[k] = batch[k][:-1]
|
||||
|
||||
if tries <= 0:
|
||||
# trigger OOM
|
||||
n_ooms += 1
|
||||
else:
|
||||
# also do GC
|
||||
do_gc()
|
||||
continue
|
||||
|
||||
all_reduce(n_ooms)
|
||||
if n_ooms.item() > 0:
|
||||
self.save_checkpoint()
|
||||
raise RuntimeError("Out of memory during forward pass!")
|
||||
"""
|
||||
|
||||
if res is None:
|
||||
continue
|
||||
|
||||
loss, engine_stats = res
|
||||
engine_stats |= self.gather_attribute("scalar")
|
||||
|
||||
n_ooms = torch.zeros([], device=cfg.device)
|
||||
|
||||
if cfg.trainer.aggressive_optimizations:
|
||||
batch = to_device(batch, 'cpu')
|
||||
|
||||
engine.backward(loss)
|
||||
"""
|
||||
try:
|
||||
engine.backward(loss)
|
||||
except RuntimeError as e:
|
||||
print("Backwards:", str(e))
|
||||
|
||||
if "out of memory" not in str(e):
|
||||
self.save_checkpoint()
|
||||
raise e
|
||||
|
||||
n_ooms += 1
|
||||
|
||||
all_reduce(n_ooms)
|
||||
if n_ooms.item() > 0:
|
||||
self.save_checkpoint()
|
||||
raise RuntimeError("Out of memory during backwards pass!")
|
||||
"""
|
||||
|
||||
engine.step()
|
||||
|
||||
#torch.cuda.synchronize()
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
total_elapsed_time += elapsed_time
|
||||
|
||||
stats.update(
|
||||
flatten_dict(
|
||||
{
|
||||
name.split("-")[0]: dict(
|
||||
loss=loss.item(),
|
||||
lr=engine.get_lr()[0],
|
||||
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
|
||||
elapsed_time=elapsed_time,
|
||||
engine_step=engine.global_step,
|
||||
**engine_stats,
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
self._update_global_step()
|
||||
self._update_micro_step()
|
||||
stats["batch_size"] = self.train_batch_size() # len(batch["text"])
|
||||
stats["elapsed_time"] = total_elapsed_time
|
||||
stats["wall_time"] = time.time()
|
||||
stats["global_step"] = self.global_step
|
||||
|
||||
return stats
|
89
vall_e/engines/deepspeed.py
Normal file
89
vall_e/engines/deepspeed.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
"""
|
||||
# https://github.com/enhuiz/pytorch-training-utilities
|
||||
"""
|
||||
|
||||
# to-do: replace this
|
||||
# to-do: swap out deepspeed
|
||||
|
||||
from ..config import cfg
|
||||
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
||||
|
||||
import logging
|
||||
import time
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from torch import Tensor
|
||||
from torch.distributed import all_reduce
|
||||
from typing import Any, Protocol
|
||||
|
||||
from .base import TrainFeeder
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distributed
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
#dist.init_distributed(dist_backend=get_accelerator().communication_backend_name())
|
||||
initialized_dist = False
|
||||
if not initialized_dist:
|
||||
initialized_dist = True
|
||||
init_distributed()
|
||||
|
||||
class Engine(DeepSpeedEngine):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs['config'] = cfg.trainer.deepspeed.get_ds_cfg(model=kwargs['model'])
|
||||
kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
|
||||
|
||||
super().__init__(None, *args, **kwargs)
|
||||
self._frozen_params = set()
|
||||
|
||||
def freeze(self):
|
||||
for p in self.module.parameters():
|
||||
if p.requires_grad:
|
||||
p.requires_grad_(False)
|
||||
self._frozen_params.add(p)
|
||||
|
||||
def unfreeze(self):
|
||||
for p in self._frozen_params:
|
||||
p.requires_grad_(True)
|
||||
self._frozen_params.clear()
|
||||
|
||||
@property
|
||||
def global_step(self):
|
||||
return self.global_steps
|
||||
|
||||
@property
|
||||
def micro_step(self):
|
||||
return self.micro_steps
|
||||
|
||||
def gather_attribute(self, *args, **kwargs):
|
||||
return gather_attribute(self.module, *args, **kwargs)
|
||||
|
||||
def dispatch_attribute(self, *args, **kwargs):
|
||||
return dispatch_attribute(self.module, *args, **kwargs)
|
||||
|
||||
def set_lr(self, lr):
|
||||
try:
|
||||
if hasattr(self.optimizer, 'param_groups'):
|
||||
print(self.optimizer.param_groups)
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
else:
|
||||
self.optimizer.set_lr(lr)
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
|
||||
def traverse(self, *args, **kwargs):
|
||||
self.forward(*args, **kwargs)
|
||||
losses = self.gather_attribute("loss")
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= self.gather_attribute("scalar")
|
||||
|
||||
self.backward(loss)
|
||||
self.step()
|
||||
|
||||
return stats
|
|
@ -72,8 +72,8 @@ class TTS():
|
|||
prom = to_device(prom, self.device).to(torch.int16)
|
||||
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||
|
||||
resp_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp)
|
||||
resps_list = [r.unsqueeze(-1) for r in resp_list]
|
||||
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp)
|
||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp)
|
||||
|
||||
wav, sr = qnt.decode_to_file(resps_list[0], out_path)
|
||||
|
|
|
@ -1,20 +1,22 @@
|
|||
from .ar import AR
|
||||
from .nar import NAR
|
||||
|
||||
def get_model(model):
|
||||
if model.name == "ar":
|
||||
def get_model(cfg):
|
||||
if cfg.name == "ar":
|
||||
Model = AR
|
||||
elif model.name == "nar":
|
||||
elif cfg.name == "nar":
|
||||
Model = NAR
|
||||
else:
|
||||
raise f"invalid model name: {model.name}"
|
||||
name = model.name
|
||||
raise f"invalid model name: {cfg.name}"
|
||||
name = cfg.name
|
||||
|
||||
model = Model(
|
||||
n_tokens=model.tokens,
|
||||
d_model=model.dim,
|
||||
n_heads=model.heads,
|
||||
n_layers=model.layers,
|
||||
)
|
||||
model._cfg = cfg
|
||||
|
||||
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
|
|
|
@ -50,112 +50,80 @@ class AR(Base):
|
|||
self,
|
||||
text_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
resp_list: list[Tensor] | None = None,
|
||||
resps_list: list[Tensor] | None = None,
|
||||
max_steps: int = 1000,
|
||||
sampling_temperature: float = 1.0,
|
||||
|
||||
naive: bool = True,
|
||||
):
|
||||
if resp_list is not None:
|
||||
if resps_list is not None:
|
||||
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
|
||||
|
||||
return super().forward(
|
||||
text_list,
|
||||
proms_list,
|
||||
self._unsqueeze_list(resp_list),
|
||||
resp_list,
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=self._unsqueeze_list(resps_list),
|
||||
targ_list=resps_list,
|
||||
quant_levels=None,
|
||||
shift_targ_list=True,
|
||||
return_all_resp=False,
|
||||
)
|
||||
|
||||
device = text_list[0].device
|
||||
resp_list: list[Tensor] = [
|
||||
resps_list: list[Tensor] = [
|
||||
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
|
||||
]
|
||||
stopped = torch.zeros(len(text_list), device=device).bool()
|
||||
|
||||
if self.arch_type == "transformer":
|
||||
naive = True
|
||||
|
||||
chunk_size = 1 # don't really know what to do about this desu
|
||||
|
||||
state = None
|
||||
start = 0
|
||||
|
||||
# prefill
|
||||
if self.arch_type == "retnet/local":
|
||||
# pre-process
|
||||
state = [
|
||||
[
|
||||
torch.zeros(self.retnet.hidden_dim // self.retnet.heads, self.retnet.v_dim // self.retnet.heads, device=device).unsqueeze(0).repeat(len(text_list), 1, 1)
|
||||
for _ in range(self.retnet.heads)
|
||||
] for _ in range(self.retnet.layers)
|
||||
]
|
||||
resps_list = self._unsqueeze_list(resp_list)
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
self.proms_emb(proms_list),
|
||||
self.resps_emb(resps_list),
|
||||
sep=self.sep,
|
||||
)
|
||||
|
||||
x, m = list_to_tensor(x_list)
|
||||
|
||||
start = x.shape[1]
|
||||
|
||||
for i in trange(start-1):
|
||||
_, state = self.retnet.forward_recurrent( x[:, i:i+1, :], state, i+1 )
|
||||
|
||||
for n in trange(max_steps // chunk_size):
|
||||
# get next in sequence
|
||||
|
||||
r, state = super().forward(
|
||||
text_list,
|
||||
proms_list,
|
||||
self._unsqueeze_list(resp_list),
|
||||
self._unsqueeze_list(resps_list),
|
||||
sampling_temperature=sampling_temperature,
|
||||
state=state,
|
||||
state=state if not naive else None,
|
||||
)
|
||||
|
||||
# append outputted token
|
||||
for i, ri in enumerate(r):
|
||||
resp_list[i] = torch.cat([resp_list[i], ri[None]])
|
||||
resps_list[i] = torch.cat([resps_list[i], ri[None]])
|
||||
|
||||
# stop token found
|
||||
stopped |= r == self.stop_token
|
||||
if stopped.all().item():
|
||||
break
|
||||
|
||||
pruned = [self._prune(r) for r in resp_list]
|
||||
pruned = [self._prune(r) for r in resps_list]
|
||||
return pruned
|
||||
|
||||
|
||||
def example_usage():
|
||||
cfg.trainer.backend = "local"
|
||||
from functools import partial
|
||||
|
||||
from einops import repeat
|
||||
|
||||
from ..emb.qnt import decode_to_file
|
||||
from ..utils import gather_attribute
|
||||
from ..engines import Engine
|
||||
from tqdm import tqdm
|
||||
|
||||
device = "cpu"
|
||||
|
||||
x8 = partial(repeat, pattern="t -> t l", l=2)
|
||||
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
|
||||
def tokenize(content, lang_marker="en"):
|
||||
split = content.split(" ")
|
||||
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
||||
return torch.tensor([*map(symmap.get, phones)]).to()
|
||||
|
||||
qnt = torch.load("data/qnt.pt")[0, 0].to(device)
|
||||
kwargs = {
|
||||
'n_tokens': 1024,
|
||||
'd_model': 1024,
|
||||
'n_heads': 16,
|
||||
'n_layers': 12,
|
||||
}
|
||||
qnt = torch.load("data/qnt.pt")[0].t()[:, :2].to(device)
|
||||
|
||||
model = AR(**kwargs).to(device)
|
||||
|
||||
x8 = partial(repeat, pattern="t -> t l", l=2)
|
||||
text_list = [
|
||||
#torch.tensor([1, 2, 3], device=device),
|
||||
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
|
||||
|
@ -164,41 +132,41 @@ def example_usage():
|
|||
x8(torch.tensor([1, 2, 3], device=device)),
|
||||
#qnt.to(device),
|
||||
]
|
||||
resp_list = [
|
||||
resps_list = [
|
||||
qnt.to(device),
|
||||
]
|
||||
|
||||
text_list = text_list[:1]
|
||||
proms_list = proms_list[:1]
|
||||
resp_list = resp_list[:1]
|
||||
resps_list = resps_list[:1]
|
||||
|
||||
model.eval()
|
||||
out = model(text_list, proms_list, max_steps=75)[0]
|
||||
print("qnt:", qnt.shape, qnt)
|
||||
print("out:", out.shape, out)
|
||||
wav, sr = decode_to_file(out, "data/test/test.ar.init.wav", device=device)
|
||||
kwargs = {
|
||||
'n_tokens': 1024,
|
||||
'd_model': 1024,
|
||||
'n_heads': 16,
|
||||
'n_layers': 12,
|
||||
}
|
||||
model = AR(**kwargs).to(device)
|
||||
engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4))
|
||||
|
||||
def sample( name, steps=400 ):
|
||||
engine.eval()
|
||||
out = engine(text_list, proms_list, max_steps=steps)
|
||||
for i, o in enumerate(out):
|
||||
wav, sr = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||||
def train():
|
||||
engine.train()
|
||||
t = trange(60)
|
||||
for i in t:
|
||||
stats = {"step": i}
|
||||
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
|
||||
|
||||
model.train()
|
||||
for i in trange(60):
|
||||
optimizer.zero_grad()
|
||||
_ = model(text_list, proms_list, resp_list)
|
||||
|
||||
losses = gather_attribute(model, "loss")
|
||||
loss = sum(losses.values())
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if i % 20 == 0:
|
||||
print(f"iter={i}, {losses}.")
|
||||
model.eval()
|
||||
out = model(text_list, proms_list, max_steps=400)
|
||||
print("qnt:", qnt.shape, qnt)
|
||||
for i, o in enumerate(out):
|
||||
print("out:", i, o.shape, o)
|
||||
wav, sr = decode_to_file(o, f"data/test/test.ar.{i}.wav", device=device)
|
||||
t.set_description(f"{stats}")
|
||||
|
||||
sample("init", 75)
|
||||
train()
|
||||
sample("final")
|
||||
|
||||
if __name__ == "__main__":
|
||||
example_usage()
|
||||
|
|
|
@ -176,17 +176,7 @@ class Base(nn.Module):
|
|||
self.retnet = RetNetDecoder(
|
||||
self.retnet_config
|
||||
)
|
||||
elif self.arch_type == "retnet/local":
|
||||
self.retnet = RetNet(
|
||||
layers=n_layers,
|
||||
hidden_dim=d_model,
|
||||
ffn_size=d_model * 4,
|
||||
heads=n_heads,
|
||||
dropout=p_dropout,
|
||||
norm_type=self.norm_type,
|
||||
n_levels=self.n_resp_levels,
|
||||
double_v_dim=True
|
||||
)
|
||||
|
||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||
|
||||
self.accuracy_metric = MulticlassAccuracy(
|
||||
|
@ -272,7 +262,7 @@ class Base(nn.Module):
|
|||
text_list: [t] * b
|
||||
proms_list: [t' l] * b, l quantization levels.
|
||||
resps_list: [t'' l] * b, l quantization levels.
|
||||
targ_list: [t''] * b, one quantization level only, when given, loss will be computed
|
||||
targ_list: [t''] * b, one quantization level only; when given, loss will be computed
|
||||
quant_levels: specify which quant_levels to feed forward, used in NAR mode.
|
||||
shift_targ_list: whether to shift target list when computing loss. True if AR.
|
||||
return_all_resp: True if NAR.
|
||||
|
@ -298,24 +288,12 @@ class Base(nn.Module):
|
|||
elif self.arch_type == "retnet":
|
||||
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||
state = self.retnet.get_incremental_state( state, 'prev_state' )
|
||||
elif self.arch_type == "retnet/local":
|
||||
# recurrent inferencing
|
||||
if self.causal and state is not None:
|
||||
last = x.shape[1]
|
||||
x, state = self.retnet.forward_recurrent(
|
||||
x[:, last-1:last, :], # nasty way to grab the last embedding to forward
|
||||
state,
|
||||
last
|
||||
)
|
||||
else:
|
||||
x = self.retnet( x, quant_levels )
|
||||
|
||||
x = self.classifier(x) * m
|
||||
|
||||
# Remove padding
|
||||
h_list = [hi[:li] for hi, li in zip(x, map(len, x_list))]
|
||||
|
||||
|
||||
# compute loss if the target is given
|
||||
if targ_list is not None:
|
||||
if any([l == 0 for l in map(len, targ_list)]):
|
||||
|
@ -337,20 +315,24 @@ class Base(nn.Module):
|
|||
# the NAR doesn't need to compute the loss for it
|
||||
if self.resp_loss_only:
|
||||
text_prom_list[i][:] = self.ignore_index
|
||||
|
||||
# roll the text/prompt for loss computing
|
||||
# the AR benefits from this
|
||||
# the AR benefits from this, for some reason I'll figure out later
|
||||
else:
|
||||
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
|
||||
text_prom_list[i][-1] = self.ignore_index
|
||||
|
||||
# necessary to roll the target if recurrently/causally/autoregressively generating, or it won't be able to work
|
||||
# for the AR, roll by one and mark the ending with a stop token
|
||||
# this coerces the model into properly inferencing causally
|
||||
|
||||
# why we don't just append a stop token in the dataloader, who knows
|
||||
if shift_targ_list:
|
||||
targ_list = [*targ_list]
|
||||
for i in range(len(targ_list)):
|
||||
targ_list[i] = targ_list[i].roll(-1, dims=0)
|
||||
targ_list[i][-1] = self.stop_token
|
||||
|
||||
# generate the sequence
|
||||
# create the new target sequence to compute the loss against
|
||||
y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep )
|
||||
|
||||
self.loss = dict(
|
||||
|
@ -367,6 +349,7 @@ class Base(nn.Module):
|
|||
del prom_list
|
||||
del text_prom_list
|
||||
del y_list
|
||||
|
||||
|
||||
# return the entire generated token string
|
||||
if return_all:
|
||||
|
@ -387,124 +370,90 @@ class Base(nn.Module):
|
|||
return ret, state
|
||||
|
||||
def example_usage():
|
||||
from ..config import cfg
|
||||
cfg.trainer.backend = "local"
|
||||
|
||||
from functools import partial
|
||||
|
||||
from einops import repeat
|
||||
from tqdm import trange
|
||||
|
||||
from ..utils import gather_attribute
|
||||
|
||||
from ..emb.qnt import decode_to_file
|
||||
from ..engines import Engine, Engines
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from .ar import AR
|
||||
from .nar import NAR
|
||||
|
||||
device = "cpu"
|
||||
x8 = partial(repeat, pattern="t -> t l", l=2)
|
||||
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
|
||||
def tokenize(content, lang_marker="en"):
|
||||
split = content.split(" ")
|
||||
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
||||
return torch.tensor([*map(symmap.get, phones)]).to()
|
||||
|
||||
device = "cpu"
|
||||
|
||||
kwargs = {
|
||||
'n_tokens': 1024,
|
||||
'd_model': 1024,
|
||||
'n_heads': 16,
|
||||
'n_layers': 12,
|
||||
}
|
||||
model_ar = AR(**kwargs).to(device)
|
||||
model_nar = NAR(**kwargs).to(device)
|
||||
models = { "ar": AR(**kwargs).to(device), "nar": NAR(**kwargs).to(device) }
|
||||
engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
|
||||
|
||||
train = True
|
||||
|
||||
if train:
|
||||
qnt = torch.load("data/qnt.pt").to(device)
|
||||
text_list = [
|
||||
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
|
||||
#tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device),
|
||||
]
|
||||
|
||||
x8 = partial(repeat, pattern="t -> t l", l=2)
|
||||
proms_list = [
|
||||
qnt[0][:2,:].t().to(device),
|
||||
#x8(torch.tensor([1, 2, 3], device=device)),
|
||||
# x8(torch.tensor([2, 3], device=device)),
|
||||
]
|
||||
qnt = torch.load("data/qnt.pt")[0].t()[:, :2].to(device)
|
||||
text_list = [
|
||||
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
|
||||
#tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device),
|
||||
]
|
||||
|
||||
resp_list_ar = [
|
||||
qnt[0,0].to(device),
|
||||
# qnt[0,0].to(device),
|
||||
]
|
||||
|
||||
resp_list_nar = [
|
||||
qnt[0][:2,:].t().to(device),
|
||||
# qnt[0][:2,:].t().to(device),
|
||||
]
|
||||
|
||||
model_ar.train()
|
||||
optimizer = torch.optim.AdamW(model_ar.parameters(), lr=1e-4)
|
||||
for i in trange(60):
|
||||
optimizer.zero_grad()
|
||||
_ = model_ar(text_list, proms_list, resp_list_ar)
|
||||
|
||||
losses = gather_attribute(model_ar, "loss")
|
||||
loss = sum(losses.values())
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if i % 20 == 0:
|
||||
print(f"iter={i}, {losses}.")
|
||||
|
||||
model_nar.train()
|
||||
optimizer = torch.optim.AdamW(model_nar.parameters(), lr=1e-4)
|
||||
for i in trange(60):
|
||||
optimizer.zero_grad()
|
||||
|
||||
_ = model_nar(text_list, proms_list, resps_list=resp_list_nar)
|
||||
|
||||
losses = gather_attribute(model_nar, "loss")
|
||||
loss = sum(losses.values())
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if i % 20 == 0:
|
||||
stats = {k: v.item() for k, v in losses.items()}
|
||||
stats["loss"] = loss.item()
|
||||
print(f"iter={i}, {stats}.")
|
||||
else:
|
||||
qnt = torch.load("data/test/test.qnt.pt")[0][:2,:].t().to(device)
|
||||
text_list = [
|
||||
#tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
|
||||
tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device),
|
||||
]
|
||||
proms_list = [
|
||||
qnt.to(device),
|
||||
]
|
||||
model_ar.load_state_dict(torch.load("data/test/ar.pth"))
|
||||
model_nar.load_state_dict(torch.load("data/test/nar.pth"))
|
||||
|
||||
model_ar.eval()
|
||||
resp_list = model_ar(text_list, proms_list, max_steps=300, sampling_temperature=1.0)
|
||||
resps_list = [r.unsqueeze(-1) for r in resp_list]
|
||||
proms_list = [
|
||||
qnt.to(device),
|
||||
]
|
||||
resps_list = [
|
||||
qnt.to(device),
|
||||
]
|
||||
|
||||
print("qnt:", qnt.shape, qnt)
|
||||
print("out:", resp_list[0].shape, resp_list[0])
|
||||
wav, sr = decode_to_file(resp_list[0], "data/test/test.ar.init.wav", device=device)
|
||||
print(wav, sr)
|
||||
def sample( name, steps=400 ):
|
||||
AR = None
|
||||
NAR = None
|
||||
|
||||
model_nar.eval()
|
||||
codes = model_nar(
|
||||
text_list,
|
||||
proms_list,
|
||||
resps_list=resps_list,
|
||||
sampling_temperature=1.0,
|
||||
)[0]
|
||||
engines.eval()
|
||||
for name, engine in engines.items():
|
||||
if name[:2] == "ar":
|
||||
AR = engine
|
||||
elif name[:3] == "nar":
|
||||
NAR = engine
|
||||
|
||||
resps_list = AR(text_list, proms_list, max_steps=steps, sampling_temperature=1.0)
|
||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||
codes = NAR( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
||||
|
||||
print("qnt:", qnt.shape, qnt)
|
||||
print("codes:", codes.shape, codes)
|
||||
decode_to_file(resps_list[0], f"./data/ar.{name}.wav", device=device)
|
||||
decode_to_file(codes[0], f"./data/ar+nar.{name}.wav", device=device)
|
||||
|
||||
if train:
|
||||
sample("init", 15)
|
||||
|
||||
wav, sr = decode_to_file(codes, "data/test/test.ar+nar.init.wav", device=device)
|
||||
print(wav, sr)
|
||||
engines.train()
|
||||
t = trange(60)
|
||||
for i in t:
|
||||
"""
|
||||
stats = {"step": i}
|
||||
for name, engine in engines.items():
|
||||
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
|
||||
"""
|
||||
stats = engines.step({"text_list": text_list, "proms_list": proms_list, "resps_list": resps_list}, device="cpu")
|
||||
t.set_description(f"{stats}")
|
||||
else:
|
||||
for name, engine in engines.items():
|
||||
engine.module.load_state_dict(torch.load(f"./data/{name}.pth"))
|
||||
|
||||
sample("final")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
example_usage()
|
||||
|
|
|
@ -80,7 +80,6 @@ class NAR(Base):
|
|||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
# Yes, just nothing as we are training
|
||||
prev_list = []
|
||||
else:
|
||||
prev_list = resps_list
|
||||
|
@ -93,7 +92,7 @@ class NAR(Base):
|
|||
|
||||
quant_levels = torch.full((len(text_list),), level, device=device)
|
||||
|
||||
resp_list, _ = super().forward(
|
||||
resps_list, _ = super().forward(
|
||||
text_list,
|
||||
proms_list,
|
||||
prev_list,
|
||||
|
@ -105,24 +104,22 @@ class NAR(Base):
|
|||
|
||||
prev_list = [
|
||||
torch.cat([rs, r.unsqueeze(-1)], dim=-1)
|
||||
for rs, r in zip(prev_list, resp_list)
|
||||
for rs, r in zip(prev_list, resps_list)
|
||||
]
|
||||
|
||||
return prev_list
|
||||
|
||||
|
||||
def example_usage():
|
||||
cfg.trainer.backend = "local"
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
from einops import repeat
|
||||
|
||||
from ..emb.qnt import decode_to_file
|
||||
from ..utils import gather_attribute
|
||||
from ..config import cfg
|
||||
from ..engines import Engine
|
||||
from tqdm import tqdm
|
||||
|
||||
device = "cpu"
|
||||
|
||||
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
|
||||
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
|
||||
def tokenize(content, lang_marker="en"):
|
||||
|
@ -130,15 +127,8 @@ def example_usage():
|
|||
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
||||
return torch.tensor([*map(symmap.get, phones)]).to()
|
||||
|
||||
resps = torch.load("data/qnt.pt")[0][:2, :].to(device)
|
||||
kwargs = {
|
||||
'n_tokens': 1024,
|
||||
'd_model': 1024,
|
||||
'n_heads': 16,
|
||||
'n_layers': 12,
|
||||
}
|
||||
|
||||
model = NAR(**kwargs).to(device)
|
||||
# to-do: unmangle this and the resp shit
|
||||
qnt = torch.load("data/qnt.pt")[0].t()[:, :2].to(device)
|
||||
|
||||
text_list = [
|
||||
#torch.tensor([1, 2, 3], device=device),
|
||||
|
@ -149,65 +139,36 @@ def example_usage():
|
|||
x8(torch.tensor([2, 3], device=device)),
|
||||
]
|
||||
|
||||
resps_x1_list = [
|
||||
resps[:1].t().to(device),
|
||||
resps_list = [
|
||||
qnt.to(device),
|
||||
]
|
||||
|
||||
resps_x8_list = [
|
||||
resps.t().to(device),
|
||||
]
|
||||
kwargs = {
|
||||
'n_tokens': 1024,
|
||||
'd_model': 1024,
|
||||
'n_heads': 16,
|
||||
'n_layers': 12,
|
||||
}
|
||||
model = NAR(**kwargs).to(device)
|
||||
engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4))
|
||||
|
||||
model.eval()
|
||||
codes = model(
|
||||
text_list,
|
||||
proms_list,
|
||||
resps_list=resps_x1_list,
|
||||
sampling_temperature=0.2,
|
||||
)[0]
|
||||
def sample( name ):
|
||||
engine.eval()
|
||||
codes = engine( text_list, proms_list, resps_list=[r[..., 0].unsqueeze(-1) for r in resps_list], sampling_temperature=0.2 )
|
||||
decode_to_file( codes[0], f"data/nar.{name}.wav", device )
|
||||
|
||||
decode_to_file(
|
||||
codes,
|
||||
Path("data/test/test.nar.init.wav"),
|
||||
device
|
||||
)
|
||||
def train():
|
||||
engine.train()
|
||||
t = trange(60)
|
||||
for i in t:
|
||||
stats = {"step": i}
|
||||
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||||
t.set_description(f"{stats}")
|
||||
|
||||
model.train()
|
||||
for i in trange(50):
|
||||
optimizer.zero_grad()
|
||||
|
||||
_ = model(text_list, proms_list, resps_list=resps_x8_list)
|
||||
|
||||
losses = gather_attribute(model, "loss")
|
||||
loss = sum(losses.values())
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
if i % 20 == 0:
|
||||
stats = {k: v.item() for k, v in losses.items()}
|
||||
stats["loss"] = loss.item()
|
||||
print(f"iter={i}, {stats}.")
|
||||
|
||||
model.eval()
|
||||
for i in trange(1, 2): # cfg.models.prom_levels):
|
||||
resps_list = [
|
||||
resps[:i].t().to(device),
|
||||
]
|
||||
|
||||
codes = model(
|
||||
text_list,
|
||||
proms_list,
|
||||
resps_list=resps_list,
|
||||
sampling_temperature=0.2,
|
||||
)[0]
|
||||
|
||||
decode_to_file(
|
||||
codes,
|
||||
Path(f"data/test/test.nar.1-{i}.wav"),
|
||||
device
|
||||
)
|
||||
sample("init")
|
||||
train()
|
||||
sample("final")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
300
vall_e/train.py
300
vall_e/train.py
|
@ -1,17 +1,13 @@
|
|||
# todo: clean this mess up
|
||||
# todo: yank deepspeed dependent code out into its own thing
|
||||
|
||||
from .config import cfg
|
||||
from .data import create_train_val_dataloader
|
||||
from .emb import qnt
|
||||
|
||||
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
|
||||
from .utils import wrapper as ml
|
||||
|
||||
from .models import get_models
|
||||
|
||||
import auraloss
|
||||
import deepspeed
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
|
@ -21,10 +17,6 @@ import traceback
|
|||
|
||||
from collections import defaultdict
|
||||
|
||||
from deepspeed import comm as dist
|
||||
from deepspeed import DeepSpeedConfig
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cuda")
|
||||
|
@ -38,210 +30,120 @@ def left_crop(x, len):
|
|||
return x[..., 0:len]
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
deepspeed._initialized_dist = False
|
||||
|
||||
def load_engines():
|
||||
if not deepspeed._initialized_dist:
|
||||
deepspeed._initialized_dist = True
|
||||
deepspeed.init_distributed()
|
||||
def train_feeder(engine, batch):
|
||||
engine( text_list=batch["text"], proms_list=batch["proms"], resps_list=batch["resps"] )
|
||||
|
||||
models = get_models(cfg.models.get())
|
||||
engines = dict()
|
||||
losses = engine.gather_attribute("loss")
|
||||
|
||||
for name in models:
|
||||
model = models[name]
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
||||
optimizer = None
|
||||
lr_scheduler = None
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
|
||||
Adam = ml.Adam
|
||||
AdamW = ml.AdamW
|
||||
return loss, stats
|
||||
|
||||
if cfg.hyperparameters.optimizer.lower() == "adamw-torch":
|
||||
optimizer = AdamW(
|
||||
model.parameters(),
|
||||
lr=cfg.hyperparameters.learning_rate,
|
||||
betas=(0.9, 0.96),
|
||||
eps=1e-07,
|
||||
weight_decay=0.01,
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def run_eval(engines, eval_name, dl):
|
||||
engines_stats = {
|
||||
'eval': eval_name
|
||||
}
|
||||
|
||||
if cfg.trainer.load_state_dict:
|
||||
load_dir = cfg.ckpt_dir / name / "fp32.pth"
|
||||
model.load_state_dict(torch.load(load_dir))
|
||||
AR = None
|
||||
NAR = None
|
||||
|
||||
ds_cfg=cfg.get_ds_cfg(model=model)
|
||||
config_class=DeepSpeedConfig(ds_cfg)
|
||||
engines[name] = trainer.Engine(
|
||||
model=model,
|
||||
config=ds_cfg,
|
||||
config_class=config_class,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
names = []
|
||||
for name, engine in engines.items():
|
||||
names.append(name)
|
||||
if name[:2] == "ar":
|
||||
AR = engine
|
||||
elif name[:3] == "nar":
|
||||
NAR = engine
|
||||
|
||||
stats = defaultdict(list)
|
||||
stats['loss'] = []
|
||||
|
||||
def process( name, batch, resps_list ):
|
||||
for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]):
|
||||
if len(hyp) == 0:
|
||||
continue
|
||||
|
||||
filename = f'{speaker}_{path.parts[-1]}'
|
||||
|
||||
# to-do, refine the output dir to be sane-er
|
||||
ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
|
||||
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")
|
||||
prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav")
|
||||
|
||||
hyp_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
ref_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
prom_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
|
||||
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
|
||||
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
|
||||
|
||||
# pseudo loss calculation since we don't get the logits during eval
|
||||
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
|
||||
ref_audio = ref_audio[..., 0:min_length]
|
||||
hyp_audio = hyp_audio[..., 0:min_length]
|
||||
|
||||
try:
|
||||
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
|
||||
for batch in tqdm(dl):
|
||||
batch: dict = to_device(batch, cfg.device)
|
||||
|
||||
# if we're training both models, provide output for both
|
||||
if AR is not None and NAR is not None:
|
||||
name = "+".join(names)
|
||||
|
||||
resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.temperature)
|
||||
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
|
||||
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.temperature)
|
||||
|
||||
process( name, batch, resps_list )
|
||||
else:
|
||||
for name in engines:
|
||||
model = engines[name]
|
||||
|
||||
if name.startswith("ar"):
|
||||
resps_list = model(
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
max_steps=cfg.evaluation.steps,
|
||||
sampling_temperature=cfg.evaluation.ar_temperature,
|
||||
)
|
||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||
elif name.startswith("nar"):
|
||||
resps_list = model(
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]],
|
||||
sampling_temperature=cfg.evaluation.nar_temperature,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(name)
|
||||
|
||||
process( name, batch, resps_list )
|
||||
|
||||
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||||
engines_stats.update(flatten_dict({ name: stats }))
|
||||
|
||||
iteration = engines.global_step
|
||||
engines_stats['it'] = iteration
|
||||
engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
|
||||
|
||||
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||
|
||||
return trainer.load_engines(engines, cfg)
|
||||
|
||||
def main():
|
||||
setup_logging(cfg.log_dir)
|
||||
|
||||
#dist.init_distributed(dist_backend=get_accelerator().communication_backend_name())
|
||||
if not deepspeed._initialized_dist:
|
||||
deepspeed._initialized_dist = True
|
||||
deepspeed.init_distributed()
|
||||
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
|
||||
def train_feeder(engines, batch, name):
|
||||
stats = {}
|
||||
model = engines[name]
|
||||
if name.startswith("ar"):
|
||||
_ = model(
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
resp_list=[r[..., 0] for r in batch["resps"]],
|
||||
)
|
||||
elif name.startswith("nar"):
|
||||
_ = model(
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
resps_list=batch["resps"],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(name)
|
||||
|
||||
losses = model.gather_attribute("loss")
|
||||
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= engines.gather_attribute("scalar")
|
||||
|
||||
return loss, stats
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_eval(engines, eval_name, dl):
|
||||
engines_stats = {
|
||||
'eval': eval_name
|
||||
}
|
||||
|
||||
AR = None
|
||||
NAR = None
|
||||
|
||||
names = []
|
||||
for name in engines:
|
||||
model = engines[name]
|
||||
names.append(name)
|
||||
if name[:2] == "ar":
|
||||
AR = model
|
||||
elif name[:3] == "nar":
|
||||
NAR = model
|
||||
|
||||
stats = defaultdict(list)
|
||||
stats['loss'] = []
|
||||
|
||||
for batch in tqdm(dl):
|
||||
batch: dict = to_device(batch, cfg.device)
|
||||
|
||||
# if we're training both models, provide output for both
|
||||
if AR is not None and NAR is not None:
|
||||
name = "+".join(names)
|
||||
|
||||
resp_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.temperature)
|
||||
resps_list = [ r.unsqueeze(-1) for r in resp_list ]
|
||||
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.temperature)
|
||||
|
||||
for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]):
|
||||
if len(hyp) == 0:
|
||||
continue
|
||||
|
||||
filename = f'{speaker}_{path.parts[-1]}'
|
||||
|
||||
# to-do, refine the output dir to be sane-er
|
||||
ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
|
||||
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")
|
||||
prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav")
|
||||
|
||||
hyp_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
ref_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
prom_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
|
||||
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
|
||||
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
|
||||
|
||||
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
|
||||
ref_audio = ref_audio[..., 0:min_length]
|
||||
hyp_audio = hyp_audio[..., 0:min_length]
|
||||
|
||||
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
|
||||
else:
|
||||
for name in engines:
|
||||
model = engines[name]
|
||||
|
||||
if name.startswith("ar"):
|
||||
resp_list = model(
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
max_steps=cfg.evaluation.steps,
|
||||
sampling_temperature=cfg.evaluation.temperature,
|
||||
)
|
||||
resps_list = [r.unsqueeze(-1) for r in resp_list]
|
||||
elif name.startswith("nar"):
|
||||
resps_list = model(
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]],
|
||||
sampling_temperature=cfg.evaluation.temperature,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(name)
|
||||
|
||||
losses = model.gather_attribute("loss")
|
||||
|
||||
batch_stats = {}
|
||||
batch_stats |= {k: v.item() for k, v in losses.items()}
|
||||
batch_stats |= engines.gather_attribute("scalar")
|
||||
|
||||
for k, v in batch_stats.items():
|
||||
stats[k].append(v)
|
||||
|
||||
for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]):
|
||||
if len(hyp) == 0:
|
||||
continue
|
||||
|
||||
filename = f'{speaker}_{path.parts[-1]}'
|
||||
|
||||
# to-do, refine the output dir to be sane-er
|
||||
ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
|
||||
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")
|
||||
prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav")
|
||||
|
||||
hyp_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
ref_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
prom_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
|
||||
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
|
||||
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
|
||||
|
||||
# pseudo loss calculation since we don't get the logits during eval
|
||||
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
|
||||
ref_audio = ref_audio[..., 0:min_length]
|
||||
hyp_audio = hyp_audio[..., 0:min_length]
|
||||
|
||||
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
|
||||
|
||||
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||||
engines_stats.update(flatten_dict({ name: stats }))
|
||||
|
||||
iteration = engines.global_step
|
||||
engines_stats['it'] = iteration
|
||||
engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
|
||||
|
||||
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||
|
||||
|
||||
def eval_fn(engines):
|
||||
try:
|
||||
run_eval(engines, "subtrain", subtrain_dl)
|
||||
|
@ -256,12 +158,10 @@ def main():
|
|||
qnt.unload_model()
|
||||
|
||||
trainer.train(
|
||||
engines_loader=load_engines,
|
||||
train_dl=train_dl,
|
||||
train_feeder=train_feeder,
|
||||
eval_fn=eval_fn,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,252 +0,0 @@
|
|||
"""
|
||||
# https://github.com/enhuiz/pytorch-training-utilities
|
||||
"""
|
||||
|
||||
# to-do: replace this
|
||||
# to-do: swap out deepspeed
|
||||
|
||||
from ..config import Config
|
||||
from .distributed import fix_unset_envs
|
||||
from .utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
||||
|
||||
import logging
|
||||
import time
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from deepspeed import DeepSpeedEngine
|
||||
from torch import Tensor
|
||||
from torch.distributed import all_reduce
|
||||
from typing import Any, Protocol
|
||||
|
||||
Stats = dict[str, float]
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Engine(DeepSpeedEngine):
|
||||
def __init__(self, *args, **kwargs):
|
||||
fix_unset_envs()
|
||||
super().__init__(None, *args, **kwargs)
|
||||
self._frozen_params = set()
|
||||
|
||||
def freeze(self):
|
||||
for p in self.module.parameters():
|
||||
if p.requires_grad:
|
||||
p.requires_grad_(False)
|
||||
self._frozen_params.add(p)
|
||||
|
||||
def unfreeze(self):
|
||||
for p in self._frozen_params:
|
||||
p.requires_grad_(True)
|
||||
self._frozen_params.clear()
|
||||
|
||||
@property
|
||||
def global_step(self):
|
||||
return self.global_steps
|
||||
|
||||
def gather_attribute(self, *args, **kwargs):
|
||||
return gather_attribute(self.module, *args, **kwargs)
|
||||
|
||||
def dispatch_attribute(self, *args, **kwargs):
|
||||
return dispatch_attribute(self.module, *args, **kwargs)
|
||||
|
||||
|
||||
class TrainFeeder(Protocol):
|
||||
def __call__(
|
||||
self, *, engines: "Engines", batch: Any, name: str
|
||||
) -> None | tuple[Tensor, Stats]:
|
||||
...
|
||||
|
||||
|
||||
class Engines(dict[str, Engine]):
|
||||
def setup(self, cfg: Config):
|
||||
self._cfg = cfg
|
||||
self._global_step = 0
|
||||
|
||||
@property
|
||||
def cfg(self) -> Config:
|
||||
return self._cfg
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self._cfg
|
||||
|
||||
@property
|
||||
def global_step(self):
|
||||
return self._global_step
|
||||
|
||||
def gather_attribute(self, *args, **kwargs):
|
||||
ret = {}
|
||||
for engine in self.values():
|
||||
ret |= engine.gather_attribute(*args, **kwargs)
|
||||
return ret
|
||||
|
||||
def dispatch_attribute(self, *args, **kwargs):
|
||||
for engine in self.values():
|
||||
engine.dispatch_attribute(*args, **kwargs)
|
||||
|
||||
def save_checkpoint(self, tag=None):
|
||||
if not tag:
|
||||
tag = self.cfg.trainer.save_tag
|
||||
tag = tag.lower()
|
||||
if tag[:2] == "it" or tag[:4] == "step":
|
||||
tag = self.global_step
|
||||
|
||||
self.cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||
for name, engine in self.items():
|
||||
engine.save_checkpoint(self.cfg.ckpt_dir / name, tag=tag)
|
||||
|
||||
def load_checkpoint(self, tag=None):
|
||||
if not tag:
|
||||
tag = self.cfg.trainer.load_tag
|
||||
|
||||
for name, engine in self.items():
|
||||
load_dir = self.cfg.ckpt_dir / name
|
||||
engine.load_checkpoint(
|
||||
tag=tag,
|
||||
load_dir=load_dir,
|
||||
load_module_strict=self.cfg.trainer.strict_loading,
|
||||
load_optimizer_states=self.cfg.trainer.load_states,
|
||||
load_lr_scheduler_states=self.cfg.trainer.load_states,
|
||||
load_module_only=False, # not self.cfg.trainer.load_states,
|
||||
)
|
||||
if self.cfg.trainer.restart_step_count:
|
||||
engine.global_steps = 0
|
||||
|
||||
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
|
||||
if self.cfg.hyperparameters.scheduler_type == "":
|
||||
self.set_lr(self.cfg.hyperparameters.learning_rate)
|
||||
|
||||
self._update_global_step()
|
||||
|
||||
def set_lr(self, lr):
|
||||
try:
|
||||
for engine in self.values():
|
||||
if hasattr(engine.optimizer, 'param_groups'):
|
||||
print(engine.optimizer.param_groups)
|
||||
for param_group in engine.optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
else:
|
||||
engine.optimizer.set_lr(lr)
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
|
||||
def _update_global_step(self):
|
||||
for engine in self.values():
|
||||
self._global_step = max(self._global_step, engine.global_step)
|
||||
|
||||
def eval(self):
|
||||
for engine in self.values():
|
||||
engine.eval()
|
||||
|
||||
def train(self):
|
||||
for engine in self.values():
|
||||
engine.train()
|
||||
|
||||
def step(self, feeder: TrainFeeder, batch):
|
||||
total_elapsed_time = 0
|
||||
|
||||
stats: Any = dict()
|
||||
|
||||
if self.cfg.trainer.gc_mode == 'step':
|
||||
do_gc()
|
||||
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
|
||||
for name, engine in self.items():
|
||||
torch.cuda.synchronize()
|
||||
if self.cfg.trainer.gc_mode == 'substep':
|
||||
do_gc()
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
tries = 4
|
||||
n_ooms = torch.zeros([], device=self.cfg.device)
|
||||
if self.cfg.trainer.aggressive_optimizations:
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
# engine = engine.to(torch.cuda.current_device())
|
||||
|
||||
while tries >= 0:
|
||||
try:
|
||||
res = feeder( engines=self, batch=batch, name=name )
|
||||
break
|
||||
except RuntimeError as e:
|
||||
print("Forward", str(e))
|
||||
|
||||
if "out of memory" not in str(e):
|
||||
self.save_checkpoint()
|
||||
raise e
|
||||
|
||||
# shrink batch size until it's happy
|
||||
for k in batch:
|
||||
batch[k] = batch[k][:-1]
|
||||
|
||||
if tries <= 0:
|
||||
# trigger OOM
|
||||
n_ooms += 1
|
||||
else:
|
||||
# also do GC
|
||||
do_gc()
|
||||
continue
|
||||
|
||||
all_reduce(n_ooms)
|
||||
if n_ooms.item() > 0:
|
||||
self.save_checkpoint()
|
||||
raise RuntimeError("Out of memory during forward pass!")
|
||||
|
||||
if res is None:
|
||||
continue
|
||||
|
||||
loss, engine_stats = res
|
||||
|
||||
n_ooms = torch.zeros([], device=self.cfg.device)
|
||||
|
||||
if self.cfg.trainer.aggressive_optimizations:
|
||||
batch = to_device(batch, 'cpu')
|
||||
|
||||
try:
|
||||
engine.backward(loss)
|
||||
except RuntimeError as e:
|
||||
print("Backwards:", str(e))
|
||||
|
||||
if "out of memory" not in str(e):
|
||||
self.save_checkpoint()
|
||||
raise e
|
||||
|
||||
n_ooms += 1
|
||||
|
||||
all_reduce(n_ooms)
|
||||
if n_ooms.item() > 0:
|
||||
self.save_checkpoint()
|
||||
raise RuntimeError("Out of memory during backwards pass!")
|
||||
|
||||
engine.step()
|
||||
torch.cuda.synchronize()
|
||||
elapsed_time = time.time() - start_time
|
||||
total_elapsed_time += elapsed_time
|
||||
|
||||
stats.update(
|
||||
flatten_dict(
|
||||
{
|
||||
name.split("-")[0]: dict(
|
||||
loss=loss.item(),
|
||||
lr=engine.get_lr()[0],
|
||||
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
|
||||
elapsed_time=elapsed_time,
|
||||
engine_step=engine.global_step,
|
||||
**engine_stats,
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
del loss
|
||||
# engine = engine.to('cpu')
|
||||
|
||||
self._update_global_step()
|
||||
stats["batch_size"] = len(batch["text"])
|
||||
stats["elapsed_time"] = total_elapsed_time
|
||||
stats["wall_time"] = time.time()
|
||||
stats["global_step"] = self.global_step
|
||||
|
||||
return stats
|
|
@ -17,239 +17,258 @@ from torch.utils.data import DataLoader
|
|||
from tqdm import tqdm
|
||||
from typing import Protocol
|
||||
|
||||
from ..config import Config
|
||||
from ..config import cfg
|
||||
from .distributed import (
|
||||
global_leader_only,
|
||||
global_rank,
|
||||
is_global_leader,
|
||||
is_local_leader,
|
||||
local_leader_only,
|
||||
fix_unset_envs,
|
||||
global_leader_only,
|
||||
global_rank,
|
||||
is_global_leader,
|
||||
is_local_leader,
|
||||
local_leader_only,
|
||||
)
|
||||
|
||||
from .engines import Engine, Engines, TrainFeeder
|
||||
from ..engines import Engine, Engines, TrainFeeder, default_feeder
|
||||
from ..models import get_models
|
||||
|
||||
from .utils import to_device, do_gc
|
||||
from ..utils import wrapper as ml
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_engines: Engines
|
||||
_command: str
|
||||
|
||||
def get_global_step():
|
||||
try:
|
||||
return _engines.global_step
|
||||
except:
|
||||
return None
|
||||
try:
|
||||
return _engines.global_step
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_micro_step():
|
||||
try:
|
||||
return _engines.micro_step
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def get_cfg():
|
||||
try:
|
||||
return _engines.cfg
|
||||
except:
|
||||
raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
|
||||
|
||||
try:
|
||||
return _engines.micro_step
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_cmd():
|
||||
try:
|
||||
return _command
|
||||
except:
|
||||
raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
|
||||
try:
|
||||
return _command
|
||||
except:
|
||||
raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
|
||||
|
||||
|
||||
get_iteration = get_global_step
|
||||
|
||||
def load_engines():
|
||||
models = get_models(cfg.models.get())
|
||||
engines = dict()
|
||||
|
||||
class EnginesLoader(Protocol):
|
||||
def __call__(self) -> Engines:
|
||||
...
|
||||
for name in models:
|
||||
model = models[name]
|
||||
|
||||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
def load_engines(engines: dict[str, Engine], config: Config):
|
||||
engines = Engines(engines)
|
||||
engines.setup(config)
|
||||
if not engines.cfg.trainer.load_state_dict:
|
||||
engines.load_checkpoint()
|
||||
return engines
|
||||
if cfg.hyperparameters.optimizer.lower() == "adamw-torch":
|
||||
optimizer = ml.AdamW(
|
||||
model.parameters(),
|
||||
lr=cfg.hyperparameters.learning_rate,
|
||||
betas=(0.9, 0.96),
|
||||
eps=1e-07,
|
||||
weight_decay=0.01,
|
||||
)
|
||||
|
||||
if cfg.trainer.load_state_dict:
|
||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
model.load_state_dict(torch.load(load_path))
|
||||
|
||||
engines[name] = Engine(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
|
||||
engines = Engines(engines)
|
||||
engines.setup()
|
||||
|
||||
if not cfg.trainer.load_state_dict:
|
||||
engines.load_checkpoint()
|
||||
|
||||
return engines
|
||||
|
||||
class EvalFn(Protocol):
|
||||
def __call__(self, *, engines: Engines):
|
||||
...
|
||||
def __call__(self, *, engines: Engines):
|
||||
...
|
||||
|
||||
|
||||
class Logger(Protocol):
|
||||
def __call__(self, *, data: dict):
|
||||
...
|
||||
def __call__(self, *, data: dict):
|
||||
...
|
||||
|
||||
|
||||
@cache
|
||||
def _get_stdin_selector():
|
||||
selector = selectors.DefaultSelector()
|
||||
selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ)
|
||||
return selector
|
||||
selector = selectors.DefaultSelector()
|
||||
selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ)
|
||||
return selector
|
||||
|
||||
|
||||
def _non_blocking_input():
|
||||
global _command
|
||||
l = [""]
|
||||
if is_global_leader():
|
||||
s = ""
|
||||
selector = _get_stdin_selector()
|
||||
events = selector.select(timeout=0)
|
||||
for key, _ in events:
|
||||
s: str = key.fileobj.readline().strip()
|
||||
_logger.info(f'Get stdin "{s}".')
|
||||
l[0] = s
|
||||
broadcast_object_list(l, src=0)
|
||||
_command = l[0]
|
||||
return _command
|
||||
global _command
|
||||
l = [""]
|
||||
if is_global_leader():
|
||||
s = ""
|
||||
selector = _get_stdin_selector()
|
||||
events = selector.select(timeout=0)
|
||||
for key, _ in events:
|
||||
s: str = key.fileobj.readline().strip()
|
||||
_logger.info(f'Get stdin "{s}".')
|
||||
l[0] = s
|
||||
broadcast_object_list(l, src=0)
|
||||
_command = l[0]
|
||||
return _command
|
||||
|
||||
|
||||
def _make_infinite_epochs(dl):
|
||||
while True:
|
||||
_logger.info("New epoch starts.")
|
||||
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True)
|
||||
while True:
|
||||
_logger.info("New epoch starts.")
|
||||
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True)
|
||||
|
||||
|
||||
@local_leader_only(default=None)
|
||||
def logger(data):
|
||||
return _logger.info(json.dumps(data, default=str))
|
||||
return _logger.info(json.dumps(data, default=str))
|
||||
|
||||
|
||||
def seed(seed):
|
||||
# Set up random seeds, after fork()
|
||||
random.seed(seed + global_rank())
|
||||
np.random.seed(seed + global_rank())
|
||||
torch.manual_seed(seed + global_rank())
|
||||
# Set up random seeds, after fork()
|
||||
random.seed(seed + global_rank())
|
||||
np.random.seed(seed + global_rank())
|
||||
torch.manual_seed(seed + global_rank())
|
||||
|
||||
|
||||
def train(
|
||||
engines_loader: EnginesLoader,
|
||||
train_dl: DataLoader,
|
||||
train_feeder: TrainFeeder,
|
||||
eval_fn: EvalFn,
|
||||
logger: Logger = logger,
|
||||
train_dl: DataLoader,
|
||||
train_feeder: TrainFeeder = default_feeder,
|
||||
eval_fn: EvalFn = lambda x: ...,
|
||||
logger: Logger = logger,
|
||||
):
|
||||
engines = engines_loader()
|
||||
cfg = engines.cfg
|
||||
fix_unset_envs()
|
||||
|
||||
"""
|
||||
if is_local_leader():
|
||||
cfg.dump()
|
||||
_logger.info(cfg)
|
||||
"""
|
||||
engines = load_engines()
|
||||
|
||||
# Setup global engines
|
||||
global _engines
|
||||
_engines = engines
|
||||
"""
|
||||
if is_local_leader():
|
||||
cfg.dump()
|
||||
_logger.info(cfg)
|
||||
"""
|
||||
|
||||
events = []
|
||||
# Setup global engines
|
||||
global _engines
|
||||
_engines = engines
|
||||
|
||||
eval_fn = global_leader_only(eval_fn)
|
||||
events = []
|
||||
|
||||
# Pre-loop command
|
||||
command = _non_blocking_input()
|
||||
if command in ["eval", "eval_quit"]:
|
||||
engines.eval()
|
||||
eval_fn(engines=engines)
|
||||
engines.train()
|
||||
if command in ["quit", "eval_quit"]:
|
||||
return
|
||||
eval_fn = global_leader_only(eval_fn)
|
||||
|
||||
last_save_step = engines.global_step
|
||||
last_eval_step = 0
|
||||
# Pre-loop command
|
||||
command = _non_blocking_input()
|
||||
if command in ["eval", "eval_quit"]:
|
||||
engines.eval()
|
||||
eval_fn(engines=engines)
|
||||
engines.train()
|
||||
if command in ["quit", "eval_quit"]:
|
||||
return
|
||||
|
||||
# Training loop
|
||||
for batch in _make_infinite_epochs(train_dl):
|
||||
if engines.global_step >= cfg.trainer.iterations:
|
||||
break
|
||||
last_save_step = engines.global_step
|
||||
last_eval_step = 0
|
||||
|
||||
#batch = to_device(batch, torch.cuda.current_device())
|
||||
stats = engines.step(feeder=train_feeder, batch=batch)
|
||||
# Training loop
|
||||
for batch in _make_infinite_epochs(train_dl):
|
||||
if engines.global_step >= cfg.trainer.iterations:
|
||||
break
|
||||
|
||||
iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps
|
||||
stats['it'] = iteration
|
||||
stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
|
||||
#batch = to_device(batch, torch.cuda.current_device())
|
||||
stats = engines.step(batch=batch, feeder=train_feeder)
|
||||
|
||||
stats['batch'] = {
|
||||
'size': stats['batch_size'],
|
||||
'id': batch['spkr_id'],
|
||||
'index': [ index for index in batch['index'] ],
|
||||
'text_len': [ text.shape[0] for text in batch['text'] ],
|
||||
'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
|
||||
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
|
||||
}
|
||||
iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps
|
||||
stats['it'] = iteration
|
||||
stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
|
||||
|
||||
del stats['batch_size']
|
||||
del stats['wall_time']
|
||||
del stats['global_step']
|
||||
stats['batch'] = {
|
||||
'size': stats['batch_size'],
|
||||
'id': batch['spkr_id'],
|
||||
'index': [ index for index in batch['index'] ],
|
||||
'text_len': [ text.shape[0] for text in batch['text'] ],
|
||||
'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
|
||||
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
|
||||
}
|
||||
|
||||
elapsed_time = stats.get("elapsed_time", 0)
|
||||
_logger.info(f"Training Metrics: {json.dumps(stats)}.")
|
||||
del stats['batch_size']
|
||||
del stats['wall_time']
|
||||
del stats['global_step']
|
||||
|
||||
command = _non_blocking_input()
|
||||
elapsed_time = stats.get("elapsed_time", 0)
|
||||
_logger.info(f"Training Metrics: {json.dumps(stats)}.")
|
||||
|
||||
if "@" in command:
|
||||
what, when = command.split("@")
|
||||
try:
|
||||
events.append((what, int(when)))
|
||||
_logger.info(f"Event {command} registered.")
|
||||
except Exception as e:
|
||||
_logger.error(e)
|
||||
command = ""
|
||||
command = _non_blocking_input()
|
||||
|
||||
# Commands are the current command plus the triggered (i.e. iteration >= trigger point) events
|
||||
events = [e for e in events if e[1] >= engines.global_step]
|
||||
commands = [command] + [e[0] for e in events if e[1] == engines.global_step]
|
||||
if "@" in command:
|
||||
what, when = command.split("@")
|
||||
try:
|
||||
events.append((what, int(when)))
|
||||
_logger.info(f"Event {command} registered.")
|
||||
except Exception as e:
|
||||
_logger.error(e)
|
||||
command = ""
|
||||
|
||||
for command in commands:
|
||||
if command in ["event show", "event"]:
|
||||
msg = "Events:\n" + "\n".join(["@".join(map(str, e)) for e in events])
|
||||
_logger.info(msg)
|
||||
# Commands are the current command plus the triggered (i.e. iteration >= trigger point) events
|
||||
events = [e for e in events if e[1] >= engines.global_step]
|
||||
commands = [command] + [e[0] for e in events if e[1] == engines.global_step]
|
||||
|
||||
if command == "event clear":
|
||||
events.clear()
|
||||
for command in commands:
|
||||
if command in ["event show", "event"]:
|
||||
msg = "Events:\n" + "\n".join(["@".join(map(str, e)) for e in events])
|
||||
_logger.info(msg)
|
||||
|
||||
if "time" in command:
|
||||
target_iter = cfg.trainer.iterations
|
||||
if " to " in command:
|
||||
try:
|
||||
target_iter = int(command.split(" to ")[-1])
|
||||
except Exception as e:
|
||||
_logger.error(e)
|
||||
remaining_iters = target_iter - engines.global_step + 1
|
||||
remaining_time = int(remaining_iters * elapsed_time)
|
||||
_logger.info(humanize.precisedelta(remaining_time))
|
||||
if command == "event clear":
|
||||
events.clear()
|
||||
|
||||
if "lr" in command:
|
||||
rate = float(command.split(" ")[-1])
|
||||
engines.set_lr(rate)
|
||||
print("Updating LR to:", rate)
|
||||
if "time" in command:
|
||||
target_iter = cfg.trainer.iterations
|
||||
if " to " in command:
|
||||
try:
|
||||
target_iter = int(command.split(" to ")[-1])
|
||||
except Exception as e:
|
||||
_logger.error(e)
|
||||
remaining_iters = target_iter - engines.global_step + 1
|
||||
remaining_time = int(remaining_iters * elapsed_time)
|
||||
_logger.info(humanize.precisedelta(remaining_time))
|
||||
|
||||
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
|
||||
if "lr" in command:
|
||||
rate = float(command.split(" ")[-1])
|
||||
engines.set_lr(rate)
|
||||
print("Updating LR to:", rate)
|
||||
|
||||
saving_commands = ["save"]
|
||||
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
|
||||
|
||||
if cfg.trainer.save_on_quit:
|
||||
saving_commands.append("quit")
|
||||
saving_commands = ["save"]
|
||||
|
||||
if engines.global_step != last_save_step:
|
||||
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
|
||||
engines.save_checkpoint()
|
||||
last_save_step = engines.global_step
|
||||
if cfg.trainer.save_on_quit:
|
||||
saving_commands.append("quit")
|
||||
|
||||
if engines.global_step != last_eval_step:
|
||||
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
|
||||
do_gc()
|
||||
if engines.global_step != last_save_step:
|
||||
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
|
||||
engines.save_checkpoint()
|
||||
last_save_step = engines.global_step
|
||||
|
||||
engines.eval()
|
||||
eval_fn(engines=engines)
|
||||
engines.train()
|
||||
last_eval_step = engines.global_step
|
||||
if engines.global_step != last_eval_step:
|
||||
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
|
||||
do_gc()
|
||||
|
||||
if command in ["quit"]:
|
||||
return
|
||||
engines.eval()
|
||||
eval_fn(engines=engines)
|
||||
engines.train()
|
||||
last_eval_step = engines.global_step
|
||||
|
||||
if command in ["quit"]:
|
||||
return
|
Loading…
Reference in New Issue
Block a user