this embedding class definitely works, and migrating from the previous embedding weights seems to work.

This commit is contained in:
mrq 2023-09-11 14:13:42 -05:00
parent a1f250ffac
commit 40ef34e1ca
6 changed files with 64 additions and 56 deletions

View File

@ -160,11 +160,12 @@ class Model:
resp_levels: int = 1 resp_levels: int = 1
prom_levels: int = 8 prom_levels: int = 8
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
arch_type: str = "transformer" arch_type: str = "retnet"
training: bool = True training: bool = True
interleave: bool = False interleave: bool = False
frozen_params: list[str] = field(default_factory=lambda: []) frozen_params: list[str] = field(default_factory=lambda: [])
p_ar_nar: float = 0.5 p_ar_nar: float = 0.5
version: int = 1
@property @property
def full_name(self): def full_name(self):

View File

@ -279,7 +279,7 @@ class Dataset(_Dataset):
# shuffle it up a bit # shuffle it up a bit
prom_length = 0 prom_length = 0
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-16, 16) trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
for _ in range(cfg.dataset.max_prompts): for _ in range(cfg.dataset.max_prompts):
path = random.choice(choices) path = random.choice(choices)

View File

@ -57,6 +57,12 @@ class AR(Base):
def monolithic(self) -> bool: def monolithic(self) -> bool:
return False return False
@property
def version(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.version
return cfg.models.ar.version
def _prune(self, l: Tensor): def _prune(self, l: Tensor):
indices = (l == self.stop_token).nonzero() indices = (l == self.stop_token).nonzero()
if len(indices) == 0: if len(indices) == 0:

View File

@ -54,6 +54,12 @@ class AR_NAR(Base):
def monolithic(self) -> bool: def monolithic(self) -> bool:
return True return True
@property
def version(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.version
return cfg.models.ar_nar.version
def _prune(self, l: Tensor): def _prune(self, l: Tensor):
indices = (l == self.stop_token).nonzero() indices = (l == self.stop_token).nonzero()
if len(indices) == 0: if len(indices) == 0:
@ -208,7 +214,7 @@ def example_usage():
] ]
proms_list = [ proms_list = [
#x8(torch.tensor([1, 2, 3], device=device)), #x8(torch.tensor([1, 2, 3], device=device)),
qnt.to(device), qnt[:75*3, :].to(device),
] ]
resps_list = [ resps_list = [
qnt.to(device), qnt.to(device),
@ -233,11 +239,15 @@ def example_usage():
""" """
model = AR_NAR(**kwargs).to(device) model = AR_NAR(**kwargs).to(device)
#steps = 500
#optimizer = ml.Prodigy(model.parameters(), lr=1.0)
steps = 500 steps = 500
optimizer = ml.Prodigy(model.parameters(), lr=1.0) optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
engine = Engine(model=model, optimizer=optimizer) engine = Engine(model=model, optimizer=optimizer)
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
print([ name for name, _ in model.named_parameters()])
@torch.inference_mode() @torch.inference_mode()
def sample( name, steps=600 ): def sample( name, steps=600 ):

View File

@ -151,64 +151,45 @@ class MultiEmbedding(nn.Embedding):
else: else:
w = self.weight w = self.weight
padded_x_list = [] padded_x_list = []
for i, xi in enumerate(x_list): for i, xi in enumerate(x_list):
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
wi = w.shape[0] - xi.shape[1] wi = w.shape[0] - xi.shape[1]
xi = F.pad(xi, (0, 0, 0, wi)) # t l k xi = F.pad(xi, (0, 0, 0, wi)) # t l k
padded_x_list.append(xi.to(w)) padded_x_list.append(xi.to(w))
x = torch.cat(padded_x_list) # n l k x = torch.cat(padded_x_list) # n l k
x = einsum("l k d, n l k -> n d", w, x) x = einsum("l k d, n l k -> n d", w, x)
x_list = x.split([*map(len, x_list)]) x_list = x.split([*map(len, x_list)])
return x_list return x_list
""" # Embedding that sums each RVQ-bin level within a given input acoustic prompt
w_ar, w_nar = self.weight[:1], self.weight[1:] class AudioEmbedding(nn.Module):
p_ar_list, p_nar_list = [], [] def __init__(self, n_levels, n_tokens, token_dim):
super().__init__()
self.n_levels = n_levels
# would it be better to have embeddings[1:] reduced to 1024 tokens to attend to, so it's *not* factoring in the stop token?
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]:
res_list = []
for i, xi in enumerate(x_list): for i, xi in enumerate(x_list):
if quant_levels is None or quant_levels[i] == 0: # prom
w padded_x_list, = w_ar, p_ar_list if quant_levels is None and xi.shape[-1] > 1:
x = sum( [ self.embeddings[k]( xi[:, k] ) for k in range(xi.shape[-1]) ] )
# AR resp
elif quant_levels is None or quant_levels[i] == 0:
x = self.embeddings[0]( xi[:, 0] )
# NAR resp
else: else:
w, padded_x_list = w_nar, p_nar_list x = sum( [ self.embeddings[k+1]( xi[:, k] ) for k in range(xi.shape[-1]) ] )
res_list.append(x)
# pad resp/prom tensor to fit weight return res_list
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
xi = F.pad(xi, (0, 0, 0, w.shape[0] - xi.shape[1])) # t l k
padded_x_list.append(xi.to(w))
# batch list => batch tensor
x_ar_list = einsum("l k d, n l k -> n d", w_ar, torch.cat(p_ar_list)) if len(p_ar_list) > 0 else []
x_nar_list = einsum("l k d, n l k -> n d", w_nar, torch.cat(p_nar_list)) if len(p_nar_list) > 0 else []
x_list = x.split([*map(len, x_list)])
"""
"""
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
class PromEmbedding(nn.Module):
def __init__(self, n_levels, n_tokens, token_dim):
super().__init__()
self.n_levels = n_levels
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
def forward(self, x_list: list[Tensor] ) -> list[Tensor]:
return [ sum([ self.embeddings[k](xi[:, k]) for k in range(xi.shape[-1]) ]) for i, xi in enumerate(x_list) ]
# Embedding that selects which embedding based on a quant_level tensor for a given batch
class RespEmbedding(nn.Module):
def __init__(self, n_levels, n_tokens, token_dim):
super().__init__()
self.n_levels = n_levels
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
return [ self.embeddings[min(self.n_levels, quant_levels[i]) if quant_levels is not None else 0](xi)[:, 0, :] for i, xi in enumerate(x_list) ]
"""
class Base(nn.Module): class Base(nn.Module):
@property @property
@ -252,8 +233,8 @@ class Base(nn.Module):
return False return False
@property @property
def n_embeddings(self) -> int: def version(self) -> int:
return 2 if self.monolithic else 1 return 1
@property @property
def stop_token(self): def stop_token(self):
@ -298,8 +279,12 @@ class Base(nn.Module):
self.text_emb = Embedding(n_tokens, d_model) self.text_emb = Embedding(n_tokens, d_model)
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) #, monolithic=self.monolithic) if self.version == 1: # legacy
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic) self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
else:
self.proms_emb = AudioEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
self.resps_emb = AudioEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
self.sep = nn.Parameter(torch.randn(d_model)) self.sep = nn.Parameter(torch.randn(d_model))

View File

@ -39,6 +39,12 @@ class NAR(Base):
def n_tasks(self) -> int: def n_tasks(self) -> int:
return cfg.models.tasks return cfg.models.tasks
@property
def version(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.version
return cfg.models.nar.version
@property @property
def recurrent_chunk_size(self) -> int: def recurrent_chunk_size(self) -> int:
return 0 return 0