added summing of external embeddings (at this point i dont think any amount of cope bandaids will get DAC to train nicely, I think the RVQ levels the NAR tends add too much noise if they're not accurate)

This commit is contained in:
mrq 2024-06-29 23:42:30 -05:00
parent 793ccb16fb
commit b21f74a5c5
4 changed files with 23 additions and 9 deletions

View File

@ -267,11 +267,31 @@ def _replace_file_extension(path, suffix):
# > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime # > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime
# each audio backend does their "embeddings" a different way that isn't just a embedding weights # each audio backend does their "embeddings" a different way that isn't just a embedding weights
@torch.inference_mode() @torch.inference_mode()
def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cuda"): def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device="cuda"):
model = _load_model(device) model = _load_model(device)
codes = codes.to(device=device, dtype=torch.int32) codes = codes.to(device=device, dtype=torch.int32)
# yucky kludge
if sums:
if codes.dim() == 1:
codes = rearrange(codes, "t -> t 1")
if cfg.audio_backend == "dac":
x = []
for i in range(quant_level+1):
emb = model.quantizer.quantizers[i]
code = rearrange(codes[:, quant_level], "t -> 1 t")
xi = emb.decode_code(code)
xi = emb.out_proj(xi)
x.append( xi[0].t() )
return sum(x).detach()
raise Exception(f'Currently only DAC is supported')
if codes.dim() == 2: if codes.dim() == 2:
codes = codes[:, quant_level] codes = codes[:, quant_level]

View File

@ -473,7 +473,7 @@ def example_usage():
""" """
model = AR_NAR(**kwargs).to(device) model = AR_NAR(**kwargs).to(device)
steps = 100 steps = 150
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""

View File

@ -194,7 +194,7 @@ class AudioEmbedding(nn.Module):
input = input[:-1] input = input[:-1]
# get external embedding # get external embedding
embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype) embedding = encode_as_embedding( input, quant_level, sums=self.sums ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype)
# resize if necessary (in case the external embeddings do not match our model dim) # resize if necessary (in case the external embeddings do not match our model dim)
embedding = ml.resize_weight( embedding, self.embeddings[quant_level].weight.shape[-1], dim=-1, random=False ) embedding = ml.resize_weight( embedding, self.embeddings[quant_level].weight.shape[-1], dim=-1, random=False )

View File

@ -252,12 +252,6 @@ def example_usage():
kwargs = {} kwargs = {}
model = Model(**kwargs).to(device) model = Model(**kwargs).to(device)
steps = 100 steps = 100
if cfg.model.arch_type == "mamba2":
steps = 100
elif cfg.model.arch_type == "llama":
steps = 500
elif cfg.model.interleave:
steps = 250
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""