this embedding class definitely works, and migrating from the previous embedding weights seems to work.
This commit is contained in:
parent
a1f250ffac
commit
40ef34e1ca
|
@ -160,11 +160,12 @@ class Model:
|
|||
resp_levels: int = 1
|
||||
prom_levels: int = 8
|
||||
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
||||
arch_type: str = "transformer"
|
||||
arch_type: str = "retnet"
|
||||
training: bool = True
|
||||
interleave: bool = False
|
||||
frozen_params: list[str] = field(default_factory=lambda: [])
|
||||
p_ar_nar: float = 0.5
|
||||
version: int = 1
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
|
|
|
@ -279,7 +279,7 @@ class Dataset(_Dataset):
|
|||
|
||||
# shuffle it up a bit
|
||||
prom_length = 0
|
||||
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-16, 16)
|
||||
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
|
||||
|
||||
for _ in range(cfg.dataset.max_prompts):
|
||||
path = random.choice(choices)
|
||||
|
|
|
@ -57,6 +57,12 @@ class AR(Base):
|
|||
def monolithic(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.version
|
||||
return cfg.models.ar.version
|
||||
|
||||
def _prune(self, l: Tensor):
|
||||
indices = (l == self.stop_token).nonzero()
|
||||
if len(indices) == 0:
|
||||
|
|
|
@ -54,6 +54,12 @@ class AR_NAR(Base):
|
|||
def monolithic(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.version
|
||||
return cfg.models.ar_nar.version
|
||||
|
||||
def _prune(self, l: Tensor):
|
||||
indices = (l == self.stop_token).nonzero()
|
||||
if len(indices) == 0:
|
||||
|
@ -208,7 +214,7 @@ def example_usage():
|
|||
]
|
||||
proms_list = [
|
||||
#x8(torch.tensor([1, 2, 3], device=device)),
|
||||
qnt.to(device),
|
||||
qnt[:75*3, :].to(device),
|
||||
]
|
||||
resps_list = [
|
||||
qnt.to(device),
|
||||
|
@ -233,11 +239,15 @@ def example_usage():
|
|||
"""
|
||||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
#steps = 500
|
||||
#optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
||||
steps = 500
|
||||
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
||||
optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
||||
engine = Engine(model=model, optimizer=optimizer)
|
||||
|
||||
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
print([ name for name, _ in model.named_parameters()])
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample( name, steps=600 ):
|
||||
|
|
|
@ -151,64 +151,45 @@ class MultiEmbedding(nn.Embedding):
|
|||
else:
|
||||
w = self.weight
|
||||
|
||||
padded_x_list = []
|
||||
padded_x_list = []
|
||||
|
||||
for i, xi in enumerate(x_list):
|
||||
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
|
||||
wi = w.shape[0] - xi.shape[1]
|
||||
xi = F.pad(xi, (0, 0, 0, wi)) # t l k
|
||||
padded_x_list.append(xi.to(w))
|
||||
for i, xi in enumerate(x_list):
|
||||
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
|
||||
wi = w.shape[0] - xi.shape[1]
|
||||
xi = F.pad(xi, (0, 0, 0, wi)) # t l k
|
||||
padded_x_list.append(xi.to(w))
|
||||
|
||||
x = torch.cat(padded_x_list) # n l k
|
||||
x = einsum("l k d, n l k -> n d", w, x)
|
||||
x = torch.cat(padded_x_list) # n l k
|
||||
x = einsum("l k d, n l k -> n d", w, x)
|
||||
|
||||
x_list = x.split([*map(len, x_list)])
|
||||
x_list = x.split([*map(len, x_list)])
|
||||
|
||||
return x_list
|
||||
|
||||
"""
|
||||
w_ar, w_nar = self.weight[:1], self.weight[1:]
|
||||
p_ar_list, p_nar_list = [], []
|
||||
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
||||
class AudioEmbedding(nn.Module):
|
||||
def __init__(self, n_levels, n_tokens, token_dim):
|
||||
super().__init__()
|
||||
self.n_levels = n_levels
|
||||
# would it be better to have embeddings[1:] reduced to 1024 tokens to attend to, so it's *not* factoring in the stop token?
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
|
||||
|
||||
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]:
|
||||
res_list = []
|
||||
|
||||
for i, xi in enumerate(x_list):
|
||||
if quant_levels is None or quant_levels[i] == 0:
|
||||
w padded_x_list, = w_ar, p_ar_list
|
||||
# prom
|
||||
if quant_levels is None and xi.shape[-1] > 1:
|
||||
x = sum( [ self.embeddings[k]( xi[:, k] ) for k in range(xi.shape[-1]) ] )
|
||||
# AR resp
|
||||
elif quant_levels is None or quant_levels[i] == 0:
|
||||
x = self.embeddings[0]( xi[:, 0] )
|
||||
# NAR resp
|
||||
else:
|
||||
w, padded_x_list = w_nar, p_nar_list
|
||||
x = sum( [ self.embeddings[k+1]( xi[:, k] ) for k in range(xi.shape[-1]) ] )
|
||||
res_list.append(x)
|
||||
|
||||
# pad resp/prom tensor to fit weight
|
||||
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
|
||||
xi = F.pad(xi, (0, 0, 0, w.shape[0] - xi.shape[1])) # t l k
|
||||
padded_x_list.append(xi.to(w))
|
||||
|
||||
# batch list => batch tensor
|
||||
x_ar_list = einsum("l k d, n l k -> n d", w_ar, torch.cat(p_ar_list)) if len(p_ar_list) > 0 else []
|
||||
x_nar_list = einsum("l k d, n l k -> n d", w_nar, torch.cat(p_nar_list)) if len(p_nar_list) > 0 else []
|
||||
|
||||
x_list = x.split([*map(len, x_list)])
|
||||
"""
|
||||
|
||||
"""
|
||||
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
||||
class PromEmbedding(nn.Module):
|
||||
def __init__(self, n_levels, n_tokens, token_dim):
|
||||
super().__init__()
|
||||
self.n_levels = n_levels
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
|
||||
|
||||
def forward(self, x_list: list[Tensor] ) -> list[Tensor]:
|
||||
return [ sum([ self.embeddings[k](xi[:, k]) for k in range(xi.shape[-1]) ]) for i, xi in enumerate(x_list) ]
|
||||
|
||||
# Embedding that selects which embedding based on a quant_level tensor for a given batch
|
||||
class RespEmbedding(nn.Module):
|
||||
def __init__(self, n_levels, n_tokens, token_dim):
|
||||
super().__init__()
|
||||
self.n_levels = n_levels
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
|
||||
|
||||
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
|
||||
return [ self.embeddings[min(self.n_levels, quant_levels[i]) if quant_levels is not None else 0](xi)[:, 0, :] for i, xi in enumerate(x_list) ]
|
||||
"""
|
||||
return res_list
|
||||
|
||||
class Base(nn.Module):
|
||||
@property
|
||||
|
@ -252,8 +233,8 @@ class Base(nn.Module):
|
|||
return False
|
||||
|
||||
@property
|
||||
def n_embeddings(self) -> int:
|
||||
return 2 if self.monolithic else 1
|
||||
def version(self) -> int:
|
||||
return 1
|
||||
|
||||
@property
|
||||
def stop_token(self):
|
||||
|
@ -298,8 +279,12 @@ class Base(nn.Module):
|
|||
|
||||
self.text_emb = Embedding(n_tokens, d_model)
|
||||
|
||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) #, monolithic=self.monolithic)
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
||||
if self.version == 1: # legacy
|
||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
||||
else:
|
||||
self.proms_emb = AudioEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
self.resps_emb = AudioEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
|
|
|
@ -39,6 +39,12 @@ class NAR(Base):
|
|||
def n_tasks(self) -> int:
|
||||
return cfg.models.tasks
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.version
|
||||
return cfg.models.nar.version
|
||||
|
||||
@property
|
||||
def recurrent_chunk_size(self) -> int:
|
||||
return 0
|
||||
|
|
Loading…
Reference in New Issue
Block a user