somewhat got recurrent forward working (it's as accurate as chunkwise forward: it's not accurate at all), added option to use AMP instead of blanket setting the weight's dtype
This commit is contained in:
parent
2bc2d08b09
commit
e40c0d34a0
|
@ -165,7 +165,6 @@ class Model:
|
|||
arch_type: str = "transformer"
|
||||
training: bool = True
|
||||
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
name = [ self.name ]
|
||||
|
@ -332,9 +331,9 @@ class DeepSpeed:
|
|||
"fp16": {
|
||||
"enabled": True,
|
||||
"auto_cast": True,
|
||||
} if cfg.trainer.weight_dtype.lower() == "float16" else None,
|
||||
} if cfg.trainer.weight_dtype.lower() == "float16" and not cfg.trainer.amp else None,
|
||||
"bf16": {
|
||||
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16"
|
||||
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16" and not cfg.trainer.amp
|
||||
},
|
||||
"compression_training": {
|
||||
"weight_quantization": {
|
||||
|
@ -427,15 +426,13 @@ class Trainer:
|
|||
|
||||
aggressive_optimizations: bool = False
|
||||
check_for_oom: bool = True
|
||||
|
||||
gc_mode: str | None = None
|
||||
load_disabled_engines: bool = False
|
||||
|
||||
gc_mode: str | None = None
|
||||
|
||||
weight_dtype: str = "float16"
|
||||
amp: bool = False
|
||||
|
||||
backend: str = "local" # "deepspeed" if not sys.platform.startswith("win") else "local"
|
||||
|
||||
backend: str = "local"
|
||||
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
||||
|
||||
@cached_property
|
||||
|
@ -450,10 +447,14 @@ class Trainer:
|
|||
@dataclass()
|
||||
class Inference:
|
||||
weight_dtype: str = "float32"
|
||||
amp: bool = False
|
||||
|
||||
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
||||
use_vocos: bool = True
|
||||
|
||||
recurrent_chunk_size: int = 0
|
||||
recurrent_forward: bool = False
|
||||
|
||||
@cached_property
|
||||
def dtype(self):
|
||||
if self.weight_dtype == "float16":
|
||||
|
@ -473,6 +474,7 @@ class BitsAndBytes:
|
|||
@dataclass()
|
||||
class Config(_Config):
|
||||
device: str = "cuda"
|
||||
mode: str = "training" # "inferencing"
|
||||
|
||||
dataset: Dataset = field(default_factory=lambda: Dataset)
|
||||
models: Models = field(default_factory=lambda: Models)
|
||||
|
|
|
@ -55,7 +55,7 @@ class Engine():
|
|||
self._cfg = kwargs['_cfg']
|
||||
kwargs.pop("_cfg")
|
||||
|
||||
self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype)
|
||||
self.module = kwargs['model'].to(cfg.device).to(torch.float32 if cfg.trainer.amp else cfg.trainer.dtype)
|
||||
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
|
||||
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
|
||||
|
||||
|
@ -196,9 +196,11 @@ class Engine():
|
|||
return 0.0
|
||||
|
||||
def traverse(self, *args, **kwargs):
|
||||
self.forward(*args, **kwargs)
|
||||
losses = self.gather_attribute("loss")
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||
self.forward(*args, **kwargs)
|
||||
losses = self.gather_attribute("loss")
|
||||
print(self.module.loss)
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
|
|
|
@ -87,9 +87,10 @@ class Engine(DeepSpeedEngine):
|
|||
print(str(e))
|
||||
|
||||
def traverse(self, *args, **kwargs):
|
||||
self.forward(*args, **kwargs)
|
||||
losses = self.gather_attribute("loss")
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
with torch.autocast(self.device, dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||
self.forward(*args, **kwargs)
|
||||
losses = self.gather_attribute("loss")
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
|
|
|
@ -31,6 +31,8 @@ class TTS():
|
|||
cfg.format()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
cfg.mode = "inferencing"
|
||||
|
||||
self.symmap = None
|
||||
if ar_ckpt and nar_ckpt:
|
||||
|
@ -47,7 +49,7 @@ class TTS():
|
|||
if "module" in state:
|
||||
state = state['module']
|
||||
self.ar.load_state_dict(state)
|
||||
self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype)
|
||||
self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||
elif name.startswith("nar"):
|
||||
self.nar = model
|
||||
state = torch.load(self.nar_ckpt)
|
||||
|
@ -56,7 +58,7 @@ class TTS():
|
|||
if "module" in state:
|
||||
state = state['module']
|
||||
self.nar.load_state_dict(state)
|
||||
self.nar = self.nar.to(self.device, dtype=cfg.inference.dtype)
|
||||
self.nar = self.nar.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||
else:
|
||||
self.load_models()
|
||||
|
||||
|
@ -72,9 +74,9 @@ class TTS():
|
|||
engines = load_engines()
|
||||
for name, engine in engines.items():
|
||||
if name[:2] == "ar":
|
||||
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype)
|
||||
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||
elif name[:3] == "nar":
|
||||
self.nar = engine.module.to(self.device, dtype=cfg.inference.dtype)
|
||||
self.nar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||
|
||||
def encode_text( self, text, language="en" ):
|
||||
# already a tensor, return it
|
||||
|
@ -119,9 +121,10 @@ class TTS():
|
|||
prom = to_device(prom, self.device).to(torch.int16)
|
||||
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||
|
||||
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp)
|
||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp)
|
||||
with torch.autocast(self.device, dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp)
|
||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp)
|
||||
|
||||
wav, sr = qnt.decode_to_file(resps_list[0], out_path)
|
||||
|
||||
|
|
|
@ -46,9 +46,15 @@ class AR(Base):
|
|||
return cfg.models.tasks
|
||||
|
||||
@property
|
||||
def resp_loss_only(self):
|
||||
def resp_loss_only(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def recurrent_chunk_size(self) -> int:
|
||||
if cfg.mode == "training":
|
||||
return 0
|
||||
return cfg.inference.recurrent_chunk_size
|
||||
|
||||
def _prune(self, l: Tensor):
|
||||
indices = (l == self.stop_token).nonzero()
|
||||
if len(indices) == 0:
|
||||
|
@ -66,8 +72,6 @@ class AR(Base):
|
|||
resps_list: list[Tensor] | None = None,
|
||||
max_steps: int = 1000,
|
||||
sampling_temperature: float = 1.0,
|
||||
|
||||
naive: bool = True,
|
||||
):
|
||||
if resps_list is not None:
|
||||
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
|
||||
|
@ -83,109 +87,35 @@ class AR(Base):
|
|||
)
|
||||
|
||||
device = text_list[0].device
|
||||
resps_list: list[Tensor] = [
|
||||
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
|
||||
]
|
||||
stopped = torch.zeros(len(text_list), device=device).bool()
|
||||
batch_size = len(text_list)
|
||||
|
||||
chunk_size = self.causal_chunk_size # don't really know what to do about this desu
|
||||
resps_list: list[Tensor] = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
state = None
|
||||
start = 0
|
||||
state = {} if cfg.inference.recurrent_forward else None
|
||||
|
||||
if naive:
|
||||
for n in trange(max_steps // max(1, chunk_size)):
|
||||
# get next in sequence
|
||||
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
||||
# get next in sequence
|
||||
|
||||
r, state = super().forward(
|
||||
text_list,
|
||||
proms_list,
|
||||
self._unsqueeze_list(resps_list),
|
||||
sampling_temperature=sampling_temperature,
|
||||
state=state # if not naive else None,
|
||||
)
|
||||
|
||||
# append outputted token
|
||||
if self.causal_chunk_size > 0:
|
||||
for i, ri in enumerate(r):
|
||||
resps_list[i] = torch.cat([resps_list[i], ri])
|
||||
else:
|
||||
for i, ri in enumerate(r):
|
||||
resps_list[i] = torch.cat([resps_list[i], ri[None]])
|
||||
|
||||
|
||||
# stop token found
|
||||
stopped |= r == self.stop_token
|
||||
if stopped.all().item():
|
||||
break
|
||||
# to-do: make it work
|
||||
# it seems anything that isn't a one-at-a-time sequence does not work, despite generating STOP tokens.
|
||||
else:
|
||||
resps_list: list[Tensor] = [
|
||||
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
|
||||
]
|
||||
|
||||
test_list: list[Tensor] = [
|
||||
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
|
||||
]
|
||||
|
||||
batch_size = len(text_list)
|
||||
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
self.proms_emb(proms_list),
|
||||
self.resps_emb(self._unsqueeze_list(resps_list)),
|
||||
sep=self.sep,
|
||||
r = super().forward(
|
||||
text_list,
|
||||
proms_list,
|
||||
self._unsqueeze_list(resps_list),
|
||||
sampling_temperature=sampling_temperature,
|
||||
state=state
|
||||
)
|
||||
|
||||
x, m = list_to_tensor(x_list)
|
||||
device = x.device
|
||||
|
||||
if state is None:
|
||||
state = {}
|
||||
|
||||
# pre-fill KV cache
|
||||
for n in trange(x.shape[1]):
|
||||
xs = x[:, n:(n + 1), :]
|
||||
r, _ = self.retnet(xs, incremental_state=state, token_embeddings=xs, features_only=True)
|
||||
r = self.classifier(r) * m
|
||||
|
||||
logits = torch.stack([hi[-1] for hi in r])
|
||||
r = Categorical(logits=logits / sampling_temperature).sample()
|
||||
|
||||
for i, ri in enumerate(r):
|
||||
test_list[i] = torch.cat([test_list[i], ri[None]])
|
||||
|
||||
# append outputted token
|
||||
# append tokens
|
||||
for i, ri in enumerate(r):
|
||||
resps_list[i] = torch.cat([resps_list[i], ri[None]])
|
||||
if self.stop_token in ri:
|
||||
stopped[i] = True
|
||||
resps_list[i] = torch.cat([resps_list[i], ri])
|
||||
|
||||
start = x.shape[1]
|
||||
for n in trange(max_steps // max(1, chunk_size)):
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
self.proms_emb(proms_list),
|
||||
self.resps_emb(self._unsqueeze_list(resps_list)),
|
||||
sep=self.sep,
|
||||
)
|
||||
# stop token found
|
||||
stopped |= r == self.stop_token
|
||||
if stopped.all().item():
|
||||
break
|
||||
|
||||
x, m = list_to_tensor(x_list)
|
||||
|
||||
xs = x[:, start+n:start+(n+1), :]
|
||||
r, _ = self.retnet(xs, incremental_state=state, token_embeddings=xs, features_only=True)
|
||||
r = self.classifier(r) * m
|
||||
|
||||
logits = torch.stack([hi[-1] for hi in r])
|
||||
r = Categorical(logits=logits / sampling_temperature).sample()
|
||||
|
||||
# append outputted token
|
||||
for i, ri in enumerate(r):
|
||||
resps_list[i] = torch.cat([resps_list[i], ri[None]])
|
||||
|
||||
# stop token found
|
||||
stopped |= r == self.stop_token
|
||||
if stopped.all().item():
|
||||
break
|
||||
|
||||
pruned = [self._prune(r) for r in resps_list]
|
||||
return pruned
|
||||
|
|
|
@ -126,6 +126,10 @@ class Base(nn.Module):
|
|||
def resp_loss_only(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def recurrent_chunk_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_tokens: int = 1024,
|
||||
|
@ -150,8 +154,6 @@ class Base(nn.Module):
|
|||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
self.causal_chunk_size = 0 # 64 if self.causal else 1
|
||||
|
||||
if self.arch_type == "transformer":
|
||||
self.sin_emb = SinusoidalEmbedding(d_model)
|
||||
self.blocks = nn.ModuleList([TransformerBlock(
|
||||
|
@ -173,8 +175,8 @@ class Base(nn.Module):
|
|||
dropout=p_dropout,
|
||||
checkpoint_activations=True,
|
||||
|
||||
chunkwise_recurrent=self.causal and self.causal_chunk_size > 0,
|
||||
recurrent_chunkwise_size=self.causal_chunk_size,
|
||||
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
|
||||
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
|
||||
no_output_layer=True,
|
||||
decoder_normalize_before=True,
|
||||
))
|
||||
|
@ -257,7 +259,7 @@ class Base(nn.Module):
|
|||
return_all_resp: bool = False,
|
||||
sampling_temperature: float = 1.0,
|
||||
|
||||
state: list | None = None,
|
||||
state: dict | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -285,6 +287,19 @@ class Base(nn.Module):
|
|||
x, m = list_to_tensor(x_list)
|
||||
device = x.device
|
||||
|
||||
if state is not None:
|
||||
# prefill
|
||||
prefill_size = x.shape[1]
|
||||
|
||||
# run the initial prompt to fill the KV cache
|
||||
if len(state) == 0:
|
||||
for n in range(prefill_size):
|
||||
xi = x[:, n, :].unsqueeze(1)
|
||||
self.retnet(xi, incremental_state=state, token_embeddings=xi, features_only=True)
|
||||
|
||||
# grab last token(s)
|
||||
x = x[:, -1, :].unsqueeze(1)
|
||||
|
||||
if self.arch_type == "transformer":
|
||||
x = self.sin_emb.add_pe(x)
|
||||
for block in self.blocks:
|
||||
|
@ -292,7 +307,6 @@ class Base(nn.Module):
|
|||
elif self.arch_type == "retnet":
|
||||
# to-do: actually make this work and verify it works with recurrent_forward / chunkwise_forward
|
||||
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||
state = self.retnet.get_incremental_state( state, 'prev_state' )
|
||||
|
||||
x = self.classifier(x) * m
|
||||
|
||||
|
@ -351,25 +365,21 @@ class Base(nn.Module):
|
|||
acc = self.accuracy_metric( torch.cat(h_list), torch.cat(y_list) ),
|
||||
precision = self.precision_metric( torch.cat(h_list), torch.cat(y_list) ),
|
||||
)
|
||||
|
||||
|
||||
# return the entire generated token string
|
||||
if return_all:
|
||||
logits = [hi[:] for hi, li in zip(h_list, map(len, resps_list))]
|
||||
ret = [Categorical(logits=hi / sampling_temperature).sample() for hi in logits]
|
||||
# return the entire generated response
|
||||
elif return_all_resp:
|
||||
logits = [hi[-li:] for hi, li in zip(h_list, map(len, resps_list))]
|
||||
ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
|
||||
# return the last chunkwise piece
|
||||
elif self.causal_chunk_size > 0:
|
||||
logits = [hi[-self.causal_chunk_size:] for hi, li in zip(h_list, map(len, resps_list))]
|
||||
ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
|
||||
elif self.causal and self.recurrent_chunk_size > 0:
|
||||
logits = [hi[-self.recurrent_chunk_size:] for hi, li in zip(h_list, map(len, resps_list))]
|
||||
# return just the last code
|
||||
else:
|
||||
logits = torch.stack([hi[-1] for hi in h_list])
|
||||
ret = Categorical(logits=logits / sampling_temperature).sample()
|
||||
logits = [ hi[-1:] for hi in h_list ]
|
||||
|
||||
return ret, state
|
||||
return [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
|
||||
|
||||
def example_usage():
|
||||
from ..config import cfg
|
||||
|
@ -387,9 +397,6 @@ def example_usage():
|
|||
from .ar import AR
|
||||
from .nar import NAR
|
||||
|
||||
from ..models import get_models
|
||||
from ..config import Model as ModelCfg
|
||||
|
||||
device = "cuda"
|
||||
x8 = partial(repeat, pattern="t -> t l", l=2)
|
||||
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
|
||||
|
@ -398,20 +405,13 @@ def example_usage():
|
|||
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
||||
return torch.tensor([*map(symmap.get, phones)]).to()
|
||||
|
||||
models = get_models({
|
||||
"ar": Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True, size=1),
|
||||
"nar": Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True, size=1)}
|
||||
)
|
||||
"""
|
||||
model_cfg = ModelCfg()
|
||||
kwargs = {
|
||||
'n_tokens': model_cfg.tokens,
|
||||
'd_model': model_cfg.dim,
|
||||
'n_heads': model_cfg.heads,
|
||||
'n_layers': model_cfg.layers,
|
||||
'n_tokens': 1024,
|
||||
'd_model': 1024,
|
||||
'n_heads': 16,
|
||||
'n_layers': 24,
|
||||
}
|
||||
models = { "ar": AR(**kwargs).to(device), "nar": NAR(**kwargs).to(device) }
|
||||
"""
|
||||
engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
|
||||
|
||||
train = True
|
||||
|
|
|
@ -44,9 +44,13 @@ class NAR(Base):
|
|||
return cfg.models.tasks
|
||||
|
||||
@property
|
||||
def resp_loss_only(self):
|
||||
def resp_loss_only(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def recurrent_chunk_size(self) -> int:
|
||||
return 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_list: list[Tensor],
|
||||
|
@ -104,7 +108,7 @@ class NAR(Base):
|
|||
|
||||
quant_levels = torch.full((len(text_list),), level, device=device)
|
||||
|
||||
resps_list, _ = super().forward(
|
||||
resps_list = super().forward(
|
||||
text_list,
|
||||
proms_list,
|
||||
prev_list,
|
||||
|
|
|
@ -8,6 +8,10 @@ from typing import Dict, Optional
|
|||
|
||||
from torch import Tensor
|
||||
|
||||
from torchscale.architecture.config import RetNetConfig
|
||||
from torchscale.architecture.retnet import RetNetDecoder
|
||||
|
||||
"""
|
||||
class FairseqIncrementalState(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -24,7 +28,6 @@ class FairseqIncrementalState(object):
|
|||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
key: str,
|
||||
) -> Optional[Dict[str, Optional[Tensor]]]:
|
||||
"""Helper for getting incremental state for an nn.Module."""
|
||||
full_key = self._get_full_incremental_state_key(key)
|
||||
if incremental_state is None or full_key not in incremental_state:
|
||||
return None
|
||||
|
@ -36,7 +39,6 @@ class FairseqIncrementalState(object):
|
|||
key: str,
|
||||
value: Dict[str, Optional[Tensor]],
|
||||
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
|
||||
"""Helper for setting incremental state for an nn.Module."""
|
||||
if incremental_state is not None:
|
||||
full_key = self._get_full_incremental_state_key(key)
|
||||
incremental_state[full_key] = value
|
||||
|
@ -65,4 +67,5 @@ class RetNetDecoder(Decoder):
|
|||
for module in incremental_state:
|
||||
for key in incremental_state[module]:
|
||||
result = incremental_state[module][key].index_select(0, new_order)
|
||||
incremental_state[module][key] = result
|
||||
incremental_state[module][key] = result
|
||||
"""
|
|
@ -23,16 +23,17 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu")
|
|||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def train_feeder(engine, batch):
|
||||
engine(
|
||||
text_list=batch["text"],
|
||||
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level
|
||||
resps_list=batch["resps"]
|
||||
)
|
||||
with torch.autocast(engine.device, dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||
engine(
|
||||
text_list=batch["text"],
|
||||
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level
|
||||
resps_list=batch["resps"]
|
||||
)
|
||||
|
||||
losses = engine.gather_attribute("loss")
|
||||
stat = engine.gather_attribute("stats")
|
||||
losses = engine.gather_attribute("loss")
|
||||
stat = engine.gather_attribute("stats")
|
||||
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
|
|
Loading…
Reference in New Issue
Block a user