big cleanup

This commit is contained in:
mrq 2023-08-03 20:26:36 -05:00
parent 2e03e5ac93
commit c85101403f
13 changed files with 1037 additions and 1031 deletions

View File

@ -15,12 +15,12 @@ dataset:
workers: 8 workers: 8
cache: True cache: True
phones_range: [4, 256] phones_range: [4, 512]
duration_range: [1.0, 12.0] duration_range: [1.0, 24.0]
random_utterance: 1.0 random_utterance: 1.0
max_prompts: 3 max_prompts: 6
prompt_duration: 3.0 prompt_duration: 6.0
models: models:
_models: _models:
@ -69,7 +69,8 @@ evaluation:
size: 32 size: 32
steps: 300 steps: 300
temperature: 1.0 ar_temperature: 1.0
nar_temperature: 0.2
trainer: trainer:
iterations: 100_000 iterations: 100_000
@ -91,7 +92,13 @@ trainer:
weight_dtype: bfloat16 weight_dtype: bfloat16
zero_optimization_level: 2 backend: deepspeed
deepspeed:
zero_optimization_level: 0
use_compression_training: True use_compression_training: True
use_vocos: False inference:
use_vocos: True
bitsandbytes:
enabled: false

View File

@ -232,104 +232,47 @@ class Evaluation:
size: int = 64 size: int = 64
steps: int = 500 steps: int = 500
temperature: float = 1.0 ar_temperature: float = 1.0
nar_temperature: float = 0.2
@dataclass() @dataclass()
class Trainer: class DeepSpeed:
iterations: int = 100_000
save_tag: str = "step"
load_tag: str | None = None
save_on_oom: bool = True
save_on_quit: bool = True
save_frequency: int = 100
load_state_dict: bool = False
load_states: bool = True
strict_loading: bool = True
restart_step_count: bool = False
aggressive_optimizations: bool = False
gc_mode: str | None = None
weight_dtype: str = "float16"
zero_optimization_level: int = 0 zero_optimization_level: int = 0
use_compression_training: bool = False use_compression_training: bool = False
def get_ds_cfg(self, model):
@dataclass()
class Inference:
use_vocos: bool = True
@dataclass()
class BitsAndBytes:
enabled: bool = False
injects: bool = False
linear: bool = False
embedding: bool = False
@dataclass()
class Config(_Config):
device: str = "cuda"
dataset: Dataset = field(default_factory=lambda: Dataset)
models: Models = field(default_factory=lambda: Models)
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
trainer: Trainer = field(default_factory=lambda: Trainer)
inference: Inference = field(default_factory=lambda: Inference)
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
@property
def sample_rate(self):
return 24_000
@cached_property
def get_spkr(self):
return eval(self.dataset.speaker_name_getter)
@property
def scheduler(self):
cfg = {
"type": self.hyperparameters.scheduler_type,
"params": {},
}
for k in self.hyperparameters.scheduler_params:
cfg['params'][k] = self.hyperparameters.scheduler_params[k]
if self.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in cfg['params']:
cfg['params']['total_num_steps'] = self.trainer.iterations
return cfg
@property
def fp16_cfg(self):
if self.trainer.weight_dtype.lower() != "float16":
return None
return {
"enabled": True,
"auto_cast": True,
}
@property
def bf16_cfg(self):
return {
"enabled": self.trainer.weight_dtype.lower() == "bfloat16"
}
def get_compression_cfg(self, model):
if not self.trainer.use_compression_training:
return None
weights = [ name[0] for name in model.named_parameters() ] weights = [ name[0] for name in model.named_parameters() ]
bits = 8 bits = 8
return { scheduler_params = {}
for k in cfg.hyperparameters.scheduler_params:
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
if cfg.hyperparameters.scheduler_type == "WarmupDecayLR" and 'total_num_steps' not in scheduler_params:
scheduler_params['total_num_steps'] = cfg.trainer.iterations
ds_cfg = {
"train_micro_batch_size_per_gpu": cfg.hyperparameters.batch_size,
"gradient_accumulation_steps": cfg.hyperparameters.gradient_accumulation_steps,
"optimizer": {
"type": cfg.hyperparameters.optimizer,
"params": {
"lr": cfg.hyperparameters.learning_rate,
}
},
"scheduler": {
"type": cfg.hyperparameters.scheduler_type,
"params": scheduler_params,
} if cfg.hyperparameters.scheduler_type != "" else None,
"gradient_clipping": cfg.hyperparameters.gradient_clipping,
"fp16": {
"enabled": True,
"auto_cast": True,
} if cfg.trainer.weight_dtype.lower() == "float16" else None,
"bf16": {
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16"
},
"compression_training": {
"weight_quantization": { "weight_quantization": {
"shared_parameters":{ "shared_parameters":{
"enabled": True, "enabled": True,
@ -372,15 +315,9 @@ class Config(_Config):
} }
} }
} }
} } if self.use_compression_training else None,
"zero_optimization": {
@property "stage": self.zero_optimization_level,
def zero_cfg(self):
if self.trainer.zero_optimization_level == 0:
return None
return {
"stage": self.trainer.zero_optimization_level,
"contiguous_gradients": True, "contiguous_gradients": True,
"overlap_comm": True, "overlap_comm": True,
"reduce_scatter": True, "reduce_scatter": True,
@ -396,38 +333,79 @@ class Config(_Config):
"device": "cpu", "device": "cpu",
"pin_memory": True "pin_memory": True
} }
} } if self.zero_optimization_level > 0 else None,
def get_ds_cfg(self, model):
cfg = {
"train_micro_batch_size_per_gpu": self.hyperparameters.batch_size,
"gradient_accumulation_steps": self.hyperparameters.gradient_accumulation_steps,
"optimizer": {
"type": self.hyperparameters.optimizer,
"params": {
"lr": self.hyperparameters.learning_rate,
}
},
"scheduler": self.hyperparameters.scheduler if self.hyperparameters.scheduler_type != "" else None,
"gradient_clipping": self.hyperparameters.gradient_clipping,
"fp16": self.fp16_cfg,
"bf16": self.bf16_cfg,
"compression_training": self.get_compression_cfg(model),
"zero_optimization": self.zero_cfg,
"comms_logger": { "comms_logger": {
"enabled": False "enabled": False
} }
} }
null_keys = [ k for k in cfg if not cfg[k] ] null_keys = [ k for k in ds_cfg if not ds_cfg[k] ]
for k in null_keys: for k in null_keys:
del cfg[k] del ds_cfg[k]
if os.path.exists("./config/ds_config.json"): if os.path.exists("./config/ds_config.json"):
ds_cfg = json.load(open("./config/ds_config.json", "r", encoding="utf-8")) ds_cfg.update(json.load(open("./config/ds_config.json", "r", encoding="utf-8")))
cfg.update(ds_cfg)
return cfg return ds_cfg
@dataclass()
class Trainer:
iterations: int = 100_000
save_tag: str = "step"
load_tag: str | None = None
save_on_oom: bool = True
save_on_quit: bool = True
save_frequency: int = 100
load_state_dict: bool = False
load_states: bool = True
strict_loading: bool = True
restart_step_count: bool = False
aggressive_optimizations: bool = False
gc_mode: str | None = None
weight_dtype: str = "float16"
backend: str = "deepspeed"
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
@dataclass()
class Inference:
use_vocos: bool = True
@dataclass()
class BitsAndBytes:
enabled: bool = False
injects: bool = False
linear: bool = False
embedding: bool = False
@dataclass()
class Config(_Config):
device: str = "cuda"
dataset: Dataset = field(default_factory=lambda: Dataset)
models: Models = field(default_factory=lambda: Models)
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
trainer: Trainer = field(default_factory=lambda: Trainer)
inference: Inference = field(default_factory=lambda: Inference)
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
@property
def sample_rate(self):
return 24_000
@cached_property
def get_spkr(self):
return eval(self.dataset.speaker_name_getter)
@property @property
def cache_dir(self): def cache_dir(self):
@ -455,6 +433,8 @@ cfg.trainer = Trainer(**cfg.trainer)
cfg.inference = Inference(**cfg.inference) cfg.inference = Inference(**cfg.inference)
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes) cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed)
# cached_property stopped working... # cached_property stopped working...
if cfg.dataset.use_hdf5: if cfg.dataset.use_hdf5:
try: try:

View File

@ -0,0 +1,9 @@
from ..config import cfg
if cfg.trainer.backend == "deepspeed":
from .deepspeed import Engine
elif cfg.trainer.backend == "local":
from .base import Engine
from .base import Engines
from .base import TrainFeeder

374
vall_e/engines/base.py Normal file
View File

@ -0,0 +1,374 @@
from torch import Tensor
from typing import Any, Protocol
Stats = dict[str, float]
class TrainFeeder(Protocol):
def __call__(
self, *, engine: "Engine", batch: Any
) -> None | tuple[Tensor, Stats]:
...
def default_feed(engine, batch):
if isinstance(batch, list):
engine( *batch )
elif isinstance(batch, dict):
engine( **batch )
else:
engine( batch )
losses = engine.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
return loss, stats
from ..config import cfg
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
import logging
import time
import torch
import torch.distributed
import os
from torch import Tensor
from torch.distributed import all_reduce
from typing import Any, Protocol
from .base import TrainFeeder
_logger = logging.getLogger(__name__)
# A very naive engine implementation using barebones PyTorch
class Engine():
def __init__(self, *args, **kwargs):
self.module = kwargs['model']
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
self.global_steps = 0
self.micro_steps = 0
self.gradient_accumulation_steps = cfg.hyperparameters.gradient_accumulation_steps
def freeze(self):
for p in self.module.parameters():
if p.requires_grad:
p.requires_grad_(False)
self._frozen_params.add(p)
def unfreeze(self):
for p in self._frozen_params:
p.requires_grad_(True)
self._frozen_params.clear()
@property
def global_step(self):
return self.global_steps
@property
def micro_step(self):
return self.micro_steps
def train_batch_size(self):
return cfg.hyperparameters.batch_size
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
def dispatch_attribute(self, *args, **kwargs):
return dispatch_attribute(self.module, *args, **kwargs)
def save_checkpoint(self, save_dir, tag ):
torch.save({
"global_step": self.global_step,
"micro_step": self.micro_step,
"module": self.module.state_dict(),
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
}, save_dir / tag / "state.pth")
def load_checkpoint(self, load_dir, tag, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
state = torch.load(load_dir / tag / "state.pth")
self.global_step = state['global_step']
self.micro_step = state['micro_step']
self.module.load_state_dict(state['module'])
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
if load_optimizer_states:
self.optimizer.load_state_dict(state['optimizer'])
if load_lr_scheduler_states:
self.lr_scheduler.load_state_dict(state['lr_scheduler'])
def eval(self):
return self.module.eval()
def train(self):
return self.module.train()
def to(self, *args, **kwargs):
self.module = self.module.to(*args, **kwargs)
return self.module
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def forward(self, *args, **kwargs):
return self.module.forward(*args, **kwargs)
def backward(self, loss):
return (loss / self.gradient_accumulation_steps).backward()
def step(self):
with torch.set_grad_enabled(self.gradient_accumulation_steps > 1):
self.micro_steps += 1
if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0:
self.global_steps += 1
self.optimizer.step()
self.optimizer.zero_grad()
def get_lr(self):
lrs = []
for param_group in self.optimizer.param_groups:
if 'lr' in param_group:
lrs.append(param_group['lr'])
return lrs
def set_lr(self, lr):
for param_group in self.optimizer.param_groups:
if 'lr' in param_group:
param_group['lr'] = lr
def get_global_grad_norm(self):
return 0.0
def traverse(self, *args, **kwargs):
self.forward(*args, **kwargs)
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")
self.backward(loss)
self.step()
return stats
# and now to ignore everything from the above
class Engines(dict[str, Engine]):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setup()
def setup(self):
self._global_step = 0
self._micro_step = 0
@property
def global_step(self):
return self._global_step
@property
def micro_step(self):
return self._micro_step
def gather_attribute(self, *args, **kwargs):
ret = {}
for engine in self.values():
ret |= engine.gather_attribute(*args, **kwargs)
return ret
def dispatch_attribute(self, *args, **kwargs):
for engine in self.values():
engine.dispatch_attribute(*args, **kwargs)
def save_checkpoint(self, tag=None):
if not tag:
tag = cfg.trainer.save_tag
tag = tag.lower()
if tag[:2] == "it" or tag[:4] == "step":
tag = f'{self.global_step}'
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
for name, engine in self.items():
engine.save_checkpoint(cfg.ckpt_dir / name, tag=tag)
def load_checkpoint(self, tag=None):
if not tag:
tag = cfg.trainer.load_tag
for name, engine in self.items():
load_dir = cfg.ckpt_dir / name
engine.load_checkpoint(
tag=tag,
load_dir=load_dir,
load_module_strict=cfg.trainer.strict_loading,
load_optimizer_states=cfg.trainer.load_states,
load_lr_scheduler_states=cfg.trainer.load_states,
)
if cfg.trainer.restart_step_count:
engine.global_steps = 0
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
if cfg.hyperparameters.scheduler_type == "":
self.set_lr(cfg.hyperparameters.learning_rate)
self._update_global_step()
self._update_micro_step()
def set_lr(self, lr):
for engine in self.values():
engine.set_lr(lr)
def _update_global_step(self):
for engine in self.values():
self._global_step = max(self._global_step, engine.global_step)
def _update_micro_step(self):
for engine in self.values():
self._micro_step = max(self._micro_step, engine.micro_step)
def train_batch_size(self):
batch_size = 0
for engine in self.values():
batch_size = max(batch_size, engine.train_batch_size())
def eval(self):
for engine in self.values():
engine.eval()
def train(self):
for engine in self.values():
engine.train()
def traverse(self):
stats = {}
for name, engine in self.items():
stat = engine.traverse()
stats.update(flatten_dict({ name.split("-")[0]: stat }))
return stats
def step(self, batch, feeder: TrainFeeder = default_feed, device=torch.cuda.current_device()):
total_elapsed_time = 0
stats: Any = dict()
if cfg.trainer.gc_mode == 'step':
do_gc()
batch = to_device(batch, device)
for name, engine in self.items():
#torch.cuda.synchronize()
if cfg.trainer.gc_mode == 'substep':
do_gc()
start_time = time.time()
tries = 4
n_ooms = torch.zeros([], device=cfg.device)
if cfg.trainer.aggressive_optimizations:
batch = to_device(batch, device)
res = feeder( engine=engine, batch=batch )
"""
while tries >= 0:
try:
res = feeder( engine=engine, batch=batch )
break
except RuntimeError as e:
print("Forward", str(e))
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
# shrink batch size until it's happy
for k in batch:
batch[k] = batch[k][:-1]
if tries <= 0:
# trigger OOM
n_ooms += 1
else:
# also do GC
do_gc()
continue
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during forward pass!")
"""
if res is None:
continue
loss, engine_stats = res
engine_stats |= self.gather_attribute("scalar")
n_ooms = torch.zeros([], device=cfg.device)
if cfg.trainer.aggressive_optimizations:
batch = to_device(batch, 'cpu')
engine.backward(loss)
"""
try:
engine.backward(loss)
except RuntimeError as e:
print("Backwards:", str(e))
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
n_ooms += 1
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during backwards pass!")
"""
engine.step()
#torch.cuda.synchronize()
elapsed_time = time.time() - start_time
total_elapsed_time += elapsed_time
stats.update(
flatten_dict(
{
name.split("-")[0]: dict(
loss=loss.item(),
lr=engine.get_lr()[0],
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
elapsed_time=elapsed_time,
engine_step=engine.global_step,
**engine_stats,
)
}
),
)
self._update_global_step()
self._update_micro_step()
stats["batch_size"] = self.train_batch_size() # len(batch["text"])
stats["elapsed_time"] = total_elapsed_time
stats["wall_time"] = time.time()
stats["global_step"] = self.global_step
return stats

View File

@ -0,0 +1,89 @@
"""
# https://github.com/enhuiz/pytorch-training-utilities
"""
# to-do: replace this
# to-do: swap out deepspeed
from ..config import cfg
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
import logging
import time
import torch
import torch.distributed
from torch import Tensor
from torch.distributed import all_reduce
from typing import Any, Protocol
from .base import TrainFeeder
_logger = logging.getLogger(__name__)
from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distributed
from deepspeed.accelerator import get_accelerator
#dist.init_distributed(dist_backend=get_accelerator().communication_backend_name())
initialized_dist = False
if not initialized_dist:
initialized_dist = True
init_distributed()
class Engine(DeepSpeedEngine):
def __init__(self, *args, **kwargs):
kwargs['config'] = cfg.trainer.deepspeed.get_ds_cfg(model=kwargs['model'])
kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
super().__init__(None, *args, **kwargs)
self._frozen_params = set()
def freeze(self):
for p in self.module.parameters():
if p.requires_grad:
p.requires_grad_(False)
self._frozen_params.add(p)
def unfreeze(self):
for p in self._frozen_params:
p.requires_grad_(True)
self._frozen_params.clear()
@property
def global_step(self):
return self.global_steps
@property
def micro_step(self):
return self.micro_steps
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
def dispatch_attribute(self, *args, **kwargs):
return dispatch_attribute(self.module, *args, **kwargs)
def set_lr(self, lr):
try:
if hasattr(self.optimizer, 'param_groups'):
print(self.optimizer.param_groups)
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
else:
self.optimizer.set_lr(lr)
except Exception as e:
print(str(e))
def traverse(self, *args, **kwargs):
self.forward(*args, **kwargs)
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")
self.backward(loss)
self.step()
return stats

View File

@ -72,8 +72,8 @@ class TTS():
prom = to_device(prom, self.device).to(torch.int16) prom = to_device(prom, self.device).to(torch.int16)
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16) phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
resp_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp) resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp)
resps_list = [r.unsqueeze(-1) for r in resp_list] resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp) resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp)
wav, sr = qnt.decode_to_file(resps_list[0], out_path) wav, sr = qnt.decode_to_file(resps_list[0], out_path)

View File

@ -1,20 +1,22 @@
from .ar import AR from .ar import AR
from .nar import NAR from .nar import NAR
def get_model(model): def get_model(cfg):
if model.name == "ar": if cfg.name == "ar":
Model = AR Model = AR
elif model.name == "nar": elif cfg.name == "nar":
Model = NAR Model = NAR
else: else:
raise f"invalid model name: {model.name}" raise f"invalid model name: {cfg.name}"
name = model.name name = cfg.name
model = Model( model = Model(
n_tokens=model.tokens, n_tokens=model.tokens,
d_model=model.dim, d_model=model.dim,
n_heads=model.heads, n_heads=model.heads,
n_layers=model.layers, n_layers=model.layers,
) )
model._cfg = cfg
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

View File

@ -50,112 +50,80 @@ class AR(Base):
self, self,
text_list: list[Tensor], text_list: list[Tensor],
proms_list: list[Tensor], proms_list: list[Tensor],
resp_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None,
max_steps: int = 1000, max_steps: int = 1000,
sampling_temperature: float = 1.0, sampling_temperature: float = 1.0,
naive: bool = True, naive: bool = True,
): ):
if resp_list is not None: if resps_list is not None:
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
return super().forward( return super().forward(
text_list, text_list=text_list,
proms_list, proms_list=proms_list,
self._unsqueeze_list(resp_list), resps_list=self._unsqueeze_list(resps_list),
resp_list, targ_list=resps_list,
quant_levels=None, quant_levels=None,
shift_targ_list=True, shift_targ_list=True,
return_all_resp=False, return_all_resp=False,
) )
device = text_list[0].device device = text_list[0].device
resp_list: list[Tensor] = [ resps_list: list[Tensor] = [
torch.zeros(0, device=device).to(torch.int16) for _ in text_list torch.zeros(0, device=device).to(torch.int16) for _ in text_list
] ]
stopped = torch.zeros(len(text_list), device=device).bool() stopped = torch.zeros(len(text_list), device=device).bool()
if self.arch_type == "transformer":
naive = True
chunk_size = 1 # don't really know what to do about this desu chunk_size = 1 # don't really know what to do about this desu
state = None state = None
start = 0 start = 0
# prefill
if self.arch_type == "retnet/local":
# pre-process
state = [
[
torch.zeros(self.retnet.hidden_dim // self.retnet.heads, self.retnet.v_dim // self.retnet.heads, device=device).unsqueeze(0).repeat(len(text_list), 1, 1)
for _ in range(self.retnet.heads)
] for _ in range(self.retnet.layers)
]
resps_list = self._unsqueeze_list(resp_list)
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.proms_emb(proms_list),
self.resps_emb(resps_list),
sep=self.sep,
)
x, m = list_to_tensor(x_list)
start = x.shape[1]
for i in trange(start-1):
_, state = self.retnet.forward_recurrent( x[:, i:i+1, :], state, i+1 )
for n in trange(max_steps // chunk_size): for n in trange(max_steps // chunk_size):
# get next in sequence # get next in sequence
r, state = super().forward( r, state = super().forward(
text_list, text_list,
proms_list, proms_list,
self._unsqueeze_list(resp_list), self._unsqueeze_list(resps_list),
sampling_temperature=sampling_temperature, sampling_temperature=sampling_temperature,
state=state, state=state if not naive else None,
) )
# append outputted token # append outputted token
for i, ri in enumerate(r): for i, ri in enumerate(r):
resp_list[i] = torch.cat([resp_list[i], ri[None]]) resps_list[i] = torch.cat([resps_list[i], ri[None]])
# stop token found # stop token found
stopped |= r == self.stop_token stopped |= r == self.stop_token
if stopped.all().item(): if stopped.all().item():
break break
pruned = [self._prune(r) for r in resp_list] pruned = [self._prune(r) for r in resps_list]
return pruned return pruned
def example_usage(): def example_usage():
cfg.trainer.backend = "local"
from functools import partial from functools import partial
from einops import repeat from einops import repeat
from ..emb.qnt import decode_to_file from ..emb.qnt import decode_to_file
from ..utils import gather_attribute from ..engines import Engine
from tqdm import tqdm
device = "cpu" device = "cpu"
x8 = partial(repeat, pattern="t -> t l", l=2)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"): def tokenize(content, lang_marker="en"):
split = content.split(" ") split = content.split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"] phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
return torch.tensor([*map(symmap.get, phones)]).to() return torch.tensor([*map(symmap.get, phones)]).to()
qnt = torch.load("data/qnt.pt")[0, 0].to(device) qnt = torch.load("data/qnt.pt")[0].t()[:, :2].to(device)
kwargs = {
'n_tokens': 1024,
'd_model': 1024,
'n_heads': 16,
'n_layers': 12,
}
model = AR(**kwargs).to(device)
x8 = partial(repeat, pattern="t -> t l", l=2)
text_list = [ text_list = [
#torch.tensor([1, 2, 3], device=device), #torch.tensor([1, 2, 3], device=device),
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device), tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
@ -164,41 +132,41 @@ def example_usage():
x8(torch.tensor([1, 2, 3], device=device)), x8(torch.tensor([1, 2, 3], device=device)),
#qnt.to(device), #qnt.to(device),
] ]
resp_list = [ resps_list = [
qnt.to(device), qnt.to(device),
] ]
text_list = text_list[:1] text_list = text_list[:1]
proms_list = proms_list[:1] proms_list = proms_list[:1]
resp_list = resp_list[:1] resps_list = resps_list[:1]
model.eval() kwargs = {
out = model(text_list, proms_list, max_steps=75)[0] 'n_tokens': 1024,
print("qnt:", qnt.shape, qnt) 'd_model': 1024,
print("out:", out.shape, out) 'n_heads': 16,
wav, sr = decode_to_file(out, "data/test/test.ar.init.wav", device=device) 'n_layers': 12,
}
model = AR(**kwargs).to(device)
engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4))
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) def sample( name, steps=400 ):
engine.eval()
model.train() out = engine(text_list, proms_list, max_steps=steps)
for i in trange(60):
optimizer.zero_grad()
_ = model(text_list, proms_list, resp_list)
losses = gather_attribute(model, "loss")
loss = sum(losses.values())
loss.backward()
optimizer.step()
if i % 20 == 0:
print(f"iter={i}, {losses}.")
model.eval()
out = model(text_list, proms_list, max_steps=400)
print("qnt:", qnt.shape, qnt)
for i, o in enumerate(out): for i, o in enumerate(out):
print("out:", i, o.shape, o) wav, sr = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)
wav, sr = decode_to_file(o, f"data/test/test.ar.{i}.wav", device=device)
def train():
engine.train()
t = trange(60)
for i in t:
stats = {"step": i}
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
t.set_description(f"{stats}")
sample("init", 75)
train()
sample("final")
if __name__ == "__main__": if __name__ == "__main__":
example_usage() example_usage()

View File

@ -176,17 +176,7 @@ class Base(nn.Module):
self.retnet = RetNetDecoder( self.retnet = RetNetDecoder(
self.retnet_config self.retnet_config
) )
elif self.arch_type == "retnet/local":
self.retnet = RetNet(
layers=n_layers,
hidden_dim=d_model,
ffn_size=d_model * 4,
heads=n_heads,
dropout=p_dropout,
norm_type=self.norm_type,
n_levels=self.n_resp_levels,
double_v_dim=True
)
self.classifier = nn.Linear(d_model, n_resp_tokens) self.classifier = nn.Linear(d_model, n_resp_tokens)
self.accuracy_metric = MulticlassAccuracy( self.accuracy_metric = MulticlassAccuracy(
@ -272,7 +262,7 @@ class Base(nn.Module):
text_list: [t] * b text_list: [t] * b
proms_list: [t' l] * b, l quantization levels. proms_list: [t' l] * b, l quantization levels.
resps_list: [t'' l] * b, l quantization levels. resps_list: [t'' l] * b, l quantization levels.
targ_list: [t''] * b, one quantization level only, when given, loss will be computed targ_list: [t''] * b, one quantization level only; when given, loss will be computed
quant_levels: specify which quant_levels to feed forward, used in NAR mode. quant_levels: specify which quant_levels to feed forward, used in NAR mode.
shift_targ_list: whether to shift target list when computing loss. True if AR. shift_targ_list: whether to shift target list when computing loss. True if AR.
return_all_resp: True if NAR. return_all_resp: True if NAR.
@ -298,24 +288,12 @@ class Base(nn.Module):
elif self.arch_type == "retnet": elif self.arch_type == "retnet":
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True) x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
state = self.retnet.get_incremental_state( state, 'prev_state' ) state = self.retnet.get_incremental_state( state, 'prev_state' )
elif self.arch_type == "retnet/local":
# recurrent inferencing
if self.causal and state is not None:
last = x.shape[1]
x, state = self.retnet.forward_recurrent(
x[:, last-1:last, :], # nasty way to grab the last embedding to forward
state,
last
)
else:
x = self.retnet( x, quant_levels )
x = self.classifier(x) * m x = self.classifier(x) * m
# Remove padding # Remove padding
h_list = [hi[:li] for hi, li in zip(x, map(len, x_list))] h_list = [hi[:li] for hi, li in zip(x, map(len, x_list))]
# compute loss if the target is given # compute loss if the target is given
if targ_list is not None: if targ_list is not None:
if any([l == 0 for l in map(len, targ_list)]): if any([l == 0 for l in map(len, targ_list)]):
@ -337,20 +315,24 @@ class Base(nn.Module):
# the NAR doesn't need to compute the loss for it # the NAR doesn't need to compute the loss for it
if self.resp_loss_only: if self.resp_loss_only:
text_prom_list[i][:] = self.ignore_index text_prom_list[i][:] = self.ignore_index
# roll the text/prompt for loss computing # roll the text/prompt for loss computing
# the AR benefits from this # the AR benefits from this, for some reason I'll figure out later
else: else:
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0) text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
text_prom_list[i][-1] = self.ignore_index text_prom_list[i][-1] = self.ignore_index
# necessary to roll the target if recurrently/causally/autoregressively generating, or it won't be able to work # for the AR, roll by one and mark the ending with a stop token
# this coerces the model into properly inferencing causally
# why we don't just append a stop token in the dataloader, who knows
if shift_targ_list: if shift_targ_list:
targ_list = [*targ_list] targ_list = [*targ_list]
for i in range(len(targ_list)): for i in range(len(targ_list)):
targ_list[i] = targ_list[i].roll(-1, dims=0) targ_list[i] = targ_list[i].roll(-1, dims=0)
targ_list[i][-1] = self.stop_token targ_list[i][-1] = self.stop_token
# generate the sequence # create the new target sequence to compute the loss against
y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep )
self.loss = dict( self.loss = dict(
@ -368,6 +350,7 @@ class Base(nn.Module):
del text_prom_list del text_prom_list
del y_list del y_list
# return the entire generated token string # return the entire generated token string
if return_all: if return_all:
logits = [hi[:] for hi, li in zip(h_list, map(len, resps_list))] logits = [hi[:] for hi, li in zip(h_list, map(len, resps_list))]
@ -387,124 +370,90 @@ class Base(nn.Module):
return ret, state return ret, state
def example_usage(): def example_usage():
from ..config import cfg
cfg.trainer.backend = "local"
from functools import partial from functools import partial
from einops import repeat from einops import repeat
from tqdm import trange
from ..utils import gather_attribute
from ..emb.qnt import decode_to_file from ..emb.qnt import decode_to_file
from ..engines import Engine, Engines
from tqdm import tqdm, trange
from .ar import AR from .ar import AR
from .nar import NAR from .nar import NAR
device = "cpu"
x8 = partial(repeat, pattern="t -> t l", l=2)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"): def tokenize(content, lang_marker="en"):
split = content.split(" ") split = content.split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"] phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
return torch.tensor([*map(symmap.get, phones)]).to() return torch.tensor([*map(symmap.get, phones)]).to()
device = "cpu"
kwargs = { kwargs = {
'n_tokens': 1024, 'n_tokens': 1024,
'd_model': 1024, 'd_model': 1024,
'n_heads': 16, 'n_heads': 16,
'n_layers': 12, 'n_layers': 12,
} }
model_ar = AR(**kwargs).to(device) models = { "ar": AR(**kwargs).to(device), "nar": NAR(**kwargs).to(device) }
model_nar = NAR(**kwargs).to(device) engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
train = True train = True
if train:
qnt = torch.load("data/qnt.pt").to(device) qnt = torch.load("data/qnt.pt")[0].t()[:, :2].to(device)
text_list = [ text_list = [
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device), tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
#tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device), #tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device),
] ]
x8 = partial(repeat, pattern="t -> t l", l=2)
proms_list = [
qnt[0][:2,:].t().to(device),
#x8(torch.tensor([1, 2, 3], device=device)),
# x8(torch.tensor([2, 3], device=device)),
]
resp_list_ar = [
qnt[0,0].to(device),
# qnt[0,0].to(device),
]
resp_list_nar = [
qnt[0][:2,:].t().to(device),
# qnt[0][:2,:].t().to(device),
]
model_ar.train()
optimizer = torch.optim.AdamW(model_ar.parameters(), lr=1e-4)
for i in trange(60):
optimizer.zero_grad()
_ = model_ar(text_list, proms_list, resp_list_ar)
losses = gather_attribute(model_ar, "loss")
loss = sum(losses.values())
loss.backward()
optimizer.step()
if i % 20 == 0:
print(f"iter={i}, {losses}.")
model_nar.train()
optimizer = torch.optim.AdamW(model_nar.parameters(), lr=1e-4)
for i in trange(60):
optimizer.zero_grad()
_ = model_nar(text_list, proms_list, resps_list=resp_list_nar)
losses = gather_attribute(model_nar, "loss")
loss = sum(losses.values())
loss.backward()
optimizer.step()
if i % 20 == 0:
stats = {k: v.item() for k, v in losses.items()}
stats["loss"] = loss.item()
print(f"iter={i}, {stats}.")
else:
qnt = torch.load("data/test/test.qnt.pt")[0][:2,:].t().to(device)
text_list = [
#tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device),
]
proms_list = [ proms_list = [
qnt.to(device), qnt.to(device),
] ]
model_ar.load_state_dict(torch.load("data/test/ar.pth")) resps_list = [
model_nar.load_state_dict(torch.load("data/test/nar.pth")) qnt.to(device),
]
model_ar.eval() def sample( name, steps=400 ):
resp_list = model_ar(text_list, proms_list, max_steps=300, sampling_temperature=1.0) AR = None
resps_list = [r.unsqueeze(-1) for r in resp_list] NAR = None
print("qnt:", qnt.shape, qnt) engines.eval()
print("out:", resp_list[0].shape, resp_list[0]) for name, engine in engines.items():
wav, sr = decode_to_file(resp_list[0], "data/test/test.ar.init.wav", device=device) if name[:2] == "ar":
print(wav, sr) AR = engine
elif name[:3] == "nar":
NAR = engine
model_nar.eval() resps_list = AR(text_list, proms_list, max_steps=steps, sampling_temperature=1.0)
codes = model_nar( resps_list = [r.unsqueeze(-1) for r in resps_list]
text_list, codes = NAR( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
proms_list,
resps_list=resps_list,
sampling_temperature=1.0,
)[0]
decode_to_file(resps_list[0], f"./data/ar.{name}.wav", device=device)
decode_to_file(codes[0], f"./data/ar+nar.{name}.wav", device=device)
print("qnt:", qnt.shape, qnt) if train:
print("codes:", codes.shape, codes) sample("init", 15)
engines.train()
t = trange(60)
for i in t:
"""
stats = {"step": i}
for name, engine in engines.items():
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
"""
stats = engines.step({"text_list": text_list, "proms_list": proms_list, "resps_list": resps_list}, device="cpu")
t.set_description(f"{stats}")
else:
for name, engine in engines.items():
engine.module.load_state_dict(torch.load(f"./data/{name}.pth"))
sample("final")
wav, sr = decode_to_file(codes, "data/test/test.ar+nar.init.wav", device=device)
print(wav, sr)
if __name__ == "__main__": if __name__ == "__main__":
example_usage() example_usage()

View File

@ -80,7 +80,6 @@ class NAR(Base):
quant_levels=quant_levels, quant_levels=quant_levels,
) )
# Yes, just nothing as we are training
prev_list = [] prev_list = []
else: else:
prev_list = resps_list prev_list = resps_list
@ -93,7 +92,7 @@ class NAR(Base):
quant_levels = torch.full((len(text_list),), level, device=device) quant_levels = torch.full((len(text_list),), level, device=device)
resp_list, _ = super().forward( resps_list, _ = super().forward(
text_list, text_list,
proms_list, proms_list,
prev_list, prev_list,
@ -105,24 +104,22 @@ class NAR(Base):
prev_list = [ prev_list = [
torch.cat([rs, r.unsqueeze(-1)], dim=-1) torch.cat([rs, r.unsqueeze(-1)], dim=-1)
for rs, r in zip(prev_list, resp_list) for rs, r in zip(prev_list, resps_list)
] ]
return prev_list return prev_list
def example_usage(): def example_usage():
cfg.trainer.backend = "local"
from functools import partial from functools import partial
from pathlib import Path
from einops import repeat from einops import repeat
from ..emb.qnt import decode_to_file from ..emb.qnt import decode_to_file
from ..utils import gather_attribute from ..engines import Engine
from ..config import cfg from tqdm import tqdm
device = "cpu" device = "cpu"
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels) x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"): def tokenize(content, lang_marker="en"):
@ -130,15 +127,8 @@ def example_usage():
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"] phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
return torch.tensor([*map(symmap.get, phones)]).to() return torch.tensor([*map(symmap.get, phones)]).to()
resps = torch.load("data/qnt.pt")[0][:2, :].to(device) # to-do: unmangle this and the resp shit
kwargs = { qnt = torch.load("data/qnt.pt")[0].t()[:, :2].to(device)
'n_tokens': 1024,
'd_model': 1024,
'n_heads': 16,
'n_layers': 12,
}
model = NAR(**kwargs).to(device)
text_list = [ text_list = [
#torch.tensor([1, 2, 3], device=device), #torch.tensor([1, 2, 3], device=device),
@ -149,65 +139,36 @@ def example_usage():
x8(torch.tensor([2, 3], device=device)), x8(torch.tensor([2, 3], device=device)),
] ]
resps_x1_list = [
resps[:1].t().to(device),
]
resps_x8_list = [
resps.t().to(device),
]
model.eval()
codes = model(
text_list,
proms_list,
resps_list=resps_x1_list,
sampling_temperature=0.2,
)[0]
decode_to_file(
codes,
Path("data/test/test.nar.init.wav"),
device
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
model.train()
for i in trange(50):
optimizer.zero_grad()
_ = model(text_list, proms_list, resps_list=resps_x8_list)
losses = gather_attribute(model, "loss")
loss = sum(losses.values())
loss.backward()
optimizer.step()
if i % 20 == 0:
stats = {k: v.item() for k, v in losses.items()}
stats["loss"] = loss.item()
print(f"iter={i}, {stats}.")
model.eval()
for i in trange(1, 2): # cfg.models.prom_levels):
resps_list = [ resps_list = [
resps[:i].t().to(device), qnt.to(device),
] ]
codes = model( kwargs = {
text_list, 'n_tokens': 1024,
proms_list, 'd_model': 1024,
resps_list=resps_list, 'n_heads': 16,
sampling_temperature=0.2, 'n_layers': 12,
)[0] }
model = NAR(**kwargs).to(device)
engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4))
decode_to_file( def sample( name ):
codes, engine.eval()
Path(f"data/test/test.nar.1-{i}.wav"), codes = engine( text_list, proms_list, resps_list=[r[..., 0].unsqueeze(-1) for r in resps_list], sampling_temperature=0.2 )
device decode_to_file( codes[0], f"data/nar.{name}.wav", device )
)
def train():
engine.train()
t = trange(60)
for i in t:
stats = {"step": i}
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
t.set_description(f"{stats}")
sample("init")
train()
sample("final")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,17 +1,13 @@
# todo: clean this mess up # todo: clean this mess up
# todo: yank deepspeed dependent code out into its own thing
from .config import cfg from .config import cfg
from .data import create_train_val_dataloader from .data import create_train_val_dataloader
from .emb import qnt from .emb import qnt
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
from .utils import wrapper as ml
from .models import get_models
import auraloss import auraloss
import deepspeed
import json import json
import logging import logging
import random import random
@ -21,10 +17,6 @@ import traceback
from collections import defaultdict from collections import defaultdict
from deepspeed import comm as dist
from deepspeed import DeepSpeedConfig
from deepspeed.accelerator import get_accelerator
from tqdm import tqdm from tqdm import tqdm
mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cuda") mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cuda")
@ -38,90 +30,21 @@ def left_crop(x, len):
return x[..., 0:len] return x[..., 0:len]
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
deepspeed._initialized_dist = False
def load_engines(): def train_feeder(engine, batch):
if not deepspeed._initialized_dist: engine( text_list=batch["text"], proms_list=batch["proms"], resps_list=batch["resps"] )
deepspeed._initialized_dist = True
deepspeed.init_distributed()
models = get_models(cfg.models.get()) losses = engine.gather_attribute("loss")
engines = dict()
for name in models:
model = models[name]
optimizer = None
lr_scheduler = None
Adam = ml.Adam
AdamW = ml.AdamW
if cfg.hyperparameters.optimizer.lower() == "adamw-torch":
optimizer = AdamW(
model.parameters(),
lr=cfg.hyperparameters.learning_rate,
betas=(0.9, 0.96),
eps=1e-07,
weight_decay=0.01,
)
if cfg.trainer.load_state_dict:
load_dir = cfg.ckpt_dir / name / "fp32.pth"
model.load_state_dict(torch.load(load_dir))
ds_cfg=cfg.get_ds_cfg(model=model)
config_class=DeepSpeedConfig(ds_cfg)
engines[name] = trainer.Engine(
model=model,
config=ds_cfg,
config_class=config_class,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
return trainer.load_engines(engines, cfg)
def main():
setup_logging(cfg.log_dir)
#dist.init_distributed(dist_backend=get_accelerator().communication_backend_name())
if not deepspeed._initialized_dist:
deepspeed._initialized_dist = True
deepspeed.init_distributed()
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
def train_feeder(engines, batch, name):
stats = {}
model = engines[name]
if name.startswith("ar"):
_ = model(
text_list=batch["text"],
proms_list=batch["proms"],
resp_list=[r[..., 0] for r in batch["resps"]],
)
elif name.startswith("nar"):
_ = model(
text_list=batch["text"],
proms_list=batch["proms"],
resps_list=batch["resps"],
)
else:
raise NotImplementedError(name)
losses = model.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum() loss = torch.stack([*losses.values()]).sum()
stats = {} stats = {}
stats |= {k: v.item() for k, v in losses.items()} stats |= {k: v.item() for k, v in losses.items()}
stats |= engines.gather_attribute("scalar")
return loss, stats return loss, stats
@torch.inference_mode() @torch.inference_mode()
def run_eval(engines, eval_name, dl): def run_eval(engines, eval_name, dl):
engines_stats = { engines_stats = {
'eval': eval_name 'eval': eval_name
} }
@ -130,83 +53,17 @@ def main():
NAR = None NAR = None
names = [] names = []
for name in engines: for name, engine in engines.items():
model = engines[name]
names.append(name) names.append(name)
if name[:2] == "ar": if name[:2] == "ar":
AR = model AR = engine
elif name[:3] == "nar": elif name[:3] == "nar":
NAR = model NAR = engine
stats = defaultdict(list) stats = defaultdict(list)
stats['loss'] = [] stats['loss'] = []
for batch in tqdm(dl): def process( name, batch, resps_list ):
batch: dict = to_device(batch, cfg.device)
# if we're training both models, provide output for both
if AR is not None and NAR is not None:
name = "+".join(names)
resp_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.temperature)
resps_list = [ r.unsqueeze(-1) for r in resp_list ]
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.temperature)
for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]):
if len(hyp) == 0:
continue
filename = f'{speaker}_{path.parts[-1]}'
# to-do, refine the output dir to be sane-er
ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")
prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav")
hyp_path.parent.mkdir(parents=True, exist_ok=True)
ref_path.parent.mkdir(parents=True, exist_ok=True)
prom_path.parent.mkdir(parents=True, exist_ok=True)
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
ref_audio = ref_audio[..., 0:min_length]
hyp_audio = hyp_audio[..., 0:min_length]
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
else:
for name in engines:
model = engines[name]
if name.startswith("ar"):
resp_list = model(
text_list=batch["text"],
proms_list=batch["proms"],
max_steps=cfg.evaluation.steps,
sampling_temperature=cfg.evaluation.temperature,
)
resps_list = [r.unsqueeze(-1) for r in resp_list]
elif name.startswith("nar"):
resps_list = model(
text_list=batch["text"],
proms_list=batch["proms"],
resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]],
sampling_temperature=cfg.evaluation.temperature,
)
else:
raise NotImplementedError(name)
losses = model.gather_attribute("loss")
batch_stats = {}
batch_stats |= {k: v.item() for k, v in losses.items()}
batch_stats |= engines.gather_attribute("scalar")
for k, v in batch_stats.items():
stats[k].append(v)
for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]): for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]):
if len(hyp) == 0: if len(hyp) == 0:
continue continue
@ -231,7 +88,46 @@ def main():
ref_audio = ref_audio[..., 0:min_length] ref_audio = ref_audio[..., 0:min_length]
hyp_audio = hyp_audio[..., 0:min_length] hyp_audio = hyp_audio[..., 0:min_length]
try:
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item()) stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
except Exception as e:
print(str(e))
for batch in tqdm(dl):
batch: dict = to_device(batch, cfg.device)
# if we're training both models, provide output for both
if AR is not None and NAR is not None:
name = "+".join(names)
resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.temperature)
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.temperature)
process( name, batch, resps_list )
else:
for name in engines:
model = engines[name]
if name.startswith("ar"):
resps_list = model(
text_list=batch["text"],
proms_list=batch["proms"],
max_steps=cfg.evaluation.steps,
sampling_temperature=cfg.evaluation.ar_temperature,
)
resps_list = [r.unsqueeze(-1) for r in resps_list]
elif name.startswith("nar"):
resps_list = model(
text_list=batch["text"],
proms_list=batch["proms"],
resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]],
sampling_temperature=cfg.evaluation.nar_temperature,
)
else:
raise NotImplementedError(name)
process( name, batch, resps_list )
stats = {k: sum(v) / len(v) for k, v in stats.items()} stats = {k: sum(v) / len(v) for k, v in stats.items()}
engines_stats.update(flatten_dict({ name: stats })) engines_stats.update(flatten_dict({ name: stats }))
@ -242,6 +138,12 @@ def main():
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.") _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
def main():
setup_logging(cfg.log_dir)
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
def eval_fn(engines): def eval_fn(engines):
try: try:
run_eval(engines, "subtrain", subtrain_dl) run_eval(engines, "subtrain", subtrain_dl)
@ -256,12 +158,10 @@ def main():
qnt.unload_model() qnt.unload_model()
trainer.train( trainer.train(
engines_loader=load_engines,
train_dl=train_dl, train_dl=train_dl,
train_feeder=train_feeder, train_feeder=train_feeder,
eval_fn=eval_fn, eval_fn=eval_fn,
) )
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,252 +0,0 @@
"""
# https://github.com/enhuiz/pytorch-training-utilities
"""
# to-do: replace this
# to-do: swap out deepspeed
from ..config import Config
from .distributed import fix_unset_envs
from .utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
import logging
import time
import torch
import torch.distributed
from deepspeed import DeepSpeedEngine
from torch import Tensor
from torch.distributed import all_reduce
from typing import Any, Protocol
Stats = dict[str, float]
_logger = logging.getLogger(__name__)
class Engine(DeepSpeedEngine):
def __init__(self, *args, **kwargs):
fix_unset_envs()
super().__init__(None, *args, **kwargs)
self._frozen_params = set()
def freeze(self):
for p in self.module.parameters():
if p.requires_grad:
p.requires_grad_(False)
self._frozen_params.add(p)
def unfreeze(self):
for p in self._frozen_params:
p.requires_grad_(True)
self._frozen_params.clear()
@property
def global_step(self):
return self.global_steps
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
def dispatch_attribute(self, *args, **kwargs):
return dispatch_attribute(self.module, *args, **kwargs)
class TrainFeeder(Protocol):
def __call__(
self, *, engines: "Engines", batch: Any, name: str
) -> None | tuple[Tensor, Stats]:
...
class Engines(dict[str, Engine]):
def setup(self, cfg: Config):
self._cfg = cfg
self._global_step = 0
@property
def cfg(self) -> Config:
return self._cfg
@property
def config(self):
return self._cfg
@property
def global_step(self):
return self._global_step
def gather_attribute(self, *args, **kwargs):
ret = {}
for engine in self.values():
ret |= engine.gather_attribute(*args, **kwargs)
return ret
def dispatch_attribute(self, *args, **kwargs):
for engine in self.values():
engine.dispatch_attribute(*args, **kwargs)
def save_checkpoint(self, tag=None):
if not tag:
tag = self.cfg.trainer.save_tag
tag = tag.lower()
if tag[:2] == "it" or tag[:4] == "step":
tag = self.global_step
self.cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
for name, engine in self.items():
engine.save_checkpoint(self.cfg.ckpt_dir / name, tag=tag)
def load_checkpoint(self, tag=None):
if not tag:
tag = self.cfg.trainer.load_tag
for name, engine in self.items():
load_dir = self.cfg.ckpt_dir / name
engine.load_checkpoint(
tag=tag,
load_dir=load_dir,
load_module_strict=self.cfg.trainer.strict_loading,
load_optimizer_states=self.cfg.trainer.load_states,
load_lr_scheduler_states=self.cfg.trainer.load_states,
load_module_only=False, # not self.cfg.trainer.load_states,
)
if self.cfg.trainer.restart_step_count:
engine.global_steps = 0
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
if self.cfg.hyperparameters.scheduler_type == "":
self.set_lr(self.cfg.hyperparameters.learning_rate)
self._update_global_step()
def set_lr(self, lr):
try:
for engine in self.values():
if hasattr(engine.optimizer, 'param_groups'):
print(engine.optimizer.param_groups)
for param_group in engine.optimizer.param_groups:
param_group['lr'] = lr
else:
engine.optimizer.set_lr(lr)
except Exception as e:
print(str(e))
def _update_global_step(self):
for engine in self.values():
self._global_step = max(self._global_step, engine.global_step)
def eval(self):
for engine in self.values():
engine.eval()
def train(self):
for engine in self.values():
engine.train()
def step(self, feeder: TrainFeeder, batch):
total_elapsed_time = 0
stats: Any = dict()
if self.cfg.trainer.gc_mode == 'step':
do_gc()
batch = to_device(batch, torch.cuda.current_device())
for name, engine in self.items():
torch.cuda.synchronize()
if self.cfg.trainer.gc_mode == 'substep':
do_gc()
start_time = time.time()
tries = 4
n_ooms = torch.zeros([], device=self.cfg.device)
if self.cfg.trainer.aggressive_optimizations:
batch = to_device(batch, torch.cuda.current_device())
# engine = engine.to(torch.cuda.current_device())
while tries >= 0:
try:
res = feeder( engines=self, batch=batch, name=name )
break
except RuntimeError as e:
print("Forward", str(e))
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
# shrink batch size until it's happy
for k in batch:
batch[k] = batch[k][:-1]
if tries <= 0:
# trigger OOM
n_ooms += 1
else:
# also do GC
do_gc()
continue
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during forward pass!")
if res is None:
continue
loss, engine_stats = res
n_ooms = torch.zeros([], device=self.cfg.device)
if self.cfg.trainer.aggressive_optimizations:
batch = to_device(batch, 'cpu')
try:
engine.backward(loss)
except RuntimeError as e:
print("Backwards:", str(e))
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
n_ooms += 1
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during backwards pass!")
engine.step()
torch.cuda.synchronize()
elapsed_time = time.time() - start_time
total_elapsed_time += elapsed_time
stats.update(
flatten_dict(
{
name.split("-")[0]: dict(
loss=loss.item(),
lr=engine.get_lr()[0],
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
elapsed_time=elapsed_time,
engine_step=engine.global_step,
**engine_stats,
)
}
),
)
del loss
# engine = engine.to('cpu')
self._update_global_step()
stats["batch_size"] = len(batch["text"])
stats["elapsed_time"] = total_elapsed_time
stats["wall_time"] = time.time()
stats["global_step"] = self.global_step
return stats

View File

@ -17,8 +17,9 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from typing import Protocol from typing import Protocol
from ..config import Config from ..config import cfg
from .distributed import ( from .distributed import (
fix_unset_envs,
global_leader_only, global_leader_only,
global_rank, global_rank,
is_global_leader, is_global_leader,
@ -26,8 +27,11 @@ from .distributed import (
local_leader_only, local_leader_only,
) )
from .engines import Engine, Engines, TrainFeeder from ..engines import Engine, Engines, TrainFeeder, default_feeder
from ..models import get_models
from .utils import to_device, do_gc from .utils import to_device, do_gc
from ..utils import wrapper as ml
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
_engines: Engines _engines: Engines
@ -45,14 +49,6 @@ def get_micro_step():
except: except:
return None return None
def get_cfg():
try:
return _engines.cfg
except:
raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
def get_cmd(): def get_cmd():
try: try:
return _command return _command
@ -62,19 +58,42 @@ def get_cmd():
get_iteration = get_global_step get_iteration = get_global_step
def load_engines():
models = get_models(cfg.models.get())
engines = dict()
class EnginesLoader(Protocol): for name in models:
def __call__(self) -> Engines: model = models[name]
...
optimizer = None
lr_scheduler = None
if cfg.hyperparameters.optimizer.lower() == "adamw-torch":
optimizer = ml.AdamW(
model.parameters(),
lr=cfg.hyperparameters.learning_rate,
betas=(0.9, 0.96),
eps=1e-07,
weight_decay=0.01,
)
if cfg.trainer.load_state_dict:
load_path = cfg.ckpt_dir / name / "fp32.pth"
model.load_state_dict(torch.load(load_path))
engines[name] = Engine(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
def load_engines(engines: dict[str, Engine], config: Config):
engines = Engines(engines) engines = Engines(engines)
engines.setup(config) engines.setup()
if not engines.cfg.trainer.load_state_dict:
engines.load_checkpoint()
return engines
if not cfg.trainer.load_state_dict:
engines.load_checkpoint()
return engines
class EvalFn(Protocol): class EvalFn(Protocol):
def __call__(self, *, engines: Engines): def __call__(self, *, engines: Engines):
@ -128,14 +147,14 @@ def seed(seed):
def train( def train(
engines_loader: EnginesLoader,
train_dl: DataLoader, train_dl: DataLoader,
train_feeder: TrainFeeder, train_feeder: TrainFeeder = default_feeder,
eval_fn: EvalFn, eval_fn: EvalFn = lambda x: ...,
logger: Logger = logger, logger: Logger = logger,
): ):
engines = engines_loader() fix_unset_envs()
cfg = engines.cfg
engines = load_engines()
""" """
if is_local_leader(): if is_local_leader():
@ -169,7 +188,7 @@ def train(
break break
#batch = to_device(batch, torch.cuda.current_device()) #batch = to_device(batch, torch.cuda.current_device())
stats = engines.step(feeder=train_feeder, batch=batch) stats = engines.step(batch=batch, feeder=train_feeder)
iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps
stats['it'] = iteration stats['it'] = iteration