nightmare of spaghetti that might break compat; mechanism to increase RVQ bins of an existing model without retraining, keeps sampled proms/resps at max RVQ level and trim off excess levels according to what model receives them, some other things I already forgot (I really hope no one else has weights being baked right now)

This commit is contained in:
mrq 2023-08-19 15:06:33 -05:00
parent f7f6d3bf6d
commit 2d1a9f10c0
11 changed files with 132 additions and 66 deletions

View File

@ -1,7 +1,14 @@
dataset:
training: []
training: [
# "./training/valle/data/LibriTTS/994/",
]
validation: []
validation: [
# "./training/valle/data/Validation/1188/",
]
noise: [
# "./training/valle/data/Other/noise/",
]
speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
@ -12,7 +19,7 @@ dataset:
workers: 4
cache: True
phones_range: [4, 256]
phones_range: [4, 512]
duration_range: [1.0, 16.0]
random_utterance: 1.0
@ -20,9 +27,11 @@ dataset:
prompt_duration: 3.0
sample_type: speaker
tasks_list: ["tts"] # do NOT change this until you're ready to train for SpeechX tasks # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"]
tasks_list: ["tts"] # ["tts", "ns", "sr", "tse", "cse", "nse"]
models:
_max_levels: 8
_models:
- name: "ar"
size: "full"
@ -38,13 +47,14 @@ models:
tasks: 8
arch_type: "retnet"
hyperparameters:
batch_size: 16
gradient_accumulation_steps: 2
gradient_accumulation_steps: 4
gradient_clipping: 100
optimizer: Adamw
learning_rate: 1.0e-5
optimizer: AdamW
learning_rate: 1.0e-4
scheduler_type: ""
#scheduler_type: OneCycle
@ -66,13 +76,13 @@ hyperparameters:
# decay_mom_rate: 0.0
evaluation:
batch_size: 32
batch_size: 16
frequency: 500
size: 32
size: 16
steps: 300
ar_temperature: 1.0
nar_temperature: 0.2
ar_temperature: 0.95
nar_temperature: 0.25
trainer:
iterations: 1_000_000
@ -106,4 +116,7 @@ inference:
normalize: False # do NOT change this unless you know exactly what you are doing.
bitsandbytes:
enabled: false
enabled: False
injects: True
linear: True
embedding: True

View File

@ -196,6 +196,8 @@ class Model:
@dataclass()
class Models:
_max_levels: int = 0
_models: list[Model] = field(default_factory=lambda: [
Model(name="ar", resp_levels=1, prom_levels=8, tasks=1),
Model(name="nar", resp_levels=7, prom_levels=8, tasks=1),
@ -232,6 +234,10 @@ class Models:
for model in self._models:
tasks = max(tasks, model.tasks)
return tasks
@property
def max_levels(self):
return self._max_levels if self._max_levels > 0 else self.prom_levels
@dataclass()
class Hyperparameters:
@ -261,7 +267,8 @@ class DeepSpeed:
use_compression_training: bool = False
compression_bits: int = 8
def get_ds_cfg(self, model):
@cached_property
def ds_cfg(self):
scheduler_params = {}
for k in cfg.hyperparameters.scheduler_params:
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
@ -277,7 +284,7 @@ class DeepSpeed:
"params": {
"lr": cfg.hyperparameters.learning_rate,
}
},
} if not cfg.hyperparameters.optimizer.endswith("-torch") else None,
"scheduler": {
"type": cfg.hyperparameters.scheduler_type,
"params": scheduler_params,
@ -351,8 +358,8 @@ class DeepSpeed:
for k in null_keys:
del ds_cfg[k]
if os.path.exists("./config/ds_config.json"):
ds_cfg.update(json.load(open("./config/ds_config.json", "r", encoding="utf-8")))
if os.path.exists("./data/ds_config.json"):
ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8")))
return ds_cfg
@ -404,8 +411,8 @@ class BitsAndBytes:
enabled: bool = False
injects: bool = False
linear: bool = False
embedding: bool = False
linear: bool = True
embedding: bool = True
@dataclass()
class Config(_Config):

View File

@ -63,11 +63,8 @@ def _get_quant_path(path):
def _get_phone_path(path):
return _replace_file_extension(path, ".phn.txt")
def _load_quants(path) -> Tensor:
path = _get_quant_path(path)
return torch.load(path)[0][:cfg.models.prom_levels, :].t().to(torch.int16)
return torch.load(path)[0][:, :].t().to(torch.int16)
@cache
def _get_phones(path, lang_marker="en"):
@ -215,12 +212,12 @@ class Dataset(_Dataset):
def _get_task_symmap(self):
return get_task_symmap()
def get_task_token( self, token ):
def get_task_token( self, token, levels=cfg.models.max_levels ):
if not hasattr(self, "task_symmap"):
self.task_symmap = self._get_task_symmap()
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(cfg.models.prom_levels) ]]).to(dtype=torch.int16)
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16)
def sample_noise(self):
def sample_noise(self):
paths = []
for data_dir in cfg.dataset.noise:
paths.extend(data_dir.rglob("*.qnt.pt"))
@ -228,7 +225,7 @@ class Dataset(_Dataset):
if False and cfg.dataset.use_hdf5:
key = f'/noise/{_get_hdf5_path(path)}'
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
else:
qnt = _load_quants(path)
return qnt
@ -260,7 +257,7 @@ class Dataset(_Dataset):
path = random.choice(choices)
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
else:
qnt = _load_quants(path)
@ -293,7 +290,7 @@ class Dataset(_Dataset):
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
text = torch.from_numpy(cfg.hdf5[key]["text"][:]).to(self.text_dtype)
resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
else:
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
resps = _load_quants(path)
@ -316,7 +313,7 @@ class Dataset(_Dataset):
# extend the noise to fill the target audio
noise = repeat_extend_audio(noise, resps.shape[0])
# create the input prompt by merging the target audio with the noise
proms = merge_audio(resps, noise, scale=[1, noise_scale], device="cpu")
proms = merge_audio( resps, noise, scale=[1, noise_scale], device="cpu" )
# set the target to just be the noise if <sr>
if task == "sr":
resps = noise
@ -358,7 +355,7 @@ class Dataset(_Dataset):
if cfg.dataset.use_hdf5:
texts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["text"][:]).to(self.text_dtype) for path in sampled ]
qnts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["audio"][:, :cfg.models.prom_levels]).to(torch.int16) for path in sampled ]
qnts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["audio"][:, :]).to(torch.int16) for path in sampled ]
else:
texts = [ torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype) for path in sampled ]
qnts = [ _load_quants(path) for path in sampled ]
@ -394,15 +391,15 @@ class Dataset(_Dataset):
# it might be better to extend the noise to the sum of the pre+mid+post or pre+edit+post to keep the noise truly coherent
# but it's noise, it's supposed to be random
def noise_proms( proms ):
def noise_proms( p ):
# ignore if we turned it off
if proms is None:
if p is None:
return None
# extend the noise to fill the target audio
n = repeat_extend_audio(noise, proms.shape[0])
n = repeat_extend_audio(noise, p.shape[0])
# merge the noise over the utterance
return merge_audio(proms, n, scale=[1, noise_scale], device="cpu")
return merge_audio(p, n, scale=[1, noise_scale], device="cpu")
# apply noise to all pieces
pre_prom = noise_proms( pre_prom )
@ -426,6 +423,8 @@ class Dataset(_Dataset):
[ edit_prom ] +
([ post_prom ] if post_prom is not None else [])
)
else:
raise f'Undefined task: {task}'
"""
# emulate SVC
@ -450,6 +449,10 @@ class Dataset(_Dataset):
text = torch.tensor([1, 2]).to(self.text_dtype)
"""
# trim to fit to requested prom/resps levels
proms = proms[:, :cfg.models.prom_levels]
resps = resps[:, :cfg.models.prom_levels]
return dict(
index=index,

View File

@ -21,16 +21,16 @@ except Exception as e:
cfg.inference.use_vocos = False
@cache
def _load_encodec_model(device="cuda"):
def _load_encodec_model(device="cuda", levels=cfg.models.max_levels):
# Instantiate a pretrained EnCodec model
assert cfg.sample_rate == 24_000
# too lazy to un-if ladder this shit
if cfg.models.prom_levels == 2:
if levels == 2:
bandwidth_id = 1.5
elif cfg.models.prom_levels == 4:
elif levels == 4:
bandwidth_id = 3.0
elif cfg.models.prom_levels == 8:
elif levels == 8:
bandwidth_id = 6.0
model = EncodecModel.encodec_model_24khz().to(device)
@ -43,18 +43,18 @@ def _load_encodec_model(device="cuda"):
return model
@cache
def _load_vocos_model(device="cuda"):
def _load_vocos_model(device="cuda", levels=cfg.models.max_levels):
assert cfg.sample_rate == 24_000
model = Vocos.from_pretrained("charactr/vocos-encodec-24khz")
model = model.to(device)
# too lazy to un-if ladder this shit
if cfg.models.prom_levels == 2:
if levels == 2:
bandwidth_id = 0
elif cfg.models.prom_levels == 4:
elif levels == 4:
bandwidth_id = 1
elif cfg.models.prom_levels == 8:
elif levels == 8:
bandwidth_id = 2
model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
@ -64,11 +64,11 @@ def _load_vocos_model(device="cuda"):
return model
@cache
def _load_model(device="cuda", vocos=cfg.inference.use_vocos):
def _load_model(device="cuda", vocos=cfg.inference.use_vocos, levels=cfg.models.max_levels):
if vocos:
model = _load_vocos_model(device)
model = _load_vocos_model(device, levels=levels)
else:
model = _load_encodec_model(device)
model = _load_encodec_model(device, levels=levels)
return model
@ -78,7 +78,7 @@ def unload_model():
@torch.inference_mode()
def decode(codes: Tensor, device="cuda"):
def decode(codes: Tensor, device="cuda", levels=cfg.models.max_levels):
"""
Args:
codes: (b q t)
@ -94,7 +94,7 @@ def decode(codes: Tensor, device="cuda"):
codes = rearrange(codes, "t q -> 1 q t")
assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}'
model = _load_model(device)
model = _load_model(device, levels=levels)
# upcast so it won't whine
if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8:
@ -115,8 +115,8 @@ def decode(codes: Tensor, device="cuda"):
return wav, model.sample_rate
# huh
def decode_to_wave(resps: Tensor, device="cuda"):
return decode(resps, device=device)
def decode_to_wave(resps: Tensor, device="cuda", levels=cfg.models.max_levels):
return decode(resps, device=device, levels=levels)
def decode_to_file(resps: Tensor, path: Path, device="cuda"):
wavs, sr = decode(resps, device=device)
@ -129,14 +129,14 @@ def _replace_file_extension(path, suffix):
@torch.inference_mode()
def encode(wav: Tensor, sr: int = 24_000, device="cuda"):
def encode(wav: Tensor, sr: int = 24_000, device="cuda", levels=cfg.models.max_levels):
"""
Args:
wav: (t)
sr: int
"""
model = _load_encodec_model(device)
model = _load_encodec_model(device, levels=levels)
wav = wav.unsqueeze(0)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
wav = wav.to(device)
@ -203,16 +203,16 @@ def repeat_extend_audio( qnt, target ):
# merges two quantized audios together
# I don't know if this works
def merge_audio( *args, device="cpu", scale=[] ):
def merge_audio( *args, device="cpu", scale=[], levels=cfg.models.max_levels ):
qnts = [*args]
decoded = [ decode_to_wave(qnt, device=device)[0] for qnt in qnts ]
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ]
if len(scale) == len(decoded):
for i in range(len(scale)):
decoded[i] = decoded[i] * scale[i]
combined = sum(decoded) / len(decoded)
return encode(combined, 24_000, device="cpu")[0].t()
return encode(combined, 24_000, device="cpu", levels=levels)[0].t()
def main():
parser = argparse.ArgumentParser()

View File

@ -39,6 +39,7 @@ import os
from torch import Tensor
from torch.distributed import all_reduce
from typing import Any, Protocol
from functools import cached_property
from .base import TrainFeeder
@ -50,6 +51,10 @@ if not distributed_initialized() and cfg.trainer.backend == "local":
# A very naive engine implementation using barebones PyTorch
class Engine():
def __init__(self, *args, **kwargs):
if '_cfg' in kwargs:
self._cfg = kwargs['_cfg']
kwargs.pop("_cfg")
self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype)
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
@ -137,6 +142,10 @@ class Engine():
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
@cached_property
def device(self):
return next(self.module.parameters()).device
def forward(self, *args, **kwargs):
return self.module.forward(*args, **kwargs)

View File

@ -31,7 +31,11 @@ if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
class Engine(DeepSpeedEngine):
def __init__(self, *args, **kwargs):
kwargs['config'] = cfg.trainer.deepspeed.get_ds_cfg(model=kwargs['model'])
if '_cfg' in kwargs:
self._cfg = kwargs['_cfg']
kwargs.pop("_cfg")
kwargs['config'] = cfg.trainer.deepspeed.ds_cfg
kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
super().__init__(None, *args, **kwargs)

View File

@ -8,10 +8,6 @@ from torch import Tensor
from tqdm import trange
class AR(Base):
@property
def n_resp_levels(self) -> int:
return cfg.models.ar.resp_levels
@property
def causal(self):
return True
@ -32,6 +28,14 @@ class AR(Base):
def n_prom_levels(self) -> int:
return cfg.models.prom_levels
@property
def n_resp_levels(self) -> int:
return cfg.models.ar.resp_levels
@property
def n_max_levels(self) -> int:
return cfg.models.max_levels
@property
def n_tasks(self) -> int:
return cfg.models.tasks

View File

@ -113,6 +113,10 @@ class Base(nn.Module):
@property
def n_prom_levels(self) -> int:
raise NotImplementedError
@property
def n_max_levels(self) -> int:
raise NotImplementedError
@property
def n_tasks(self) -> int:

View File

@ -7,10 +7,6 @@ from torch import Tensor
from tqdm import trange
class NAR(Base):
@property
def n_resp_levels(self) -> int:
return cfg.models.nar.resp_levels
@property
def causal(self):
return False
@ -31,6 +27,14 @@ class NAR(Base):
def n_prom_levels(self) -> int:
return cfg.models.prom_levels
@property
def n_resp_levels(self) -> int:
return cfg.models.nar.resp_levels
@property
def n_max_levels(self) -> int:
return cfg.models.max_levels
@property
def n_tasks(self) -> int:
return cfg.models.tasks

View File

@ -23,7 +23,11 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu")
_logger = logging.getLogger(__name__)
def train_feeder(engine, batch):
engine( text_list=batch["text"], proms_list=batch["proms"], resps_list=batch["resps"] )
engine(
text_list=batch["text"],
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level
resps_list=batch["resps"]
)
losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats")

View File

@ -88,19 +88,33 @@ def load_engines():
# extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]:
n_prom_levels, n_prom_tokens, d_model = state['proms_emb.weight'].shape
o_prom_levels, o_prom_tokens, d_model = state['proms_emb.weight'].shape
# copy weights from the dict into the old portion
model.proms_emb.weight.data[:n_prom_levels, :n_prom_tokens, :] = state['proms_emb.weight'].data[:n_prom_levels, :n_prom_tokens, :]
model.proms_emb.weight.data[:o_prom_levels, :o_prom_tokens, :] = state['proms_emb.weight'].data[:o_prom_levels, :o_prom_tokens, :]
# copy the full tensors back
state['proms_emb.weight'] = model.proms_emb.weight
# extend the resps_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
if model.resps_emb.weight.shape[0] > state['resps_emb.weight'].shape[0] or model.resps_emb.weight.shape[1] > state['resps_emb.weight'].shape[1]:
o_resp_levels, o_resp_tokens, d_model = state['resps_emb.weight'].shape
n_resp_levels, n_resp_tokens, d_model = model.resps_emb.weight.shape
# copy weights from the dict into the old portion
model.resps_emb.weight.data[:o_resp_levels, :o_resp_tokens, :] = state['resps_emb.weight'].data[:o_resp_levels, :o_resp_tokens, :]
# reuse additional levels, probably bad
for n in range(o_resp_tokens, n_resp_tokens):
model.resps_emb.weight.data[n] = model.resps_emb.weight.data[o_resp_tokens-1]
# copy the full tensors back
state['resps_emb.weight'] = model.resps_emb.weight
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
engines[name] = Engine(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
_cfg=model._cfg,
)
engines = Engines(engines)