this embedding class definitely works, and migrating from the previous embedding weights seems to work.
This commit is contained in:
parent
a1f250ffac
commit
40ef34e1ca
|
@ -160,11 +160,12 @@ class Model:
|
||||||
resp_levels: int = 1
|
resp_levels: int = 1
|
||||||
prom_levels: int = 8
|
prom_levels: int = 8
|
||||||
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
||||||
arch_type: str = "transformer"
|
arch_type: str = "retnet"
|
||||||
training: bool = True
|
training: bool = True
|
||||||
interleave: bool = False
|
interleave: bool = False
|
||||||
frozen_params: list[str] = field(default_factory=lambda: [])
|
frozen_params: list[str] = field(default_factory=lambda: [])
|
||||||
p_ar_nar: float = 0.5
|
p_ar_nar: float = 0.5
|
||||||
|
version: int = 1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def full_name(self):
|
def full_name(self):
|
||||||
|
|
|
@ -279,7 +279,7 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
# shuffle it up a bit
|
# shuffle it up a bit
|
||||||
prom_length = 0
|
prom_length = 0
|
||||||
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-16, 16)
|
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
|
||||||
|
|
||||||
for _ in range(cfg.dataset.max_prompts):
|
for _ in range(cfg.dataset.max_prompts):
|
||||||
path = random.choice(choices)
|
path = random.choice(choices)
|
||||||
|
|
|
@ -57,6 +57,12 @@ class AR(Base):
|
||||||
def monolithic(self) -> bool:
|
def monolithic(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def version(self) -> int:
|
||||||
|
if hasattr(self, "config") and self.config:
|
||||||
|
return self.config.version
|
||||||
|
return cfg.models.ar.version
|
||||||
|
|
||||||
def _prune(self, l: Tensor):
|
def _prune(self, l: Tensor):
|
||||||
indices = (l == self.stop_token).nonzero()
|
indices = (l == self.stop_token).nonzero()
|
||||||
if len(indices) == 0:
|
if len(indices) == 0:
|
||||||
|
|
|
@ -54,6 +54,12 @@ class AR_NAR(Base):
|
||||||
def monolithic(self) -> bool:
|
def monolithic(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def version(self) -> int:
|
||||||
|
if hasattr(self, "config") and self.config:
|
||||||
|
return self.config.version
|
||||||
|
return cfg.models.ar_nar.version
|
||||||
|
|
||||||
def _prune(self, l: Tensor):
|
def _prune(self, l: Tensor):
|
||||||
indices = (l == self.stop_token).nonzero()
|
indices = (l == self.stop_token).nonzero()
|
||||||
if len(indices) == 0:
|
if len(indices) == 0:
|
||||||
|
@ -208,7 +214,7 @@ def example_usage():
|
||||||
]
|
]
|
||||||
proms_list = [
|
proms_list = [
|
||||||
#x8(torch.tensor([1, 2, 3], device=device)),
|
#x8(torch.tensor([1, 2, 3], device=device)),
|
||||||
qnt.to(device),
|
qnt[:75*3, :].to(device),
|
||||||
]
|
]
|
||||||
resps_list = [
|
resps_list = [
|
||||||
qnt.to(device),
|
qnt.to(device),
|
||||||
|
@ -233,12 +239,16 @@ def example_usage():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model = AR_NAR(**kwargs).to(device)
|
model = AR_NAR(**kwargs).to(device)
|
||||||
|
#steps = 500
|
||||||
|
#optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
||||||
steps = 500
|
steps = 500
|
||||||
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
||||||
engine = Engine(model=model, optimizer=optimizer)
|
engine = Engine(model=model, optimizer=optimizer)
|
||||||
|
|
||||||
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
|
print([ name for name, _ in model.named_parameters()])
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def sample( name, steps=600 ):
|
def sample( name, steps=600 ):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
|
|
|
@ -166,49 +166,30 @@ class MultiEmbedding(nn.Embedding):
|
||||||
|
|
||||||
return x_list
|
return x_list
|
||||||
|
|
||||||
"""
|
|
||||||
w_ar, w_nar = self.weight[:1], self.weight[1:]
|
|
||||||
p_ar_list, p_nar_list = [], []
|
|
||||||
|
|
||||||
for i, xi in enumerate(x_list):
|
|
||||||
if quant_levels is None or quant_levels[i] == 0:
|
|
||||||
w padded_x_list, = w_ar, p_ar_list
|
|
||||||
else:
|
|
||||||
w, padded_x_list = w_nar, p_nar_list
|
|
||||||
|
|
||||||
# pad resp/prom tensor to fit weight
|
|
||||||
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
|
|
||||||
xi = F.pad(xi, (0, 0, 0, w.shape[0] - xi.shape[1])) # t l k
|
|
||||||
padded_x_list.append(xi.to(w))
|
|
||||||
|
|
||||||
# batch list => batch tensor
|
|
||||||
x_ar_list = einsum("l k d, n l k -> n d", w_ar, torch.cat(p_ar_list)) if len(p_ar_list) > 0 else []
|
|
||||||
x_nar_list = einsum("l k d, n l k -> n d", w_nar, torch.cat(p_nar_list)) if len(p_nar_list) > 0 else []
|
|
||||||
|
|
||||||
x_list = x.split([*map(len, x_list)])
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
|
||||||
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
||||||
class PromEmbedding(nn.Module):
|
class AudioEmbedding(nn.Module):
|
||||||
def __init__(self, n_levels, n_tokens, token_dim):
|
|
||||||
super().__init__()
|
|
||||||
self.n_levels = n_levels
|
|
||||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
|
|
||||||
|
|
||||||
def forward(self, x_list: list[Tensor] ) -> list[Tensor]:
|
|
||||||
return [ sum([ self.embeddings[k](xi[:, k]) for k in range(xi.shape[-1]) ]) for i, xi in enumerate(x_list) ]
|
|
||||||
|
|
||||||
# Embedding that selects which embedding based on a quant_level tensor for a given batch
|
|
||||||
class RespEmbedding(nn.Module):
|
|
||||||
def __init__(self, n_levels, n_tokens, token_dim):
|
def __init__(self, n_levels, n_tokens, token_dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_levels = n_levels
|
self.n_levels = n_levels
|
||||||
|
# would it be better to have embeddings[1:] reduced to 1024 tokens to attend to, so it's *not* factoring in the stop token?
|
||||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
|
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
|
||||||
|
|
||||||
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]:
|
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]:
|
||||||
return [ self.embeddings[min(self.n_levels, quant_levels[i]) if quant_levels is not None else 0](xi)[:, 0, :] for i, xi in enumerate(x_list) ]
|
res_list = []
|
||||||
"""
|
|
||||||
|
for i, xi in enumerate(x_list):
|
||||||
|
# prom
|
||||||
|
if quant_levels is None and xi.shape[-1] > 1:
|
||||||
|
x = sum( [ self.embeddings[k]( xi[:, k] ) for k in range(xi.shape[-1]) ] )
|
||||||
|
# AR resp
|
||||||
|
elif quant_levels is None or quant_levels[i] == 0:
|
||||||
|
x = self.embeddings[0]( xi[:, 0] )
|
||||||
|
# NAR resp
|
||||||
|
else:
|
||||||
|
x = sum( [ self.embeddings[k+1]( xi[:, k] ) for k in range(xi.shape[-1]) ] )
|
||||||
|
res_list.append(x)
|
||||||
|
|
||||||
|
return res_list
|
||||||
|
|
||||||
class Base(nn.Module):
|
class Base(nn.Module):
|
||||||
@property
|
@property
|
||||||
|
@ -252,8 +233,8 @@ class Base(nn.Module):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_embeddings(self) -> int:
|
def version(self) -> int:
|
||||||
return 2 if self.monolithic else 1
|
return 1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stop_token(self):
|
def stop_token(self):
|
||||||
|
@ -298,8 +279,12 @@ class Base(nn.Module):
|
||||||
|
|
||||||
self.text_emb = Embedding(n_tokens, d_model)
|
self.text_emb = Embedding(n_tokens, d_model)
|
||||||
|
|
||||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) #, monolithic=self.monolithic)
|
if self.version == 1: # legacy
|
||||||
|
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
||||||
|
else:
|
||||||
|
self.proms_emb = AudioEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||||
|
self.resps_emb = AudioEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||||
|
|
||||||
self.sep = nn.Parameter(torch.randn(d_model))
|
self.sep = nn.Parameter(torch.randn(d_model))
|
||||||
|
|
||||||
|
|
|
@ -39,6 +39,12 @@ class NAR(Base):
|
||||||
def n_tasks(self) -> int:
|
def n_tasks(self) -> int:
|
||||||
return cfg.models.tasks
|
return cfg.models.tasks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def version(self) -> int:
|
||||||
|
if hasattr(self, "config") and self.config:
|
||||||
|
return self.config.version
|
||||||
|
return cfg.models.nar.version
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def recurrent_chunk_size(self) -> int:
|
def recurrent_chunk_size(self) -> int:
|
||||||
return 0
|
return 0
|
||||||
|
|
Loading…
Reference in New Issue
Block a user