this embedding class definitely works, and migrating from the previous embedding weights seems to work.

This commit is contained in:
mrq 2023-09-11 14:13:42 -05:00
parent a1f250ffac
commit 40ef34e1ca
6 changed files with 64 additions and 56 deletions

View File

@ -160,11 +160,12 @@ class Model:
resp_levels: int = 1
prom_levels: int = 8
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
arch_type: str = "transformer"
arch_type: str = "retnet"
training: bool = True
interleave: bool = False
frozen_params: list[str] = field(default_factory=lambda: [])
p_ar_nar: float = 0.5
version: int = 1
@property
def full_name(self):

View File

@ -279,7 +279,7 @@ class Dataset(_Dataset):
# shuffle it up a bit
prom_length = 0
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-16, 16)
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
for _ in range(cfg.dataset.max_prompts):
path = random.choice(choices)

View File

@ -57,6 +57,12 @@ class AR(Base):
def monolithic(self) -> bool:
return False
@property
def version(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.version
return cfg.models.ar.version
def _prune(self, l: Tensor):
indices = (l == self.stop_token).nonzero()
if len(indices) == 0:

View File

@ -54,6 +54,12 @@ class AR_NAR(Base):
def monolithic(self) -> bool:
return True
@property
def version(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.version
return cfg.models.ar_nar.version
def _prune(self, l: Tensor):
indices = (l == self.stop_token).nonzero()
if len(indices) == 0:
@ -208,7 +214,7 @@ def example_usage():
]
proms_list = [
#x8(torch.tensor([1, 2, 3], device=device)),
qnt.to(device),
qnt[:75*3, :].to(device),
]
resps_list = [
qnt.to(device),
@ -233,11 +239,15 @@ def example_usage():
"""
model = AR_NAR(**kwargs).to(device)
#steps = 500
#optimizer = ml.Prodigy(model.parameters(), lr=1.0)
steps = 500
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
engine = Engine(model=model, optimizer=optimizer)
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
print([ name for name, _ in model.named_parameters()])
@torch.inference_mode()
def sample( name, steps=600 ):

View File

@ -151,64 +151,45 @@ class MultiEmbedding(nn.Embedding):
else:
w = self.weight
padded_x_list = []
padded_x_list = []
for i, xi in enumerate(x_list):
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
wi = w.shape[0] - xi.shape[1]
xi = F.pad(xi, (0, 0, 0, wi)) # t l k
padded_x_list.append(xi.to(w))
for i, xi in enumerate(x_list):
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
wi = w.shape[0] - xi.shape[1]
xi = F.pad(xi, (0, 0, 0, wi)) # t l k
padded_x_list.append(xi.to(w))
x = torch.cat(padded_x_list) # n l k
x = einsum("l k d, n l k -> n d", w, x)
x = torch.cat(padded_x_list) # n l k
x = einsum("l k d, n l k -> n d", w, x)
x_list = x.split([*map(len, x_list)])
x_list = x.split([*map(len, x_list)])
return x_list
"""
w_ar, w_nar = self.weight[:1], self.weight[1:]
p_ar_list, p_nar_list = [], []
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
class AudioEmbedding(nn.Module):
def __init__(self, n_levels, n_tokens, token_dim):
super().__init__()
self.n_levels = n_levels
# would it be better to have embeddings[1:] reduced to 1024 tokens to attend to, so it's *not* factoring in the stop token?
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]:
res_list = []
for i, xi in enumerate(x_list):
if quant_levels is None or quant_levels[i] == 0:
w padded_x_list, = w_ar, p_ar_list
# prom
if quant_levels is None and xi.shape[-1] > 1:
x = sum( [ self.embeddings[k]( xi[:, k] ) for k in range(xi.shape[-1]) ] )
# AR resp
elif quant_levels is None or quant_levels[i] == 0:
x = self.embeddings[0]( xi[:, 0] )
# NAR resp
else:
w, padded_x_list = w_nar, p_nar_list
x = sum( [ self.embeddings[k+1]( xi[:, k] ) for k in range(xi.shape[-1]) ] )
res_list.append(x)
# pad resp/prom tensor to fit weight
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
xi = F.pad(xi, (0, 0, 0, w.shape[0] - xi.shape[1])) # t l k
padded_x_list.append(xi.to(w))
# batch list => batch tensor
x_ar_list = einsum("l k d, n l k -> n d", w_ar, torch.cat(p_ar_list)) if len(p_ar_list) > 0 else []
x_nar_list = einsum("l k d, n l k -> n d", w_nar, torch.cat(p_nar_list)) if len(p_nar_list) > 0 else []
x_list = x.split([*map(len, x_list)])
"""
"""
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
class PromEmbedding(nn.Module):
def __init__(self, n_levels, n_tokens, token_dim):
super().__init__()
self.n_levels = n_levels
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
def forward(self, x_list: list[Tensor] ) -> list[Tensor]:
return [ sum([ self.embeddings[k](xi[:, k]) for k in range(xi.shape[-1]) ]) for i, xi in enumerate(x_list) ]
# Embedding that selects which embedding based on a quant_level tensor for a given batch
class RespEmbedding(nn.Module):
def __init__(self, n_levels, n_tokens, token_dim):
super().__init__()
self.n_levels = n_levels
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
return [ self.embeddings[min(self.n_levels, quant_levels[i]) if quant_levels is not None else 0](xi)[:, 0, :] for i, xi in enumerate(x_list) ]
"""
return res_list
class Base(nn.Module):
@property
@ -252,8 +233,8 @@ class Base(nn.Module):
return False
@property
def n_embeddings(self) -> int:
return 2 if self.monolithic else 1
def version(self) -> int:
return 1
@property
def stop_token(self):
@ -298,8 +279,12 @@ class Base(nn.Module):
self.text_emb = Embedding(n_tokens, d_model)
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) #, monolithic=self.monolithic)
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
if self.version == 1: # legacy
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
else:
self.proms_emb = AudioEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
self.resps_emb = AudioEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
self.sep = nn.Parameter(torch.randn(d_model))

View File

@ -39,6 +39,12 @@ class NAR(Base):
def n_tasks(self) -> int:
return cfg.models.tasks
@property
def version(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.version
return cfg.models.nar.version
@property
def recurrent_chunk_size(self) -> int:
return 0