somewhat got recurrent forward working (it's as accurate as chunkwise forward: it's not accurate at all), added option to use AMP instead of blanket setting the weight's dtype

This commit is contained in:
mrq 2023-09-01 20:58:29 -05:00
parent 2bc2d08b09
commit e40c0d34a0
9 changed files with 107 additions and 161 deletions

View File

@ -165,7 +165,6 @@ class Model:
arch_type: str = "transformer"
training: bool = True
@property
def full_name(self):
name = [ self.name ]
@ -332,9 +331,9 @@ class DeepSpeed:
"fp16": {
"enabled": True,
"auto_cast": True,
} if cfg.trainer.weight_dtype.lower() == "float16" else None,
} if cfg.trainer.weight_dtype.lower() == "float16" and not cfg.trainer.amp else None,
"bf16": {
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16"
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16" and not cfg.trainer.amp
},
"compression_training": {
"weight_quantization": {
@ -427,15 +426,13 @@ class Trainer:
aggressive_optimizations: bool = False
check_for_oom: bool = True
gc_mode: str | None = None
load_disabled_engines: bool = False
gc_mode: str | None = None
weight_dtype: str = "float16"
amp: bool = False
backend: str = "local" # "deepspeed" if not sys.platform.startswith("win") else "local"
backend: str = "local"
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
@cached_property
@ -450,10 +447,14 @@ class Trainer:
@dataclass()
class Inference:
weight_dtype: str = "float32"
amp: bool = False
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
use_vocos: bool = True
recurrent_chunk_size: int = 0
recurrent_forward: bool = False
@cached_property
def dtype(self):
if self.weight_dtype == "float16":
@ -473,6 +474,7 @@ class BitsAndBytes:
@dataclass()
class Config(_Config):
device: str = "cuda"
mode: str = "training" # "inferencing"
dataset: Dataset = field(default_factory=lambda: Dataset)
models: Models = field(default_factory=lambda: Models)

View File

@ -55,7 +55,7 @@ class Engine():
self._cfg = kwargs['_cfg']
kwargs.pop("_cfg")
self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype)
self.module = kwargs['model'].to(cfg.device).to(torch.float32 if cfg.trainer.amp else cfg.trainer.dtype)
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
@ -196,9 +196,11 @@ class Engine():
return 0.0
def traverse(self, *args, **kwargs):
self.forward(*args, **kwargs)
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
self.forward(*args, **kwargs)
losses = self.gather_attribute("loss")
print(self.module.loss)
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}

View File

@ -87,9 +87,10 @@ class Engine(DeepSpeedEngine):
print(str(e))
def traverse(self, *args, **kwargs):
self.forward(*args, **kwargs)
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
with torch.autocast(self.device, dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
self.forward(*args, **kwargs)
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}

View File

@ -31,6 +31,8 @@ class TTS():
cfg.format()
except Exception as e:
pass
cfg.mode = "inferencing"
self.symmap = None
if ar_ckpt and nar_ckpt:
@ -47,7 +49,7 @@ class TTS():
if "module" in state:
state = state['module']
self.ar.load_state_dict(state)
self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype)
self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
elif name.startswith("nar"):
self.nar = model
state = torch.load(self.nar_ckpt)
@ -56,7 +58,7 @@ class TTS():
if "module" in state:
state = state['module']
self.nar.load_state_dict(state)
self.nar = self.nar.to(self.device, dtype=cfg.inference.dtype)
self.nar = self.nar.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
else:
self.load_models()
@ -72,9 +74,9 @@ class TTS():
engines = load_engines()
for name, engine in engines.items():
if name[:2] == "ar":
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype)
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
elif name[:3] == "nar":
self.nar = engine.module.to(self.device, dtype=cfg.inference.dtype)
self.nar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
def encode_text( self, text, language="en" ):
# already a tensor, return it
@ -119,9 +121,10 @@ class TTS():
prom = to_device(prom, self.device).to(torch.int16)
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp)
resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp)
with torch.autocast(self.device, dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp)
resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp)
wav, sr = qnt.decode_to_file(resps_list[0], out_path)

View File

@ -46,9 +46,15 @@ class AR(Base):
return cfg.models.tasks
@property
def resp_loss_only(self):
def resp_loss_only(self) -> bool:
return False
@property
def recurrent_chunk_size(self) -> int:
if cfg.mode == "training":
return 0
return cfg.inference.recurrent_chunk_size
def _prune(self, l: Tensor):
indices = (l == self.stop_token).nonzero()
if len(indices) == 0:
@ -66,8 +72,6 @@ class AR(Base):
resps_list: list[Tensor] | None = None,
max_steps: int = 1000,
sampling_temperature: float = 1.0,
naive: bool = True,
):
if resps_list is not None:
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
@ -83,109 +87,35 @@ class AR(Base):
)
device = text_list[0].device
resps_list: list[Tensor] = [
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
]
stopped = torch.zeros(len(text_list), device=device).bool()
batch_size = len(text_list)
chunk_size = self.causal_chunk_size # don't really know what to do about this desu
resps_list: list[Tensor] = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
stopped = torch.zeros(batch_size, device=device).bool()
state = None
start = 0
state = {} if cfg.inference.recurrent_forward else None
if naive:
for n in trange(max_steps // max(1, chunk_size)):
# get next in sequence
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
# get next in sequence
r, state = super().forward(
text_list,
proms_list,
self._unsqueeze_list(resps_list),
sampling_temperature=sampling_temperature,
state=state # if not naive else None,
)
# append outputted token
if self.causal_chunk_size > 0:
for i, ri in enumerate(r):
resps_list[i] = torch.cat([resps_list[i], ri])
else:
for i, ri in enumerate(r):
resps_list[i] = torch.cat([resps_list[i], ri[None]])
# stop token found
stopped |= r == self.stop_token
if stopped.all().item():
break
# to-do: make it work
# it seems anything that isn't a one-at-a-time sequence does not work, despite generating STOP tokens.
else:
resps_list: list[Tensor] = [
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
]
test_list: list[Tensor] = [
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
]
batch_size = len(text_list)
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.proms_emb(proms_list),
self.resps_emb(self._unsqueeze_list(resps_list)),
sep=self.sep,
r = super().forward(
text_list,
proms_list,
self._unsqueeze_list(resps_list),
sampling_temperature=sampling_temperature,
state=state
)
x, m = list_to_tensor(x_list)
device = x.device
if state is None:
state = {}
# pre-fill KV cache
for n in trange(x.shape[1]):
xs = x[:, n:(n + 1), :]
r, _ = self.retnet(xs, incremental_state=state, token_embeddings=xs, features_only=True)
r = self.classifier(r) * m
logits = torch.stack([hi[-1] for hi in r])
r = Categorical(logits=logits / sampling_temperature).sample()
for i, ri in enumerate(r):
test_list[i] = torch.cat([test_list[i], ri[None]])
# append outputted token
# append tokens
for i, ri in enumerate(r):
resps_list[i] = torch.cat([resps_list[i], ri[None]])
if self.stop_token in ri:
stopped[i] = True
resps_list[i] = torch.cat([resps_list[i], ri])
start = x.shape[1]
for n in trange(max_steps // max(1, chunk_size)):
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.proms_emb(proms_list),
self.resps_emb(self._unsqueeze_list(resps_list)),
sep=self.sep,
)
# stop token found
stopped |= r == self.stop_token
if stopped.all().item():
break
x, m = list_to_tensor(x_list)
xs = x[:, start+n:start+(n+1), :]
r, _ = self.retnet(xs, incremental_state=state, token_embeddings=xs, features_only=True)
r = self.classifier(r) * m
logits = torch.stack([hi[-1] for hi in r])
r = Categorical(logits=logits / sampling_temperature).sample()
# append outputted token
for i, ri in enumerate(r):
resps_list[i] = torch.cat([resps_list[i], ri[None]])
# stop token found
stopped |= r == self.stop_token
if stopped.all().item():
break
pruned = [self._prune(r) for r in resps_list]
return pruned

View File

@ -126,6 +126,10 @@ class Base(nn.Module):
def resp_loss_only(self):
raise NotImplementedError
@property
def recurrent_chunk_size(self) -> int:
raise NotImplementedError
def __init__(
self,
n_tokens: int = 1024,
@ -150,8 +154,6 @@ class Base(nn.Module):
self.sep = nn.Parameter(torch.randn(d_model))
self.causal_chunk_size = 0 # 64 if self.causal else 1
if self.arch_type == "transformer":
self.sin_emb = SinusoidalEmbedding(d_model)
self.blocks = nn.ModuleList([TransformerBlock(
@ -173,8 +175,8 @@ class Base(nn.Module):
dropout=p_dropout,
checkpoint_activations=True,
chunkwise_recurrent=self.causal and self.causal_chunk_size > 0,
recurrent_chunkwise_size=self.causal_chunk_size,
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
no_output_layer=True,
decoder_normalize_before=True,
))
@ -257,7 +259,7 @@ class Base(nn.Module):
return_all_resp: bool = False,
sampling_temperature: float = 1.0,
state: list | None = None,
state: dict | None = None,
):
"""
Args:
@ -285,6 +287,19 @@ class Base(nn.Module):
x, m = list_to_tensor(x_list)
device = x.device
if state is not None:
# prefill
prefill_size = x.shape[1]
# run the initial prompt to fill the KV cache
if len(state) == 0:
for n in range(prefill_size):
xi = x[:, n, :].unsqueeze(1)
self.retnet(xi, incremental_state=state, token_embeddings=xi, features_only=True)
# grab last token(s)
x = x[:, -1, :].unsqueeze(1)
if self.arch_type == "transformer":
x = self.sin_emb.add_pe(x)
for block in self.blocks:
@ -292,7 +307,6 @@ class Base(nn.Module):
elif self.arch_type == "retnet":
# to-do: actually make this work and verify it works with recurrent_forward / chunkwise_forward
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
state = self.retnet.get_incremental_state( state, 'prev_state' )
x = self.classifier(x) * m
@ -351,25 +365,21 @@ class Base(nn.Module):
acc = self.accuracy_metric( torch.cat(h_list), torch.cat(y_list) ),
precision = self.precision_metric( torch.cat(h_list), torch.cat(y_list) ),
)
# return the entire generated token string
if return_all:
logits = [hi[:] for hi, li in zip(h_list, map(len, resps_list))]
ret = [Categorical(logits=hi / sampling_temperature).sample() for hi in logits]
# return the entire generated response
elif return_all_resp:
logits = [hi[-li:] for hi, li in zip(h_list, map(len, resps_list))]
ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
# return the last chunkwise piece
elif self.causal_chunk_size > 0:
logits = [hi[-self.causal_chunk_size:] for hi, li in zip(h_list, map(len, resps_list))]
ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
elif self.causal and self.recurrent_chunk_size > 0:
logits = [hi[-self.recurrent_chunk_size:] for hi, li in zip(h_list, map(len, resps_list))]
# return just the last code
else:
logits = torch.stack([hi[-1] for hi in h_list])
ret = Categorical(logits=logits / sampling_temperature).sample()
logits = [ hi[-1:] for hi in h_list ]
return ret, state
return [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
def example_usage():
from ..config import cfg
@ -387,9 +397,6 @@ def example_usage():
from .ar import AR
from .nar import NAR
from ..models import get_models
from ..config import Model as ModelCfg
device = "cuda"
x8 = partial(repeat, pattern="t -> t l", l=2)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
@ -398,20 +405,13 @@ def example_usage():
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
return torch.tensor([*map(symmap.get, phones)]).to()
models = get_models({
"ar": Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True, size=1),
"nar": Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True, size=1)}
)
"""
model_cfg = ModelCfg()
kwargs = {
'n_tokens': model_cfg.tokens,
'd_model': model_cfg.dim,
'n_heads': model_cfg.heads,
'n_layers': model_cfg.layers,
'n_tokens': 1024,
'd_model': 1024,
'n_heads': 16,
'n_layers': 24,
}
models = { "ar": AR(**kwargs).to(device), "nar": NAR(**kwargs).to(device) }
"""
engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
train = True

View File

@ -44,9 +44,13 @@ class NAR(Base):
return cfg.models.tasks
@property
def resp_loss_only(self):
def resp_loss_only(self) -> bool:
return True
@property
def recurrent_chunk_size(self) -> int:
return 0
def forward(
self,
text_list: list[Tensor],
@ -104,7 +108,7 @@ class NAR(Base):
quant_levels = torch.full((len(text_list),), level, device=device)
resps_list, _ = super().forward(
resps_list = super().forward(
text_list,
proms_list,
prev_list,

View File

@ -8,6 +8,10 @@ from typing import Dict, Optional
from torch import Tensor
from torchscale.architecture.config import RetNetConfig
from torchscale.architecture.retnet import RetNetDecoder
"""
class FairseqIncrementalState(object):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -24,7 +28,6 @@ class FairseqIncrementalState(object):
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
) -> Optional[Dict[str, Optional[Tensor]]]:
"""Helper for getting incremental state for an nn.Module."""
full_key = self._get_full_incremental_state_key(key)
if incremental_state is None or full_key not in incremental_state:
return None
@ -36,7 +39,6 @@ class FairseqIncrementalState(object):
key: str,
value: Dict[str, Optional[Tensor]],
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
"""Helper for setting incremental state for an nn.Module."""
if incremental_state is not None:
full_key = self._get_full_incremental_state_key(key)
incremental_state[full_key] = value
@ -65,4 +67,5 @@ class RetNetDecoder(Decoder):
for module in incremental_state:
for key in incremental_state[module]:
result = incremental_state[module][key].index_select(0, new_order)
incremental_state[module][key] = result
incremental_state[module][key] = result
"""

View File

@ -23,16 +23,17 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu")
_logger = logging.getLogger(__name__)
def train_feeder(engine, batch):
engine(
text_list=batch["text"],
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level
resps_list=batch["resps"]
)
with torch.autocast(engine.device, dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
engine(
text_list=batch["text"],
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level
resps_list=batch["resps"]
)
losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats")
losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats")
loss = torch.stack([*losses.values()]).sum()
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}