big cleanup

This commit is contained in:
mrq 2023-08-03 20:26:36 -05:00
parent 2e03e5ac93
commit c85101403f
13 changed files with 1037 additions and 1031 deletions

View File

@ -15,12 +15,12 @@ dataset:
workers: 8 workers: 8
cache: True cache: True
phones_range: [4, 256] phones_range: [4, 512]
duration_range: [1.0, 12.0] duration_range: [1.0, 24.0]
random_utterance: 1.0 random_utterance: 1.0
max_prompts: 3 max_prompts: 6
prompt_duration: 3.0 prompt_duration: 6.0
models: models:
_models: _models:
@ -69,7 +69,8 @@ evaluation:
size: 32 size: 32
steps: 300 steps: 300
temperature: 1.0 ar_temperature: 1.0
nar_temperature: 0.2
trainer: trainer:
iterations: 100_000 iterations: 100_000
@ -91,7 +92,13 @@ trainer:
weight_dtype: bfloat16 weight_dtype: bfloat16
zero_optimization_level: 2 backend: deepspeed
use_compression_training: True deepspeed:
zero_optimization_level: 0
use_compression_training: True
use_vocos: False inference:
use_vocos: True
bitsandbytes:
enabled: false

View File

@ -232,7 +232,121 @@ class Evaluation:
size: int = 64 size: int = 64
steps: int = 500 steps: int = 500
temperature: float = 1.0 ar_temperature: float = 1.0
nar_temperature: float = 0.2
@dataclass()
class DeepSpeed:
zero_optimization_level: int = 0
use_compression_training: bool = False
def get_ds_cfg(self, model):
weights = [ name[0] for name in model.named_parameters() ]
bits = 8
scheduler_params = {}
for k in cfg.hyperparameters.scheduler_params:
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
if cfg.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params:
scheduler_params['total_num_steps'] = cfg.trainer.iterations
ds_cfg = {
"train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size,
"gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps,
"optimizer": {
"type": cfg.hyperparameters.optimizer,
"params": {
"lr": cfg.hyperparameters.learning_rate,
}
},
"scheduler": {
"type": cfg.hyperparameters.scheduler_type,
"params": scheduler_params,
} if cfg.hyperparameters.scheduler_type != "" else None,
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
"fp16": {
"enabled": True,
"auto_cast": True,
} if cfg.trainer.weight_dtype.lower() == "float16" else None,
"bf16": {
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16"
},
"compression_training": {
"weight_quantization": {
"shared_parameters":{
"enabled": True,
"quantizer_kernel": True,
"schedule_offset": 0,
"quantize_groups": 64,
"quantize_verbose": True,
"quantization_type": "symmetric",
"rounding": "nearest",
"quantize_weight_in_forward": True,
"fp16_mixed_quantize":{
"enabled": False,
"quantize_change_ratio": 1
}
},
"different_groups": {
"wq1": {
"params": {
"start_bits": bits,
"target_bits": bits,
"quantization_period": 0
},
"modules": weights
}
}
},
"activation_quantization": {
"shared_parameters":{
"enabled": True,
"quantization_type": "symmetric",
"range_calibration": "dynamic",
"schedule_offset": 0
},
"different_groups": {
"aq1": {
"params": {
"bits": bits
},
"modules": weights
}
}
}
} if self.use_compression_training else None,
"zero_optimization": {
"stage": self.zero_optimization_level,
"contiguous_gradients": True,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8,
"sub_group_size": 5e8,
"round_robin_gradients": True,
"offload_optimizer": {
"device": "cpu",
"pin_memory": True
},
"offload_param": {
"device": "cpu",
"pin_memory": True
}
} if self.zero_optimization_level > 0 else None,
"comms_logger": {
"enabled": False
}
}
null_keys = [ k for k in ds_cfg if not ds_cfg[k] ]
for k in null_keys:
del ds_cfg[k]
if os.path.exists("./config/ds_config.json"):
ds_cfg.update(json.load(open("./config/ds_config.json", "r", encoding="utf-8")))
return ds_cfg
@dataclass() @dataclass()
class Trainer: class Trainer:
@ -256,8 +370,9 @@ class Trainer:
weight_dtype: str = "float16" weight_dtype: str = "float16"
zero_optimization_level: int = 0 backend: str = "deepspeed"
use_compression_training: bool = False
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
@dataclass() @dataclass()
@ -284,7 +399,6 @@ class Config(_Config):
inference: Inference = field(default_factory=lambda: Inference) inference: Inference = field(default_factory=lambda: Inference)
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes) bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
@property @property
def sample_rate(self): def sample_rate(self):
return 24_000 return 24_000
@ -293,142 +407,6 @@ class Config(_Config):
def get_spkr(self): def get_spkr(self):
return eval(self.dataset.speaker_name_getter) return eval(self.dataset.speaker_name_getter)
@property
def scheduler(self):
cfg = {
"type": self.hyperparameters.scheduler_type,
"params": {},
}
for k in self.hyperparameters.scheduler_params:
cfg['params'][k] = self.hyperparameters.scheduler_params[k]
if self.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in cfg['params']:
cfg['params']['total_num_steps'] = self.trainer.iterations
return cfg
@property
def fp16_cfg(self):
if self.trainer.weight_dtype.lower() != "float16":
return None
return {
"enabled": True,
"auto_cast": True,
}
@property
def bf16_cfg(self):
return {
"enabled": self.trainer.weight_dtype.lower() == "bfloat16"
}
def get_compression_cfg(self, model):
if not self.trainer.use_compression_training:
return None
weights = [ name[0] for name in model.named_parameters() ]
bits = 8
return {
"weight_quantization": {
"shared_parameters":{
"enabled": True,
"quantizer_kernel": True,
"schedule_offset": 0,
"quantize_groups": 64,
"quantize_verbose": True,
"quantization_type": "symmetric",
"rounding": "nearest",
"quantize_weight_in_forward": True,
"fp16_mixed_quantize":{
"enabled": False,
"quantize_change_ratio": 1
}
},
"different_groups": {
"wq1": {
"params": {
"start_bits": bits,
"target_bits": bits,
"quantization_period": 0
},
"modules": weights
}
}
},
"activation_quantization": {
"shared_parameters":{
"enabled": True,
"quantization_type": "symmetric",
"range_calibration": "dynamic",
"schedule_offset": 0
},
"different_groups": {
"aq1": {
"params": {
"bits": bits
},
"modules": weights
}
}
}
}
@property
def zero_cfg(self):
if self.trainer.zero_optimization_level == 0:
return None
return {
"stage": self.trainer.zero_optimization_level,
"contiguous_gradients": True,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8,
"sub_group_size": 5e8,
"round_robin_gradients": True,
"offload_optimizer": {
"device": "cpu",
"pin_memory": True
},
"offload_param": {
"device": "cpu",
"pin_memory": True
}
}
def get_ds_cfg(self, model):
cfg = {
"train_micro_batch_size_per_gpu": self.hyperparameters.batch_size,
"gradient_accumulation_steps": self.hyperparameters.gradient_accumulation_steps,
"optimizer": {
"type": self.hyperparameters.optimizer,
"params": {
"lr": self.hyperparameters.learning_rate,
}
},
"scheduler": self.hyperparameters.scheduler if self.hyperparameters.scheduler_type != "" else None,
"gradient_clipping": self.hyperparameters.gradient_clipping,
"fp16": self.fp16_cfg,
"bf16": self.bf16_cfg,
"compression_training": self.get_compression_cfg(model),
"zero_optimization": self.zero_cfg,
"comms_logger": {
"enabled": False
}
}
null_keys = [ k for k in cfg if not cfg[k] ]
for k in null_keys:
del cfg[k]
if os.path.exists("./config/ds_config.json"):
ds_cfg = json.load(open("./config/ds_config.json", "r", encoding="utf-8"))
cfg.update(ds_cfg)
return cfg
@property @property
def cache_dir(self): def cache_dir(self):
return ".cache" / self.relpath return ".cache" / self.relpath
@ -455,6 +433,8 @@ cfg.trainer = Trainer(**cfg.trainer)
cfg.inference = Inference(**cfg.inference) cfg.inference = Inference(**cfg.inference)
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes) cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed)
# cached_property stopped working... # cached_property stopped working...
if cfg.dataset.use_hdf5: if cfg.dataset.use_hdf5:
try: try:

View File

@ -0,0 +1,9 @@
from ..config import cfg
if cfg.trainer.backend == "deepspeed":
from .deepspeed import Engine
elif cfg.trainer.backend == "local":
from .base import Engine
from .base import Engines
from .base import TrainFeeder

374
vall_e/engines/base.py Normal file
View File

@ -0,0 +1,374 @@
from torch import Tensor
from typing import Any, Protocol
Stats = dict[str, float]
class TrainFeeder(Protocol):
def __call__(
self, *, engine: "Engine", batch: Any
) -> None | tuple[Tensor, Stats]:
...
def default_feed(engine, batch):
if isinstance(batch, list):
engine( *batch )
elif isinstance(batch, dict):
engine( **batch )
else:
engine( batch )
losses = engine.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
return loss, stats
from ..config import cfg
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
import logging
import time
import torch
import torch.distributed
import os
from torch import Tensor
from torch.distributed import all_reduce
from typing import Any, Protocol
from .base import TrainFeeder
_logger = logging.getLogger(__name__)
# A very naive engine implementation using barebones PyTorch
class Engine():
def __init__(self, *args, **kwargs):
self.module = kwargs['model']
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
self.global_steps = 0
self.micro_steps = 0
self.gradient_accumulation_steps = cfg.hyperparameters.gradient_accumulation_steps
def freeze(self):
for p in self.module.parameters():
if p.requires_grad:
p.requires_grad_(False)
self._frozen_params.add(p)
def unfreeze(self):
for p in self._frozen_params:
p.requires_grad_(True)
self._frozen_params.clear()
@property
def global_step(self):
return self.global_steps
@property
def micro_step(self):
return self.micro_steps
def train_batch_size(self):
return cfg.hyperparameters.batch_size
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
def dispatch_attribute(self, *args, **kwargs):
return dispatch_attribute(self.module, *args, **kwargs)
def save_checkpoint(self, save_dir, tag ):
torch.save({
"global_step": self.global_step,
"micro_step": self.micro_step,
"module": self.module.state_dict(),
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
}, save_dir / tag / "state.pth")
def load_checkpoint(self, load_dir, tag, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
state = torch.load(load_dir / tag / "state.pth")
self.global_step = state['global_step']
self.micro_step = state['micro_step']
self.module.load_state_dict(state['module'])
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
if load_optimizer_states:
self.optimizer.load_state_dict(state['optimizer'])
if load_lr_scheduler_states:
self.lr_scheduler.load_state_dict(state['lr_scheduler'])
def eval(self):
return self.module.eval()
def train(self):
return self.module.train()
def to(self, *args, **kwargs):
self.module = self.module.to(*args, **kwargs)
return self.module
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def forward(self, *args, **kwargs):
return self.module.forward(*args, **kwargs)
def backward(self, loss):
return (loss / self.gradient_accumulation_steps).backward()
def step(self):
with torch.set_grad_enabled(self.gradient_accumulation_steps > 1):
self.micro_steps += 1
if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0:
self.global_steps += 1
self.optimizer.step()
self.optimizer.zero_grad()
def get_lr(self):
lrs = []
for param_group in self.optimizer.param_groups:
if 'lr' in param_group:
lrs.append(param_group['lr'])
return lrs
def set_lr(self, lr):
for param_group in self.optimizer.param_groups:
if 'lr' in param_group:
param_group['lr'] = lr
def get_global_grad_norm(self):
return 0.0
def traverse(self, *args, **kwargs):
self.forward(*args, **kwargs)
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")
self.backward(loss)
self.step()
return stats
# and now to ignore everything from the above
class Engines(dict[str, Engine]):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setup()
def setup(self):
self._global_step = 0
self._micro_step = 0
@property
def global_step(self):
return self._global_step
@property
def micro_step(self):
return self._micro_step
def gather_attribute(self, *args, **kwargs):
ret = {}
for engine in self.values():
ret |= engine.gather_attribute(*args, **kwargs)
return ret
def dispatch_attribute(self, *args, **kwargs):
for engine in self.values():
engine.dispatch_attribute(*args, **kwargs)
def save_checkpoint(self, tag=None):
if not tag:
tag = cfg.trainer.save_tag
tag = tag.lower()
if tag[:2] == "it" or tag[:4] == "step":
tag = f'{self.global_step}'
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
for name, engine in self.items():
engine.save_checkpoint(cfg.ckpt_dir / name, tag=tag)
def load_checkpoint(self, tag=None):
if not tag:
tag = cfg.trainer.load_tag
for name, engine in self.items():
load_dir = cfg.ckpt_dir / name
engine.load_checkpoint(
tag=tag,
load_dir=load_dir,
load_module_strict=cfg.trainer.strict_loading,
load_optimizer_states=cfg.trainer.load_states,
load_lr_scheduler_states=cfg.trainer.load_states,
)
if cfg.trainer.restart_step_count:
engine.global_steps = 0
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
if cfg.hyperparameters.scheduler_type == "":
self.set_lr(cfg.hyperparameters.learning_rate)
self._update_global_step()
self._update_micro_step()
def set_lr(self, lr):
for engine in self.values():
engine.set_lr(lr)
def _update_global_step(self):
for engine in self.values():
self._global_step = max(self._global_step, engine.global_step)
def _update_micro_step(self):
for engine in self.values():
self._micro_step = max(self._micro_step, engine.micro_step)
def train_batch_size(self):
batch_size = 0
for engine in self.values():
batch_size = max(batch_size, engine.train_batch_size())
def eval(self):
for engine in self.values():
engine.eval()
def train(self):
for engine in self.values():
engine.train()
def traverse(self):
stats = {}
for name, engine in self.items():
stat = engine.traverse()
stats.update(flatten_dict({ name.split("-")[0]: stat }))
return stats
def step(self, batch, feeder: TrainFeeder = default_feed, device=torch.cuda.current_device()):
total_elapsed_time = 0
stats: Any = dict()
if cfg.trainer.gc_mode == 'step':
do_gc()
batch = to_device(batch, device)
for name, engine in self.items():
#torch.cuda.synchronize()
if cfg.trainer.gc_mode == 'substep':
do_gc()
start_time = time.time()
tries = 4
n_ooms = torch.zeros([], device=cfg.device)
if cfg.trainer.aggressive_optimizations:
batch = to_device(batch, device)
res = feeder( engine=engine, batch=batch )
"""
while tries >= 0:
try:
res = feeder( engine=engine, batch=batch )
break
except RuntimeError as e:
print("Forward", str(e))
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
# shrink batch size until it's happy
for k in batch:
batch[k] = batch[k][:-1]
if tries <= 0:
# trigger OOM
n_ooms += 1
else:
# also do GC
do_gc()
continue
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during forward pass!")
"""
if res is None:
continue
loss, engine_stats = res
engine_stats |= self.gather_attribute("scalar")
n_ooms = torch.zeros([], device=cfg.device)
if cfg.trainer.aggressive_optimizations:
batch = to_device(batch, 'cpu')
engine.backward(loss)
"""
try:
engine.backward(loss)
except RuntimeError as e:
print("Backwards:", str(e))
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
n_ooms += 1
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during backwards pass!")
"""
engine.step()
#torch.cuda.synchronize()
elapsed_time = time.time() - start_time
total_elapsed_time += elapsed_time
stats.update(
flatten_dict(
{
name.split("-")[0]: dict(
loss=loss.item(),
lr=engine.get_lr()[0],
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
elapsed_time=elapsed_time,
engine_step=engine.global_step,
**engine_stats,
)
}
),
)
self._update_global_step()
self._update_micro_step()
stats["batch_size"] = self.train_batch_size() # len(batch["text"])
stats["elapsed_time"] = total_elapsed_time
stats["wall_time"] = time.time()
stats["global_step"] = self.global_step
return stats

View File

@ -0,0 +1,89 @@
"""
# https://github.com/enhuiz/pytorch-training-utilities
"""
# to-do: replace this
# to-do: swap out deepspeed
from ..config import cfg
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
import logging
import time
import torch
import torch.distributed
from torch import Tensor
from torch.distributed import all_reduce
from typing import Any, Protocol
from .base import TrainFeeder
_logger = logging.getLogger(__name__)
from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distributed
from deepspeed.accelerator import get_accelerator
#dist.init_distributed(dist_backend=get_accelerator().communication_backend_name())
initialized_dist = False
if not initialized_dist:
initialized_dist = True
init_distributed()
class Engine(DeepSpeedEngine):
def __init__(self, *args, **kwargs):
kwargs['config'] = cfg.trainer.deepspeed.get_ds_cfg(model=kwargs['model'])
kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
super().__init__(None, *args, **kwargs)
self._frozen_params = set()
def freeze(self):
for p in self.module.parameters():
if p.requires_grad:
p.requires_grad_(False)
self._frozen_params.add(p)
def unfreeze(self):
for p in self._frozen_params:
p.requires_grad_(True)
self._frozen_params.clear()
@property
def global_step(self):
return self.global_steps
@property
def micro_step(self):
return self.micro_steps
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
def dispatch_attribute(self, *args, **kwargs):
return dispatch_attribute(self.module, *args, **kwargs)
def set_lr(self, lr):
try:
if hasattr(self.optimizer, 'param_groups'):
print(self.optimizer.param_groups)
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
else:
self.optimizer.set_lr(lr)
except Exception as e:
print(str(e))
def traverse(self, *args, **kwargs):
self.forward(*args, **kwargs)
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")
self.backward(loss)
self.step()
return stats

View File

@ -72,8 +72,8 @@ class TTS():
prom = to_device(prom, self.device).to(torch.int16) prom = to_device(prom, self.device).to(torch.int16)
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16) phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
resp_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp) resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp)
resps_list = [r.unsqueeze(-1) for r in resp_list] resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp) resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp)
wav, sr = qnt.decode_to_file(resps_list[0], out_path) wav, sr = qnt.decode_to_file(resps_list[0], out_path)

View File

@ -1,20 +1,22 @@
from .ar import AR from .ar import AR
from .nar import NAR from .nar import NAR
def get_model(model): def get_model(cfg):
if model.name == "ar": if cfg.name == "ar":
Model = AR Model = AR
elif model.name == "nar": elif cfg.name == "nar":
Model = NAR Model = NAR
else: else:
raise f"invalid model name: {model.name}" raise f"invalid model name: {cfg.name}"
name = model.name name = cfg.name
model = Model( model = Model(
n_tokens=model.tokens, n_tokens=model.tokens,
d_model=model.dim, d_model=model.dim,
n_heads=model.heads, n_heads=model.heads,
n_layers=model.layers, n_layers=model.layers,
) )
model._cfg = cfg
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

View File

@ -50,112 +50,80 @@ class AR(Base):
self, self,
text_list: list[Tensor], text_list: list[Tensor],
proms_list: list[Tensor], proms_list: list[Tensor],
resp_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None,
max_steps: int = 1000, max_steps: int = 1000,
sampling_temperature: float = 1.0, sampling_temperature: float = 1.0,
naive: bool = True, naive: bool = True,
): ):
if resp_list is not None: if resps_list is not None:
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
return super().forward( return super().forward(
text_list, text_list=text_list,
proms_list, proms_list=proms_list,
self._unsqueeze_list(resp_list), resps_list=self._unsqueeze_list(resps_list),
resp_list, targ_list=resps_list,
quant_levels=None, quant_levels=None,
shift_targ_list=True, shift_targ_list=True,
return_all_resp=False, return_all_resp=False,
) )
device = text_list[0].device device = text_list[0].device
resp_list: list[Tensor] = [ resps_list: list[Tensor] = [
torch.zeros(0, device=device).to(torch.int16) for _ in text_list torch.zeros(0, device=device).to(torch.int16) for _ in text_list
] ]
stopped = torch.zeros(len(text_list), device=device).bool() stopped = torch.zeros(len(text_list), device=device).bool()
if self.arch_type == "transformer":
naive = True
chunk_size = 1 # don't really know what to do about this desu chunk_size = 1 # don't really know what to do about this desu
state = None state = None
start = 0 start = 0
# prefill
if self.arch_type == "retnet/local":
# pre-process
state = [
[
torch.zeros(self.retnet.hidden_dim // self.retnet.heads, self.retnet.v_dim // self.retnet.heads, device=device).unsqueeze(0).repeat(len(text_list), 1, 1)
for _ in range(self.retnet.heads)
] for _ in range(self.retnet.layers)
]
resps_list = self._unsqueeze_list(resp_list)
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.proms_emb(proms_list),
self.resps_emb(resps_list),
sep=self.sep,
)
x, m = list_to_tensor(x_list)
start = x.shape[1]
for i in trange(start-1):
_, state = self.retnet.forward_recurrent( x[:, i:i+1, :], state, i+1 )
for n in trange(max_steps // chunk_size): for n in trange(max_steps // chunk_size):
# get next in sequence # get next in sequence
r, state = super().forward( r, state = super().forward(
text_list, text_list,
proms_list, proms_list,
self._unsqueeze_list(resp_list), self._unsqueeze_list(resps_list),
sampling_temperature=sampling_temperature, sampling_temperature=sampling_temperature,
state=state, state=state if not naive else None,
) )
# append outputted token # append outputted token
for i, ri in enumerate(r): for i, ri in enumerate(r):
resp_list[i] = torch.cat([resp_list[i], ri[None]]) resps_list[i] = torch.cat([resps_list[i], ri[None]])
# stop token found # stop token found
stopped |= r == self.stop_token stopped |= r == self.stop_token
if stopped.all().item(): if stopped.all().item():
break break
pruned = [self._prune(r) for r in resp_list] pruned = [self._prune(r) for r in resps_list]
return pruned return pruned
def example_usage(): def example_usage():
cfg.trainer.backend = "local"
from functools import partial from functools import partial
from einops import repeat from einops import repeat
from ..emb.qnt import decode_to_file from ..emb.qnt import decode_to_file
from ..utils import gather_attribute from ..engines import Engine
from tqdm import tqdm
device = "cpu" device = "cpu"
x8 = partial(repeat, pattern="t -> t l", l=2)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"): def tokenize(content, lang_marker="en"):
split = content.split(" ") split = content.split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"] phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
return torch.tensor([*map(symmap.get, phones)]).to() return torch.tensor([*map(symmap.get, phones)]).to()
qnt = torch.load("data/qnt.pt")[0, 0].to(device) qnt = torch.load("data/qnt.pt")[0].t()[:, :2].to(device)
kwargs = {
'n_tokens': 1024,
'd_model': 1024,
'n_heads': 16,
'n_layers': 12,
}
model = AR(**kwargs).to(device)
x8 = partial(repeat, pattern="t -> t l", l=2)
text_list = [ text_list = [
#torch.tensor([1, 2, 3], device=device), #torch.tensor([1, 2, 3], device=device),
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device), tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
@ -164,41 +132,41 @@ def example_usage():
x8(torch.tensor([1, 2, 3], device=device)), x8(torch.tensor([1, 2, 3], device=device)),
#qnt.to(device), #qnt.to(device),
] ]
resp_list = [ resps_list = [
qnt.to(device), qnt.to(device),
] ]
text_list = text_list[:1] text_list = text_list[:1]
proms_list = proms_list[:1] proms_list = proms_list[:1]
resp_list = resp_list[:1] resps_list = resps_list[:1]
model.eval() kwargs = {
out = model(text_list, proms_list, max_steps=75)[0] 'n_tokens': 1024,
print("qnt:", qnt.shape, qnt) 'd_model': 1024,
print("out:", out.shape, out) 'n_heads': 16,
wav, sr = decode_to_file(out, "data/test/test.ar.init.wav", device=device) 'n_layers': 12,
}
model = AR(**kwargs).to(device)
engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4))
def sample( name, steps=400 ):
engine.eval()
out = engine(text_list, proms_list, max_steps=steps)
for i, o in enumerate(out):
wav, sr = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) def train():
engine.train()
t = trange(60)
for i in t:
stats = {"step": i}
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
model.train() t.set_description(f"{stats}")
for i in trange(60):
optimizer.zero_grad()
_ = model(text_list, proms_list, resp_list)
losses = gather_attribute(model, "loss")
loss = sum(losses.values())
loss.backward()
optimizer.step()
if i % 20 == 0:
print(f"iter={i}, {losses}.")
model.eval()
out = model(text_list, proms_list, max_steps=400)
print("qnt:", qnt.shape, qnt)
for i, o in enumerate(out):
print("out:", i, o.shape, o)
wav, sr = decode_to_file(o, f"data/test/test.ar.{i}.wav", device=device)
sample("init", 75)
train()
sample("final")
if __name__ == "__main__": if __name__ == "__main__":
example_usage() example_usage()

View File

@ -176,17 +176,7 @@ class Base(nn.Module):
self.retnet = RetNetDecoder( self.retnet = RetNetDecoder(
self.retnet_config self.retnet_config
) )
elif self.arch_type == "retnet/local":
self.retnet = RetNet(
layers=n_layers,
hidden_dim=d_model,
ffn_size=d_model * 4,
heads=n_heads,
dropout=p_dropout,
norm_type=self.norm_type,
n_levels=self.n_resp_levels,
double_v_dim=True
)
self.classifier = nn.Linear(d_model, n_resp_tokens) self.classifier = nn.Linear(d_model, n_resp_tokens)
self.accuracy_metric = MulticlassAccuracy( self.accuracy_metric = MulticlassAccuracy(
@ -272,7 +262,7 @@ class Base(nn.Module):
text_list: [t] * b text_list: [t] * b
proms_list: [t' l] * b, l quantization levels. proms_list: [t' l] * b, l quantization levels.
resps_list: [t'' l] * b, l quantization levels. resps_list: [t'' l] * b, l quantization levels.
targ_list: [t''] * b, one quantization level only, when given, loss will be computed targ_list: [t''] * b, one quantization level only; when given, loss will be computed
quant_levels: specify which quant_levels to feed forward, used in NAR mode. quant_levels: specify which quant_levels to feed forward, used in NAR mode.
shift_targ_list: whether to shift target list when computing loss. True if AR. shift_targ_list: whether to shift target list when computing loss. True if AR.
return_all_resp: True if NAR. return_all_resp: True if NAR.
@ -298,24 +288,12 @@ class Base(nn.Module):
elif self.arch_type == "retnet": elif self.arch_type == "retnet":
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True) x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
state = self.retnet.get_incremental_state( state, 'prev_state' ) state = self.retnet.get_incremental_state( state, 'prev_state' )
elif self.arch_type == "retnet/local":
# recurrent inferencing
if self.causal and state is not None:
last = x.shape[1]
x, state = self.retnet.forward_recurrent(
x[:, last-1:last, :], # nasty way to grab the last embedding to forward
state,
last
)
else:
x = self.retnet( x, quant_levels )
x = self.classifier(x) * m x = self.classifier(x) * m
# Remove padding # Remove padding
h_list = [hi[:li] for hi, li in zip(x, map(len, x_list))] h_list = [hi[:li] for hi, li in zip(x, map(len, x_list))]
# compute loss if the target is given # compute loss if the target is given
if targ_list is not None: if targ_list is not None:
if any([l == 0 for l in map(len, targ_list)]): if any([l == 0 for l in map(len, targ_list)]):
@ -337,20 +315,24 @@ class Base(nn.Module):
# the NAR doesn't need to compute the loss for it # the NAR doesn't need to compute the loss for it
if self.resp_loss_only: if self.resp_loss_only:
text_prom_list[i][:] = self.ignore_index text_prom_list[i][:] = self.ignore_index
# roll the text/prompt for loss computing # roll the text/prompt for loss computing
# the AR benefits from this # the AR benefits from this, for some reason I'll figure out later
else: else:
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0) text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
text_prom_list[i][-1] = self.ignore_index text_prom_list[i][-1] = self.ignore_index
# necessary to roll the target if recurrently/causally/autoregressively generating, or it won't be able to work # for the AR, roll by one and mark the ending with a stop token
# this coerces the model into properly inferencing causally
# why we don't just append a stop token in the dataloader, who knows
if shift_targ_list: if shift_targ_list:
targ_list = [*targ_list] targ_list = [*targ_list]
for i in range(len(targ_list)): for i in range(len(targ_list)):
targ_list[i] = targ_list[i].roll(-1, dims=0) targ_list[i] = targ_list[i].roll(-1, dims=0)
targ_list[i][-1] = self.stop_token targ_list[i][-1] = self.stop_token
# generate the sequence # create the new target sequence to compute the loss against
y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep )
self.loss = dict( self.loss = dict(
@ -367,6 +349,7 @@ class Base(nn.Module):
del prom_list del prom_list
del text_prom_list del text_prom_list
del y_list del y_list
# return the entire generated token string # return the entire generated token string
if return_all: if return_all:
@ -387,124 +370,90 @@ class Base(nn.Module):
return ret, state return ret, state
def example_usage(): def example_usage():
from ..config import cfg
cfg.trainer.backend = "local"
from functools import partial from functools import partial
from einops import repeat from einops import repeat
from tqdm import trange
from ..utils import gather_attribute
from ..emb.qnt import decode_to_file from ..emb.qnt import decode_to_file
from ..engines import Engine, Engines
from tqdm import tqdm, trange
from .ar import AR from .ar import AR
from .nar import NAR from .nar import NAR
device = "cpu"
x8 = partial(repeat, pattern="t -> t l", l=2)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"): def tokenize(content, lang_marker="en"):
split = content.split(" ") split = content.split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"] phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
return torch.tensor([*map(symmap.get, phones)]).to() return torch.tensor([*map(symmap.get, phones)]).to()
device = "cpu"
kwargs = { kwargs = {
'n_tokens': 1024, 'n_tokens': 1024,
'd_model': 1024, 'd_model': 1024,
'n_heads': 16, 'n_heads': 16,
'n_layers': 12, 'n_layers': 12,
} }
model_ar = AR(**kwargs).to(device) models = { "ar": AR(**kwargs).to(device), "nar": NAR(**kwargs).to(device) }
model_nar = NAR(**kwargs).to(device) engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
train = True train = True
if train:
qnt = torch.load("data/qnt.pt").to(device)
text_list = [
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
#tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device),
]
x8 = partial(repeat, pattern="t -> t l", l=2) qnt = torch.load("data/qnt.pt")[0].t()[:, :2].to(device)
proms_list = [ text_list = [
qnt[0][:2,:].t().to(device), tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
#x8(torch.tensor([1, 2, 3], device=device)), #tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device),
# x8(torch.tensor([2, 3], device=device)), ]
]
resp_list_ar = [ proms_list = [
qnt[0,0].to(device), qnt.to(device),
# qnt[0,0].to(device), ]
] resps_list = [
qnt.to(device),
resp_list_nar = [ ]
qnt[0][:2,:].t().to(device),
# qnt[0][:2,:].t().to(device),
]
model_ar.train()
optimizer = torch.optim.AdamW(model_ar.parameters(), lr=1e-4)
for i in trange(60):
optimizer.zero_grad()
_ = model_ar(text_list, proms_list, resp_list_ar)
losses = gather_attribute(model_ar, "loss")
loss = sum(losses.values())
loss.backward()
optimizer.step()
if i % 20 == 0:
print(f"iter={i}, {losses}.")
model_nar.train()
optimizer = torch.optim.AdamW(model_nar.parameters(), lr=1e-4)
for i in trange(60):
optimizer.zero_grad()
_ = model_nar(text_list, proms_list, resps_list=resp_list_nar)
losses = gather_attribute(model_nar, "loss")
loss = sum(losses.values())
loss.backward()
optimizer.step()
if i % 20 == 0:
stats = {k: v.item() for k, v in losses.items()}
stats["loss"] = loss.item()
print(f"iter={i}, {stats}.")
else:
qnt = torch.load("data/test/test.qnt.pt")[0][:2,:].t().to(device)
text_list = [
#tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device),
]
proms_list = [
qnt.to(device),
]
model_ar.load_state_dict(torch.load("data/test/ar.pth"))
model_nar.load_state_dict(torch.load("data/test/nar.pth"))
model_ar.eval()
resp_list = model_ar(text_list, proms_list, max_steps=300, sampling_temperature=1.0)
resps_list = [r.unsqueeze(-1) for r in resp_list]
print("qnt:", qnt.shape, qnt) def sample( name, steps=400 ):
print("out:", resp_list[0].shape, resp_list[0]) AR = None
wav, sr = decode_to_file(resp_list[0], "data/test/test.ar.init.wav", device=device) NAR = None
print(wav, sr)
model_nar.eval() engines.eval()
codes = model_nar( for name, engine in engines.items():
text_list, if name[:2] == "ar":
proms_list, AR = engine
resps_list=resps_list, elif name[:3] == "nar":
sampling_temperature=1.0, NAR = engine
)[0]
resps_list = AR(text_list, proms_list, max_steps=steps, sampling_temperature=1.0)
resps_list = [r.unsqueeze(-1) for r in resps_list]
codes = NAR( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
print("qnt:", qnt.shape, qnt) decode_to_file(resps_list[0], f"./data/ar.{name}.wav", device=device)
print("codes:", codes.shape, codes) decode_to_file(codes[0], f"./data/ar+nar.{name}.wav", device=device)
if train:
sample("init", 15)
wav, sr = decode_to_file(codes, "data/test/test.ar+nar.init.wav", device=device) engines.train()
print(wav, sr) t = trange(60)
for i in t:
"""
stats = {"step": i}
for name, engine in engines.items():
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
"""
stats = engines.step({"text_list": text_list, "proms_list": proms_list, "resps_list": resps_list}, device="cpu")
t.set_description(f"{stats}")
else:
for name, engine in engines.items():
engine.module.load_state_dict(torch.load(f"./data/{name}.pth"))
sample("final")
if __name__ == "__main__": if __name__ == "__main__":
example_usage() example_usage()

View File

@ -80,7 +80,6 @@ class NAR(Base):
quant_levels=quant_levels, quant_levels=quant_levels,
) )
# Yes, just nothing as we are training
prev_list = [] prev_list = []
else: else:
prev_list = resps_list prev_list = resps_list
@ -93,7 +92,7 @@ class NAR(Base):
quant_levels = torch.full((len(text_list),), level, device=device) quant_levels = torch.full((len(text_list),), level, device=device)
resp_list, _ = super().forward( resps_list, _ = super().forward(
text_list, text_list,
proms_list, proms_list,
prev_list, prev_list,
@ -105,24 +104,22 @@ class NAR(Base):
prev_list = [ prev_list = [
torch.cat([rs, r.unsqueeze(-1)], dim=-1) torch.cat([rs, r.unsqueeze(-1)], dim=-1)
for rs, r in zip(prev_list, resp_list) for rs, r in zip(prev_list, resps_list)
] ]
return prev_list return prev_list
def example_usage(): def example_usage():
cfg.trainer.backend = "local"
from functools import partial from functools import partial
from pathlib import Path
from einops import repeat from einops import repeat
from ..emb.qnt import decode_to_file from ..emb.qnt import decode_to_file
from ..utils import gather_attribute from ..engines import Engine
from ..config import cfg from tqdm import tqdm
device = "cpu" device = "cpu"
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels) x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"): def tokenize(content, lang_marker="en"):
@ -130,15 +127,8 @@ def example_usage():
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"] phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
return torch.tensor([*map(symmap.get, phones)]).to() return torch.tensor([*map(symmap.get, phones)]).to()
resps = torch.load("data/qnt.pt")[0][:2, :].to(device) # to-do: unmangle this and the resp shit
kwargs = { qnt = torch.load("data/qnt.pt")[0].t()[:, :2].to(device)
'n_tokens': 1024,
'd_model': 1024,
'n_heads': 16,
'n_layers': 12,
}
model = NAR(**kwargs).to(device)
text_list = [ text_list = [
#torch.tensor([1, 2, 3], device=device), #torch.tensor([1, 2, 3], device=device),
@ -149,65 +139,36 @@ def example_usage():
x8(torch.tensor([2, 3], device=device)), x8(torch.tensor([2, 3], device=device)),
] ]
resps_x1_list = [ resps_list = [
resps[:1].t().to(device), qnt.to(device),
] ]
resps_x8_list = [ kwargs = {
resps.t().to(device), 'n_tokens': 1024,
] 'd_model': 1024,
'n_heads': 16,
'n_layers': 12,
}
model = NAR(**kwargs).to(device)
engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4))
model.eval() def sample( name ):
codes = model( engine.eval()
text_list, codes = engine( text_list, proms_list, resps_list=[r[..., 0].unsqueeze(-1) for r in resps_list], sampling_temperature=0.2 )
proms_list, decode_to_file( codes[0], f"data/nar.{name}.wav", device )
resps_list=resps_x1_list,
sampling_temperature=0.2,
)[0]
decode_to_file( def train():
codes, engine.train()
Path("data/test/test.nar.init.wav"), t = trange(60)
device for i in t:
) stats = {"step": i}
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) t.set_description(f"{stats}")
model.train() sample("init")
for i in trange(50): train()
optimizer.zero_grad() sample("final")
_ = model(text_list, proms_list, resps_list=resps_x8_list)
losses = gather_attribute(model, "loss")
loss = sum(losses.values())
loss.backward()
optimizer.step()
if i % 20 == 0:
stats = {k: v.item() for k, v in losses.items()}
stats["loss"] = loss.item()
print(f"iter={i}, {stats}.")
model.eval()
for i in trange(1, 2): # cfg.models.prom_levels):
resps_list = [
resps[:i].t().to(device),
]
codes = model(
text_list,
proms_list,
resps_list=resps_list,
sampling_temperature=0.2,
)[0]
decode_to_file(
codes,
Path(f"data/test/test.nar.1-{i}.wav"),
device
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,17 +1,13 @@
# todo: clean this mess up # todo: clean this mess up
# todo: yank deepspeed dependent code out into its own thing
from .config import cfg from .config import cfg
from .data import create_train_val_dataloader from .data import create_train_val_dataloader
from .emb import qnt from .emb import qnt
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
from .utils import wrapper as ml
from .models import get_models
import auraloss import auraloss
import deepspeed
import json import json
import logging import logging
import random import random
@ -21,10 +17,6 @@ import traceback
from collections import defaultdict from collections import defaultdict
from deepspeed import comm as dist
from deepspeed import DeepSpeedConfig
from deepspeed.accelerator import get_accelerator
from tqdm import tqdm from tqdm import tqdm
mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cuda") mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cuda")
@ -38,210 +30,120 @@ def left_crop(x, len):
return x[..., 0:len] return x[..., 0:len]
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
deepspeed._initialized_dist = False
def load_engines(): def train_feeder(engine, batch):
if not deepspeed._initialized_dist: engine( text_list=batch["text"], proms_list=batch["proms"], resps_list=batch["resps"] )
deepspeed._initialized_dist = True
deepspeed.init_distributed()
models = get_models(cfg.models.get()) losses = engine.gather_attribute("loss")
engines = dict()
for name in models: loss = torch.stack([*losses.values()]).sum()
model = models[name]
optimizer = None stats = {}
lr_scheduler = None stats |= {k: v.item() for k, v in losses.items()}
Adam = ml.Adam return loss, stats
AdamW = ml.AdamW
if cfg.hyperparameters.optimizer.lower() == "adamw-torch": @torch.inference_mode()
optimizer = AdamW( def run_eval(engines, eval_name, dl):
model.parameters(), engines_stats = {
lr=cfg.hyperparameters.learning_rate, 'eval': eval_name
betas=(0.9, 0.96), }
eps=1e-07,
weight_decay=0.01,
)
if cfg.trainer.load_state_dict: AR = None
load_dir = cfg.ckpt_dir / name / "fp32.pth" NAR = None
model.load_state_dict(torch.load(load_dir))
ds_cfg=cfg.get_ds_cfg(model=model) names = []
config_class=DeepSpeedConfig(ds_cfg) for name, engine in engines.items():
engines[name] = trainer.Engine( names.append(name)
model=model, if name[:2] == "ar":
config=ds_cfg, AR = engine
config_class=config_class, elif name[:3] == "nar":
optimizer=optimizer, NAR = engine
lr_scheduler=lr_scheduler,
) stats = defaultdict(list)
stats['loss'] = []
def process( name, batch, resps_list ):
for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]):
if len(hyp) == 0:
continue
filename = f'{speaker}_{path.parts[-1]}'
# to-do, refine the output dir to be sane-er
ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")
prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav")
hyp_path.parent.mkdir(parents=True, exist_ok=True)
ref_path.parent.mkdir(parents=True, exist_ok=True)
prom_path.parent.mkdir(parents=True, exist_ok=True)
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
# pseudo loss calculation since we don't get the logits during eval
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
ref_audio = ref_audio[..., 0:min_length]
hyp_audio = hyp_audio[..., 0:min_length]
try:
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
except Exception as e:
print(str(e))
for batch in tqdm(dl):
batch: dict = to_device(batch, cfg.device)
# if we're training both models, provide output for both
if AR is not None and NAR is not None:
name = "+".join(names)
resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.temperature)
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.temperature)
process( name, batch, resps_list )
else:
for name in engines:
model = engines[name]
if name.startswith("ar"):
resps_list = model(
text_list=batch["text"],
proms_list=batch["proms"],
max_steps=cfg.evaluation.steps,
sampling_temperature=cfg.evaluation.ar_temperature,
)
resps_list = [r.unsqueeze(-1) for r in resps_list]
elif name.startswith("nar"):
resps_list = model(
text_list=batch["text"],
proms_list=batch["proms"],
resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]],
sampling_temperature=cfg.evaluation.nar_temperature,
)
else:
raise NotImplementedError(name)
process( name, batch, resps_list )
stats = {k: sum(v) / len(v) for k, v in stats.items()}
engines_stats.update(flatten_dict({ name: stats }))
iteration = engines.global_step
engines_stats['it'] = iteration
engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
return trainer.load_engines(engines, cfg)
def main(): def main():
setup_logging(cfg.log_dir) setup_logging(cfg.log_dir)
#dist.init_distributed(dist_backend=get_accelerator().communication_backend_name())
if not deepspeed._initialized_dist:
deepspeed._initialized_dist = True
deepspeed.init_distributed()
train_dl, subtrain_dl, val_dl = create_train_val_dataloader() train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
def train_feeder(engines, batch, name):
stats = {}
model = engines[name]
if name.startswith("ar"):
_ = model(
text_list=batch["text"],
proms_list=batch["proms"],
resp_list=[r[..., 0] for r in batch["resps"]],
)
elif name.startswith("nar"):
_ = model(
text_list=batch["text"],
proms_list=batch["proms"],
resps_list=batch["resps"],
)
else:
raise NotImplementedError(name)
losses = model.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= engines.gather_attribute("scalar")
return loss, stats
@torch.inference_mode()
def run_eval(engines, eval_name, dl):
engines_stats = {
'eval': eval_name
}
AR = None
NAR = None
names = []
for name in engines:
model = engines[name]
names.append(name)
if name[:2] == "ar":
AR = model
elif name[:3] == "nar":
NAR = model
stats = defaultdict(list)
stats['loss'] = []
for batch in tqdm(dl):
batch: dict = to_device(batch, cfg.device)
# if we're training both models, provide output for both
if AR is not None and NAR is not None:
name = "+".join(names)
resp_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.temperature)
resps_list = [ r.unsqueeze(-1) for r in resp_list ]
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.temperature)
for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]):
if len(hyp) == 0:
continue
filename = f'{speaker}_{path.parts[-1]}'
# to-do, refine the output dir to be sane-er
ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")
prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav")
hyp_path.parent.mkdir(parents=True, exist_ok=True)
ref_path.parent.mkdir(parents=True, exist_ok=True)
prom_path.parent.mkdir(parents=True, exist_ok=True)
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
ref_audio = ref_audio[..., 0:min_length]
hyp_audio = hyp_audio[..., 0:min_length]
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
else:
for name in engines:
model = engines[name]
if name.startswith("ar"):
resp_list = model(
text_list=batch["text"],
proms_list=batch["proms"],
max_steps=cfg.evaluation.steps,
sampling_temperature=cfg.evaluation.temperature,
)
resps_list = [r.unsqueeze(-1) for r in resp_list]
elif name.startswith("nar"):
resps_list = model(
text_list=batch["text"],
proms_list=batch["proms"],
resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]],
sampling_temperature=cfg.evaluation.temperature,
)
else:
raise NotImplementedError(name)
losses = model.gather_attribute("loss")
batch_stats = {}
batch_stats |= {k: v.item() for k, v in losses.items()}
batch_stats |= engines.gather_attribute("scalar")
for k, v in batch_stats.items():
stats[k].append(v)
for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]):
if len(hyp) == 0:
continue
filename = f'{speaker}_{path.parts[-1]}'
# to-do, refine the output dir to be sane-er
ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")
prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav")
hyp_path.parent.mkdir(parents=True, exist_ok=True)
ref_path.parent.mkdir(parents=True, exist_ok=True)
prom_path.parent.mkdir(parents=True, exist_ok=True)
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
# pseudo loss calculation since we don't get the logits during eval
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
ref_audio = ref_audio[..., 0:min_length]
hyp_audio = hyp_audio[..., 0:min_length]
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
stats = {k: sum(v) / len(v) for k, v in stats.items()}
engines_stats.update(flatten_dict({ name: stats }))
iteration = engines.global_step
engines_stats['it'] = iteration
engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
def eval_fn(engines): def eval_fn(engines):
try: try:
run_eval(engines, "subtrain", subtrain_dl) run_eval(engines, "subtrain", subtrain_dl)
@ -256,12 +158,10 @@ def main():
qnt.unload_model() qnt.unload_model()
trainer.train( trainer.train(
engines_loader=load_engines,
train_dl=train_dl, train_dl=train_dl,
train_feeder=train_feeder, train_feeder=train_feeder,
eval_fn=eval_fn, eval_fn=eval_fn,
) )
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,252 +0,0 @@
"""
# https://github.com/enhuiz/pytorch-training-utilities
"""
# to-do: replace this
# to-do: swap out deepspeed
from ..config import Config
from .distributed import fix_unset_envs
from .utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
import logging
import time
import torch
import torch.distributed
from deepspeed import DeepSpeedEngine
from torch import Tensor
from torch.distributed import all_reduce
from typing import Any, Protocol
Stats = dict[str, float]
_logger = logging.getLogger(__name__)
class Engine(DeepSpeedEngine):
def __init__(self, *args, **kwargs):
fix_unset_envs()
super().__init__(None, *args, **kwargs)
self._frozen_params = set()
def freeze(self):
for p in self.module.parameters():
if p.requires_grad:
p.requires_grad_(False)
self._frozen_params.add(p)
def unfreeze(self):
for p in self._frozen_params:
p.requires_grad_(True)
self._frozen_params.clear()
@property
def global_step(self):
return self.global_steps
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
def dispatch_attribute(self, *args, **kwargs):
return dispatch_attribute(self.module, *args, **kwargs)
class TrainFeeder(Protocol):
def __call__(
self, *, engines: "Engines", batch: Any, name: str
) -> None | tuple[Tensor, Stats]:
...
class Engines(dict[str, Engine]):
def setup(self, cfg: Config):
self._cfg = cfg
self._global_step = 0
@property
def cfg(self) -> Config:
return self._cfg
@property
def config(self):
return self._cfg
@property
def global_step(self):
return self._global_step
def gather_attribute(self, *args, **kwargs):
ret = {}
for engine in self.values():
ret |= engine.gather_attribute(*args, **kwargs)
return ret
def dispatch_attribute(self, *args, **kwargs):
for engine in self.values():
engine.dispatch_attribute(*args, **kwargs)
def save_checkpoint(self, tag=None):
if not tag:
tag = self.cfg.trainer.save_tag
tag = tag.lower()
if tag[:2] == "it" or tag[:4] == "step":
tag = self.global_step
self.cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
for name, engine in self.items():
engine.save_checkpoint(self.cfg.ckpt_dir / name, tag=tag)
def load_checkpoint(self, tag=None):
if not tag:
tag = self.cfg.trainer.load_tag
for name, engine in self.items():
load_dir = self.cfg.ckpt_dir / name
engine.load_checkpoint(
tag=tag,
load_dir=load_dir,
load_module_strict=self.cfg.trainer.strict_loading,
load_optimizer_states=self.cfg.trainer.load_states,
load_lr_scheduler_states=self.cfg.trainer.load_states,
load_module_only=False, # not self.cfg.trainer.load_states,
)
if self.cfg.trainer.restart_step_count:
engine.global_steps = 0
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
if self.cfg.hyperparameters.scheduler_type == "":
self.set_lr(self.cfg.hyperparameters.learning_rate)
self._update_global_step()
def set_lr(self, lr):
try:
for engine in self.values():
if hasattr(engine.optimizer, 'param_groups'):
print(engine.optimizer.param_groups)
for param_group in engine.optimizer.param_groups:
param_group['lr'] = lr
else:
engine.optimizer.set_lr(lr)
except Exception as e:
print(str(e))
def _update_global_step(self):
for engine in self.values():
self._global_step = max(self._global_step, engine.global_step)
def eval(self):
for engine in self.values():
engine.eval()
def train(self):
for engine in self.values():
engine.train()
def step(self, feeder: TrainFeeder, batch):
total_elapsed_time = 0
stats: Any = dict()
if self.cfg.trainer.gc_mode == 'step':
do_gc()
batch = to_device(batch, torch.cuda.current_device())
for name, engine in self.items():
torch.cuda.synchronize()
if self.cfg.trainer.gc_mode == 'substep':
do_gc()
start_time = time.time()
tries = 4
n_ooms = torch.zeros([], device=self.cfg.device)
if self.cfg.trainer.aggressive_optimizations:
batch = to_device(batch, torch.cuda.current_device())
# engine = engine.to(torch.cuda.current_device())
while tries >= 0:
try:
res = feeder( engines=self, batch=batch, name=name )
break
except RuntimeError as e:
print("Forward", str(e))
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
# shrink batch size until it's happy
for k in batch:
batch[k] = batch[k][:-1]
if tries <= 0:
# trigger OOM
n_ooms += 1
else:
# also do GC
do_gc()
continue
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during forward pass!")
if res is None:
continue
loss, engine_stats = res
n_ooms = torch.zeros([], device=self.cfg.device)
if self.cfg.trainer.aggressive_optimizations:
batch = to_device(batch, 'cpu')
try:
engine.backward(loss)
except RuntimeError as e:
print("Backwards:", str(e))
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
n_ooms += 1
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during backwards pass!")
engine.step()
torch.cuda.synchronize()
elapsed_time = time.time() - start_time
total_elapsed_time += elapsed_time
stats.update(
flatten_dict(
{
name.split("-")[0]: dict(
loss=loss.item(),
lr=engine.get_lr()[0],
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
elapsed_time=elapsed_time,
engine_step=engine.global_step,
**engine_stats,
)
}
),
)
del loss
# engine = engine.to('cpu')
self._update_global_step()
stats["batch_size"] = len(batch["text"])
stats["elapsed_time"] = total_elapsed_time
stats["wall_time"] = time.time()
stats["global_step"] = self.global_step
return stats

View File

@ -17,239 +17,258 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from typing import Protocol from typing import Protocol
from ..config import Config from ..config import cfg
from .distributed import ( from .distributed import (
global_leader_only, fix_unset_envs,
global_rank, global_leader_only,
is_global_leader, global_rank,
is_local_leader, is_global_leader,
local_leader_only, is_local_leader,
local_leader_only,
) )
from .engines import Engine, Engines, TrainFeeder from ..engines import Engine, Engines, TrainFeeder, default_feeder
from ..models import get_models
from .utils import to_device, do_gc from .utils import to_device, do_gc
from ..utils import wrapper as ml
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
_engines: Engines _engines: Engines
_command: str _command: str
def get_global_step(): def get_global_step():
try: try:
return _engines.global_step return _engines.global_step
except: except:
return None return None
def get_micro_step(): def get_micro_step():
try: try:
return _engines.micro_step return _engines.micro_step
except: except:
return None return None
def get_cfg():
try:
return _engines.cfg
except:
raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
def get_cmd(): def get_cmd():
try: try:
return _command return _command
except: except:
raise RuntimeError("Trainer has not been setup. Have you called trainer.train?") raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
get_iteration = get_global_step get_iteration = get_global_step
def load_engines():
models = get_models(cfg.models.get())
engines = dict()
class EnginesLoader(Protocol): for name in models:
def __call__(self) -> Engines: model = models[name]
...
optimizer = None
lr_scheduler = None
def load_engines(engines: dict[str, Engine], config: Config): if cfg.hyperparameters.optimizer.lower() == "adamw-torch":
engines = Engines(engines) optimizer = ml.AdamW(
engines.setup(config) model.parameters(),
if not engines.cfg.trainer.load_state_dict: lr=cfg.hyperparameters.learning_rate,
engines.load_checkpoint() betas=(0.9, 0.96),
return engines eps=1e-07,
weight_decay=0.01,
)
if cfg.trainer.load_state_dict:
load_path = cfg.ckpt_dir / name / "fp32.pth"
model.load_state_dict(torch.load(load_path))
engines[name] = Engine(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
engines = Engines(engines)
engines.setup()
if not cfg.trainer.load_state_dict:
engines.load_checkpoint()
return engines
class EvalFn(Protocol): class EvalFn(Protocol):
def __call__(self, *, engines: Engines): def __call__(self, *, engines: Engines):
... ...
class Logger(Protocol): class Logger(Protocol):
def __call__(self, *, data: dict): def __call__(self, *, data: dict):
... ...
@cache @cache
def _get_stdin_selector(): def _get_stdin_selector():
selector = selectors.DefaultSelector() selector = selectors.DefaultSelector()
selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ) selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ)
return selector return selector
def _non_blocking_input(): def _non_blocking_input():
global _command global _command
l = [""] l = [""]
if is_global_leader(): if is_global_leader():
s = "" s = ""
selector = _get_stdin_selector() selector = _get_stdin_selector()
events = selector.select(timeout=0) events = selector.select(timeout=0)
for key, _ in events: for key, _ in events:
s: str = key.fileobj.readline().strip() s: str = key.fileobj.readline().strip()
_logger.info(f'Get stdin "{s}".') _logger.info(f'Get stdin "{s}".')
l[0] = s l[0] = s
broadcast_object_list(l, src=0) broadcast_object_list(l, src=0)
_command = l[0] _command = l[0]
return _command return _command
def _make_infinite_epochs(dl): def _make_infinite_epochs(dl):
while True: while True:
_logger.info("New epoch starts.") _logger.info("New epoch starts.")
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True) yield from tqdm(dl, "Epoch progress", dynamic_ncols=True)
@local_leader_only(default=None) @local_leader_only(default=None)
def logger(data): def logger(data):
return _logger.info(json.dumps(data, default=str)) return _logger.info(json.dumps(data, default=str))
def seed(seed): def seed(seed):
# Set up random seeds, after fork() # Set up random seeds, after fork()
random.seed(seed + global_rank()) random.seed(seed + global_rank())
np.random.seed(seed + global_rank()) np.random.seed(seed + global_rank())
torch.manual_seed(seed + global_rank()) torch.manual_seed(seed + global_rank())
def train( def train(
engines_loader: EnginesLoader, train_dl: DataLoader,
train_dl: DataLoader, train_feeder: TrainFeeder = default_feeder,
train_feeder: TrainFeeder, eval_fn: EvalFn = lambda x: ...,
eval_fn: EvalFn, logger: Logger = logger,
logger: Logger = logger,
): ):
engines = engines_loader() fix_unset_envs()
cfg = engines.cfg
""" engines = load_engines()
if is_local_leader():
cfg.dump()
_logger.info(cfg)
"""
# Setup global engines """
global _engines if is_local_leader():
_engines = engines cfg.dump()
_logger.info(cfg)
"""
events = [] # Setup global engines
global _engines
_engines = engines
eval_fn = global_leader_only(eval_fn) events = []
# Pre-loop command eval_fn = global_leader_only(eval_fn)
command = _non_blocking_input()
if command in ["eval", "eval_quit"]:
engines.eval()
eval_fn(engines=engines)
engines.train()
if command in ["quit", "eval_quit"]:
return
last_save_step = engines.global_step # Pre-loop command
last_eval_step = 0 command = _non_blocking_input()
if command in ["eval", "eval_quit"]:
engines.eval()
eval_fn(engines=engines)
engines.train()
if command in ["quit", "eval_quit"]:
return
# Training loop last_save_step = engines.global_step
for batch in _make_infinite_epochs(train_dl): last_eval_step = 0
if engines.global_step >= cfg.trainer.iterations:
break
#batch = to_device(batch, torch.cuda.current_device()) # Training loop
stats = engines.step(feeder=train_feeder, batch=batch) for batch in _make_infinite_epochs(train_dl):
if engines.global_step >= cfg.trainer.iterations:
break
iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps #batch = to_device(batch, torch.cuda.current_device())
stats['it'] = iteration stats = engines.step(batch=batch, feeder=train_feeder)
stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
stats['batch'] = { iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps
'size': stats['batch_size'], stats['it'] = iteration
'id': batch['spkr_id'], stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
'index': [ index for index in batch['index'] ],
'text_len': [ text.shape[0] for text in batch['text'] ],
'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
}
del stats['batch_size'] stats['batch'] = {
del stats['wall_time'] 'size': stats['batch_size'],
del stats['global_step'] 'id': batch['spkr_id'],
'index': [ index for index in batch['index'] ],
'text_len': [ text.shape[0] for text in batch['text'] ],
'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
}
elapsed_time = stats.get("elapsed_time", 0) del stats['batch_size']
_logger.info(f"Training Metrics: {json.dumps(stats)}.") del stats['wall_time']
del stats['global_step']
command = _non_blocking_input() elapsed_time = stats.get("elapsed_time", 0)
_logger.info(f"Training Metrics: {json.dumps(stats)}.")
if "@" in command: command = _non_blocking_input()
what, when = command.split("@")
try:
events.append((what, int(when)))
_logger.info(f"Event {command} registered.")
except Exception as e:
_logger.error(e)
command = ""
# Commands are the current command plus the triggered (i.e. iteration >= trigger point) events if "@" in command:
events = [e for e in events if e[1] >= engines.global_step] what, when = command.split("@")
commands = [command] + [e[0] for e in events if e[1] == engines.global_step] try:
events.append((what, int(when)))
_logger.info(f"Event {command} registered.")
except Exception as e:
_logger.error(e)
command = ""
for command in commands: # Commands are the current command plus the triggered (i.e. iteration >= trigger point) events
if command in ["event show", "event"]: events = [e for e in events if e[1] >= engines.global_step]
msg = "Events:\n" + "\n".join(["@".join(map(str, e)) for e in events]) commands = [command] + [e[0] for e in events if e[1] == engines.global_step]
_logger.info(msg)
if command == "event clear": for command in commands:
events.clear() if command in ["event show", "event"]:
msg = "Events:\n" + "\n".join(["@".join(map(str, e)) for e in events])
_logger.info(msg)
if "time" in command: if command == "event clear":
target_iter = cfg.trainer.iterations events.clear()
if " to " in command:
try:
target_iter = int(command.split(" to ")[-1])
except Exception as e:
_logger.error(e)
remaining_iters = target_iter - engines.global_step + 1
remaining_time = int(remaining_iters * elapsed_time)
_logger.info(humanize.precisedelta(remaining_time))
if "lr" in command: if "time" in command:
rate = float(command.split(" ")[-1]) target_iter = cfg.trainer.iterations
engines.set_lr(rate) if " to " in command:
print("Updating LR to:", rate) try:
target_iter = int(command.split(" to ")[-1])
except Exception as e:
_logger.error(e)
remaining_iters = target_iter - engines.global_step + 1
remaining_time = int(remaining_iters * elapsed_time)
_logger.info(humanize.precisedelta(remaining_time))
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency if "lr" in command:
rate = float(command.split(" ")[-1])
engines.set_lr(rate)
print("Updating LR to:", rate)
saving_commands = ["save"] save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
if cfg.trainer.save_on_quit: saving_commands = ["save"]
saving_commands.append("quit")
if engines.global_step != last_save_step: if cfg.trainer.save_on_quit:
if engines.global_step % save_ckpt_every == 0 or command in saving_commands: saving_commands.append("quit")
engines.save_checkpoint()
last_save_step = engines.global_step
if engines.global_step != last_eval_step: if engines.global_step != last_save_step:
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]: if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
do_gc() engines.save_checkpoint()
last_save_step = engines.global_step
engines.eval() if engines.global_step != last_eval_step:
eval_fn(engines=engines) if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
engines.train() do_gc()
last_eval_step = engines.global_step
if command in ["quit"]: engines.eval()
return eval_fn(engines=engines)
engines.train()
last_eval_step = engines.global_step
if command in ["quit"]:
return