somewhat got recurrent forward working (it's as accurate as chunkwise forward: it's not accurate at all), added option to use AMP instead of blanket setting the weight's dtype
This commit is contained in:
parent
2bc2d08b09
commit
e40c0d34a0
|
@ -165,7 +165,6 @@ class Model:
|
||||||
arch_type: str = "transformer"
|
arch_type: str = "transformer"
|
||||||
training: bool = True
|
training: bool = True
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def full_name(self):
|
def full_name(self):
|
||||||
name = [ self.name ]
|
name = [ self.name ]
|
||||||
|
@ -332,9 +331,9 @@ class DeepSpeed:
|
||||||
"fp16": {
|
"fp16": {
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"auto_cast": True,
|
"auto_cast": True,
|
||||||
} if cfg.trainer.weight_dtype.lower() == "float16" else None,
|
} if cfg.trainer.weight_dtype.lower() == "float16" and not cfg.trainer.amp else None,
|
||||||
"bf16": {
|
"bf16": {
|
||||||
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16"
|
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16" and not cfg.trainer.amp
|
||||||
},
|
},
|
||||||
"compression_training": {
|
"compression_training": {
|
||||||
"weight_quantization": {
|
"weight_quantization": {
|
||||||
|
@ -427,15 +426,13 @@ class Trainer:
|
||||||
|
|
||||||
aggressive_optimizations: bool = False
|
aggressive_optimizations: bool = False
|
||||||
check_for_oom: bool = True
|
check_for_oom: bool = True
|
||||||
|
gc_mode: str | None = None
|
||||||
load_disabled_engines: bool = False
|
load_disabled_engines: bool = False
|
||||||
|
|
||||||
gc_mode: str | None = None
|
|
||||||
|
|
||||||
weight_dtype: str = "float16"
|
weight_dtype: str = "float16"
|
||||||
|
amp: bool = False
|
||||||
|
|
||||||
backend: str = "local" # "deepspeed" if not sys.platform.startswith("win") else "local"
|
backend: str = "local"
|
||||||
|
|
||||||
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
|
@ -450,10 +447,14 @@ class Trainer:
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Inference:
|
class Inference:
|
||||||
weight_dtype: str = "float32"
|
weight_dtype: str = "float32"
|
||||||
|
amp: bool = False
|
||||||
|
|
||||||
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
||||||
use_vocos: bool = True
|
use_vocos: bool = True
|
||||||
|
|
||||||
|
recurrent_chunk_size: int = 0
|
||||||
|
recurrent_forward: bool = False
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
if self.weight_dtype == "float16":
|
if self.weight_dtype == "float16":
|
||||||
|
@ -473,6 +474,7 @@ class BitsAndBytes:
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Config(_Config):
|
class Config(_Config):
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
|
mode: str = "training" # "inferencing"
|
||||||
|
|
||||||
dataset: Dataset = field(default_factory=lambda: Dataset)
|
dataset: Dataset = field(default_factory=lambda: Dataset)
|
||||||
models: Models = field(default_factory=lambda: Models)
|
models: Models = field(default_factory=lambda: Models)
|
||||||
|
|
|
@ -55,7 +55,7 @@ class Engine():
|
||||||
self._cfg = kwargs['_cfg']
|
self._cfg = kwargs['_cfg']
|
||||||
kwargs.pop("_cfg")
|
kwargs.pop("_cfg")
|
||||||
|
|
||||||
self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype)
|
self.module = kwargs['model'].to(cfg.device).to(torch.float32 if cfg.trainer.amp else cfg.trainer.dtype)
|
||||||
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
|
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
|
||||||
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
|
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
|
||||||
|
|
||||||
|
@ -196,9 +196,11 @@ class Engine():
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
def traverse(self, *args, **kwargs):
|
def traverse(self, *args, **kwargs):
|
||||||
self.forward(*args, **kwargs)
|
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||||
losses = self.gather_attribute("loss")
|
self.forward(*args, **kwargs)
|
||||||
loss = torch.stack([*losses.values()]).sum()
|
losses = self.gather_attribute("loss")
|
||||||
|
print(self.module.loss)
|
||||||
|
loss = torch.stack([*losses.values()]).sum()
|
||||||
|
|
||||||
stats = {}
|
stats = {}
|
||||||
stats |= {k: v.item() for k, v in losses.items()}
|
stats |= {k: v.item() for k, v in losses.items()}
|
||||||
|
|
|
@ -87,9 +87,10 @@ class Engine(DeepSpeedEngine):
|
||||||
print(str(e))
|
print(str(e))
|
||||||
|
|
||||||
def traverse(self, *args, **kwargs):
|
def traverse(self, *args, **kwargs):
|
||||||
self.forward(*args, **kwargs)
|
with torch.autocast(self.device, dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||||
losses = self.gather_attribute("loss")
|
self.forward(*args, **kwargs)
|
||||||
loss = torch.stack([*losses.values()]).sum()
|
losses = self.gather_attribute("loss")
|
||||||
|
loss = torch.stack([*losses.values()]).sum()
|
||||||
|
|
||||||
stats = {}
|
stats = {}
|
||||||
stats |= {k: v.item() for k, v in losses.items()}
|
stats |= {k: v.item() for k, v in losses.items()}
|
||||||
|
|
|
@ -32,6 +32,8 @@ class TTS():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
cfg.mode = "inferencing"
|
||||||
|
|
||||||
self.symmap = None
|
self.symmap = None
|
||||||
if ar_ckpt and nar_ckpt:
|
if ar_ckpt and nar_ckpt:
|
||||||
self.ar_ckpt = ar_ckpt
|
self.ar_ckpt = ar_ckpt
|
||||||
|
@ -47,7 +49,7 @@ class TTS():
|
||||||
if "module" in state:
|
if "module" in state:
|
||||||
state = state['module']
|
state = state['module']
|
||||||
self.ar.load_state_dict(state)
|
self.ar.load_state_dict(state)
|
||||||
self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype)
|
self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||||
elif name.startswith("nar"):
|
elif name.startswith("nar"):
|
||||||
self.nar = model
|
self.nar = model
|
||||||
state = torch.load(self.nar_ckpt)
|
state = torch.load(self.nar_ckpt)
|
||||||
|
@ -56,7 +58,7 @@ class TTS():
|
||||||
if "module" in state:
|
if "module" in state:
|
||||||
state = state['module']
|
state = state['module']
|
||||||
self.nar.load_state_dict(state)
|
self.nar.load_state_dict(state)
|
||||||
self.nar = self.nar.to(self.device, dtype=cfg.inference.dtype)
|
self.nar = self.nar.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||||
else:
|
else:
|
||||||
self.load_models()
|
self.load_models()
|
||||||
|
|
||||||
|
@ -72,9 +74,9 @@ class TTS():
|
||||||
engines = load_engines()
|
engines = load_engines()
|
||||||
for name, engine in engines.items():
|
for name, engine in engines.items():
|
||||||
if name[:2] == "ar":
|
if name[:2] == "ar":
|
||||||
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype)
|
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||||
elif name[:3] == "nar":
|
elif name[:3] == "nar":
|
||||||
self.nar = engine.module.to(self.device, dtype=cfg.inference.dtype)
|
self.nar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||||
|
|
||||||
def encode_text( self, text, language="en" ):
|
def encode_text( self, text, language="en" ):
|
||||||
# already a tensor, return it
|
# already a tensor, return it
|
||||||
|
@ -119,9 +121,10 @@ class TTS():
|
||||||
prom = to_device(prom, self.device).to(torch.int16)
|
prom = to_device(prom, self.device).to(torch.int16)
|
||||||
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||||
|
|
||||||
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp)
|
with torch.autocast(self.device, dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp)
|
||||||
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp)
|
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||||
|
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp)
|
||||||
|
|
||||||
wav, sr = qnt.decode_to_file(resps_list[0], out_path)
|
wav, sr = qnt.decode_to_file(resps_list[0], out_path)
|
||||||
|
|
||||||
|
|
|
@ -46,9 +46,15 @@ class AR(Base):
|
||||||
return cfg.models.tasks
|
return cfg.models.tasks
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def resp_loss_only(self):
|
def resp_loss_only(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def recurrent_chunk_size(self) -> int:
|
||||||
|
if cfg.mode == "training":
|
||||||
|
return 0
|
||||||
|
return cfg.inference.recurrent_chunk_size
|
||||||
|
|
||||||
def _prune(self, l: Tensor):
|
def _prune(self, l: Tensor):
|
||||||
indices = (l == self.stop_token).nonzero()
|
indices = (l == self.stop_token).nonzero()
|
||||||
if len(indices) == 0:
|
if len(indices) == 0:
|
||||||
|
@ -66,8 +72,6 @@ class AR(Base):
|
||||||
resps_list: list[Tensor] | None = None,
|
resps_list: list[Tensor] | None = None,
|
||||||
max_steps: int = 1000,
|
max_steps: int = 1000,
|
||||||
sampling_temperature: float = 1.0,
|
sampling_temperature: float = 1.0,
|
||||||
|
|
||||||
naive: bool = True,
|
|
||||||
):
|
):
|
||||||
if resps_list is not None:
|
if resps_list is not None:
|
||||||
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
|
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
|
||||||
|
@ -83,109 +87,35 @@ class AR(Base):
|
||||||
)
|
)
|
||||||
|
|
||||||
device = text_list[0].device
|
device = text_list[0].device
|
||||||
resps_list: list[Tensor] = [
|
batch_size = len(text_list)
|
||||||
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
|
|
||||||
]
|
|
||||||
stopped = torch.zeros(len(text_list), device=device).bool()
|
|
||||||
|
|
||||||
chunk_size = self.causal_chunk_size # don't really know what to do about this desu
|
resps_list: list[Tensor] = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
|
||||||
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
||||||
state = None
|
state = {} if cfg.inference.recurrent_forward else None
|
||||||
start = 0
|
|
||||||
|
|
||||||
if naive:
|
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
||||||
for n in trange(max_steps // max(1, chunk_size)):
|
# get next in sequence
|
||||||
# get next in sequence
|
|
||||||
|
|
||||||
r, state = super().forward(
|
r = super().forward(
|
||||||
text_list,
|
text_list,
|
||||||
proms_list,
|
proms_list,
|
||||||
self._unsqueeze_list(resps_list),
|
self._unsqueeze_list(resps_list),
|
||||||
sampling_temperature=sampling_temperature,
|
sampling_temperature=sampling_temperature,
|
||||||
state=state # if not naive else None,
|
state=state
|
||||||
)
|
|
||||||
|
|
||||||
# append outputted token
|
|
||||||
if self.causal_chunk_size > 0:
|
|
||||||
for i, ri in enumerate(r):
|
|
||||||
resps_list[i] = torch.cat([resps_list[i], ri])
|
|
||||||
else:
|
|
||||||
for i, ri in enumerate(r):
|
|
||||||
resps_list[i] = torch.cat([resps_list[i], ri[None]])
|
|
||||||
|
|
||||||
|
|
||||||
# stop token found
|
|
||||||
stopped |= r == self.stop_token
|
|
||||||
if stopped.all().item():
|
|
||||||
break
|
|
||||||
# to-do: make it work
|
|
||||||
# it seems anything that isn't a one-at-a-time sequence does not work, despite generating STOP tokens.
|
|
||||||
else:
|
|
||||||
resps_list: list[Tensor] = [
|
|
||||||
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
|
|
||||||
]
|
|
||||||
|
|
||||||
test_list: list[Tensor] = [
|
|
||||||
torch.zeros(0, device=device).to(torch.int16) for _ in text_list
|
|
||||||
]
|
|
||||||
|
|
||||||
batch_size = len(text_list)
|
|
||||||
|
|
||||||
x_list = self._samplewise_merge_tensors(
|
|
||||||
self.text_emb(text_list),
|
|
||||||
self.proms_emb(proms_list),
|
|
||||||
self.resps_emb(self._unsqueeze_list(resps_list)),
|
|
||||||
sep=self.sep,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
x, m = list_to_tensor(x_list)
|
# append tokens
|
||||||
device = x.device
|
|
||||||
|
|
||||||
if state is None:
|
|
||||||
state = {}
|
|
||||||
|
|
||||||
# pre-fill KV cache
|
|
||||||
for n in trange(x.shape[1]):
|
|
||||||
xs = x[:, n:(n + 1), :]
|
|
||||||
r, _ = self.retnet(xs, incremental_state=state, token_embeddings=xs, features_only=True)
|
|
||||||
r = self.classifier(r) * m
|
|
||||||
|
|
||||||
logits = torch.stack([hi[-1] for hi in r])
|
|
||||||
r = Categorical(logits=logits / sampling_temperature).sample()
|
|
||||||
|
|
||||||
for i, ri in enumerate(r):
|
|
||||||
test_list[i] = torch.cat([test_list[i], ri[None]])
|
|
||||||
|
|
||||||
# append outputted token
|
|
||||||
for i, ri in enumerate(r):
|
for i, ri in enumerate(r):
|
||||||
resps_list[i] = torch.cat([resps_list[i], ri[None]])
|
if self.stop_token in ri:
|
||||||
|
stopped[i] = True
|
||||||
|
resps_list[i] = torch.cat([resps_list[i], ri])
|
||||||
|
|
||||||
start = x.shape[1]
|
# stop token found
|
||||||
for n in trange(max_steps // max(1, chunk_size)):
|
stopped |= r == self.stop_token
|
||||||
x_list = self._samplewise_merge_tensors(
|
if stopped.all().item():
|
||||||
self.text_emb(text_list),
|
break
|
||||||
self.proms_emb(proms_list),
|
|
||||||
self.resps_emb(self._unsqueeze_list(resps_list)),
|
|
||||||
sep=self.sep,
|
|
||||||
)
|
|
||||||
|
|
||||||
x, m = list_to_tensor(x_list)
|
|
||||||
|
|
||||||
xs = x[:, start+n:start+(n+1), :]
|
|
||||||
r, _ = self.retnet(xs, incremental_state=state, token_embeddings=xs, features_only=True)
|
|
||||||
r = self.classifier(r) * m
|
|
||||||
|
|
||||||
logits = torch.stack([hi[-1] for hi in r])
|
|
||||||
r = Categorical(logits=logits / sampling_temperature).sample()
|
|
||||||
|
|
||||||
# append outputted token
|
|
||||||
for i, ri in enumerate(r):
|
|
||||||
resps_list[i] = torch.cat([resps_list[i], ri[None]])
|
|
||||||
|
|
||||||
# stop token found
|
|
||||||
stopped |= r == self.stop_token
|
|
||||||
if stopped.all().item():
|
|
||||||
break
|
|
||||||
|
|
||||||
pruned = [self._prune(r) for r in resps_list]
|
pruned = [self._prune(r) for r in resps_list]
|
||||||
return pruned
|
return pruned
|
||||||
|
|
|
@ -126,6 +126,10 @@ class Base(nn.Module):
|
||||||
def resp_loss_only(self):
|
def resp_loss_only(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def recurrent_chunk_size(self) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n_tokens: int = 1024,
|
n_tokens: int = 1024,
|
||||||
|
@ -150,8 +154,6 @@ class Base(nn.Module):
|
||||||
|
|
||||||
self.sep = nn.Parameter(torch.randn(d_model))
|
self.sep = nn.Parameter(torch.randn(d_model))
|
||||||
|
|
||||||
self.causal_chunk_size = 0 # 64 if self.causal else 1
|
|
||||||
|
|
||||||
if self.arch_type == "transformer":
|
if self.arch_type == "transformer":
|
||||||
self.sin_emb = SinusoidalEmbedding(d_model)
|
self.sin_emb = SinusoidalEmbedding(d_model)
|
||||||
self.blocks = nn.ModuleList([TransformerBlock(
|
self.blocks = nn.ModuleList([TransformerBlock(
|
||||||
|
@ -173,8 +175,8 @@ class Base(nn.Module):
|
||||||
dropout=p_dropout,
|
dropout=p_dropout,
|
||||||
checkpoint_activations=True,
|
checkpoint_activations=True,
|
||||||
|
|
||||||
chunkwise_recurrent=self.causal and self.causal_chunk_size > 0,
|
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
|
||||||
recurrent_chunkwise_size=self.causal_chunk_size,
|
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
|
||||||
no_output_layer=True,
|
no_output_layer=True,
|
||||||
decoder_normalize_before=True,
|
decoder_normalize_before=True,
|
||||||
))
|
))
|
||||||
|
@ -257,7 +259,7 @@ class Base(nn.Module):
|
||||||
return_all_resp: bool = False,
|
return_all_resp: bool = False,
|
||||||
sampling_temperature: float = 1.0,
|
sampling_temperature: float = 1.0,
|
||||||
|
|
||||||
state: list | None = None,
|
state: dict | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -285,6 +287,19 @@ class Base(nn.Module):
|
||||||
x, m = list_to_tensor(x_list)
|
x, m = list_to_tensor(x_list)
|
||||||
device = x.device
|
device = x.device
|
||||||
|
|
||||||
|
if state is not None:
|
||||||
|
# prefill
|
||||||
|
prefill_size = x.shape[1]
|
||||||
|
|
||||||
|
# run the initial prompt to fill the KV cache
|
||||||
|
if len(state) == 0:
|
||||||
|
for n in range(prefill_size):
|
||||||
|
xi = x[:, n, :].unsqueeze(1)
|
||||||
|
self.retnet(xi, incremental_state=state, token_embeddings=xi, features_only=True)
|
||||||
|
|
||||||
|
# grab last token(s)
|
||||||
|
x = x[:, -1, :].unsqueeze(1)
|
||||||
|
|
||||||
if self.arch_type == "transformer":
|
if self.arch_type == "transformer":
|
||||||
x = self.sin_emb.add_pe(x)
|
x = self.sin_emb.add_pe(x)
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
|
@ -292,7 +307,6 @@ class Base(nn.Module):
|
||||||
elif self.arch_type == "retnet":
|
elif self.arch_type == "retnet":
|
||||||
# to-do: actually make this work and verify it works with recurrent_forward / chunkwise_forward
|
# to-do: actually make this work and verify it works with recurrent_forward / chunkwise_forward
|
||||||
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
|
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||||
state = self.retnet.get_incremental_state( state, 'prev_state' )
|
|
||||||
|
|
||||||
x = self.classifier(x) * m
|
x = self.classifier(x) * m
|
||||||
|
|
||||||
|
@ -355,21 +369,17 @@ class Base(nn.Module):
|
||||||
# return the entire generated token string
|
# return the entire generated token string
|
||||||
if return_all:
|
if return_all:
|
||||||
logits = [hi[:] for hi, li in zip(h_list, map(len, resps_list))]
|
logits = [hi[:] for hi, li in zip(h_list, map(len, resps_list))]
|
||||||
ret = [Categorical(logits=hi / sampling_temperature).sample() for hi in logits]
|
|
||||||
# return the entire generated response
|
# return the entire generated response
|
||||||
elif return_all_resp:
|
elif return_all_resp:
|
||||||
logits = [hi[-li:] for hi, li in zip(h_list, map(len, resps_list))]
|
logits = [hi[-li:] for hi, li in zip(h_list, map(len, resps_list))]
|
||||||
ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
|
|
||||||
# return the last chunkwise piece
|
# return the last chunkwise piece
|
||||||
elif self.causal_chunk_size > 0:
|
elif self.causal and self.recurrent_chunk_size > 0:
|
||||||
logits = [hi[-self.causal_chunk_size:] for hi, li in zip(h_list, map(len, resps_list))]
|
logits = [hi[-self.recurrent_chunk_size:] for hi, li in zip(h_list, map(len, resps_list))]
|
||||||
ret = [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
|
|
||||||
# return just the last code
|
# return just the last code
|
||||||
else:
|
else:
|
||||||
logits = torch.stack([hi[-1] for hi in h_list])
|
logits = [ hi[-1:] for hi in h_list ]
|
||||||
ret = Categorical(logits=logits / sampling_temperature).sample()
|
|
||||||
|
|
||||||
return ret, state
|
return [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
|
||||||
|
|
||||||
def example_usage():
|
def example_usage():
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
|
@ -387,9 +397,6 @@ def example_usage():
|
||||||
from .ar import AR
|
from .ar import AR
|
||||||
from .nar import NAR
|
from .nar import NAR
|
||||||
|
|
||||||
from ..models import get_models
|
|
||||||
from ..config import Model as ModelCfg
|
|
||||||
|
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
x8 = partial(repeat, pattern="t -> t l", l=2)
|
x8 = partial(repeat, pattern="t -> t l", l=2)
|
||||||
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
|
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
|
||||||
|
@ -398,20 +405,13 @@ def example_usage():
|
||||||
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
||||||
return torch.tensor([*map(symmap.get, phones)]).to()
|
return torch.tensor([*map(symmap.get, phones)]).to()
|
||||||
|
|
||||||
models = get_models({
|
|
||||||
"ar": Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True, size=1),
|
|
||||||
"nar": Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True, size=1)}
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
model_cfg = ModelCfg()
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'n_tokens': model_cfg.tokens,
|
'n_tokens': 1024,
|
||||||
'd_model': model_cfg.dim,
|
'd_model': 1024,
|
||||||
'n_heads': model_cfg.heads,
|
'n_heads': 16,
|
||||||
'n_layers': model_cfg.layers,
|
'n_layers': 24,
|
||||||
}
|
}
|
||||||
models = { "ar": AR(**kwargs).to(device), "nar": NAR(**kwargs).to(device) }
|
models = { "ar": AR(**kwargs).to(device), "nar": NAR(**kwargs).to(device) }
|
||||||
"""
|
|
||||||
engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
|
engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
|
||||||
|
|
||||||
train = True
|
train = True
|
||||||
|
|
|
@ -44,9 +44,13 @@ class NAR(Base):
|
||||||
return cfg.models.tasks
|
return cfg.models.tasks
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def resp_loss_only(self):
|
def resp_loss_only(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def recurrent_chunk_size(self) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
text_list: list[Tensor],
|
text_list: list[Tensor],
|
||||||
|
@ -104,7 +108,7 @@ class NAR(Base):
|
||||||
|
|
||||||
quant_levels = torch.full((len(text_list),), level, device=device)
|
quant_levels = torch.full((len(text_list),), level, device=device)
|
||||||
|
|
||||||
resps_list, _ = super().forward(
|
resps_list = super().forward(
|
||||||
text_list,
|
text_list,
|
||||||
proms_list,
|
proms_list,
|
||||||
prev_list,
|
prev_list,
|
||||||
|
|
|
@ -8,6 +8,10 @@ from typing import Dict, Optional
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
from torchscale.architecture.config import RetNetConfig
|
||||||
|
from torchscale.architecture.retnet import RetNetDecoder
|
||||||
|
|
||||||
|
"""
|
||||||
class FairseqIncrementalState(object):
|
class FairseqIncrementalState(object):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
@ -24,7 +28,6 @@ class FairseqIncrementalState(object):
|
||||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||||
key: str,
|
key: str,
|
||||||
) -> Optional[Dict[str, Optional[Tensor]]]:
|
) -> Optional[Dict[str, Optional[Tensor]]]:
|
||||||
"""Helper for getting incremental state for an nn.Module."""
|
|
||||||
full_key = self._get_full_incremental_state_key(key)
|
full_key = self._get_full_incremental_state_key(key)
|
||||||
if incremental_state is None or full_key not in incremental_state:
|
if incremental_state is None or full_key not in incremental_state:
|
||||||
return None
|
return None
|
||||||
|
@ -36,7 +39,6 @@ class FairseqIncrementalState(object):
|
||||||
key: str,
|
key: str,
|
||||||
value: Dict[str, Optional[Tensor]],
|
value: Dict[str, Optional[Tensor]],
|
||||||
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
|
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
|
||||||
"""Helper for setting incremental state for an nn.Module."""
|
|
||||||
if incremental_state is not None:
|
if incremental_state is not None:
|
||||||
full_key = self._get_full_incremental_state_key(key)
|
full_key = self._get_full_incremental_state_key(key)
|
||||||
incremental_state[full_key] = value
|
incremental_state[full_key] = value
|
||||||
|
@ -66,3 +68,4 @@ class RetNetDecoder(Decoder):
|
||||||
for key in incremental_state[module]:
|
for key in incremental_state[module]:
|
||||||
result = incremental_state[module][key].index_select(0, new_order)
|
result = incremental_state[module][key].index_select(0, new_order)
|
||||||
incremental_state[module][key] = result
|
incremental_state[module][key] = result
|
||||||
|
"""
|
|
@ -23,16 +23,17 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu")
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def train_feeder(engine, batch):
|
def train_feeder(engine, batch):
|
||||||
engine(
|
with torch.autocast(engine.device, dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||||
text_list=batch["text"],
|
engine(
|
||||||
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level
|
text_list=batch["text"],
|
||||||
resps_list=batch["resps"]
|
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level
|
||||||
)
|
resps_list=batch["resps"]
|
||||||
|
)
|
||||||
|
|
||||||
losses = engine.gather_attribute("loss")
|
losses = engine.gather_attribute("loss")
|
||||||
stat = engine.gather_attribute("stats")
|
stat = engine.gather_attribute("stats")
|
||||||
|
|
||||||
loss = torch.stack([*losses.values()]).sum()
|
loss = torch.stack([*losses.values()]).sum()
|
||||||
|
|
||||||
stats = {}
|
stats = {}
|
||||||
stats |= {k: v.item() for k, v in losses.items()}
|
stats |= {k: v.item() for k, v in losses.items()}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user