added summing of external embeddings (at this point i dont think any amount of cope bandaids will get DAC to train nicely, I think the RVQ levels the NAR tends add too much noise if they're not accurate)
This commit is contained in:
parent
793ccb16fb
commit
b21f74a5c5
|
@ -267,11 +267,31 @@ def _replace_file_extension(path, suffix):
|
|||
# > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime
|
||||
# each audio backend does their "embeddings" a different way that isn't just a embedding weights
|
||||
@torch.inference_mode()
|
||||
def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cuda"):
|
||||
def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device="cuda"):
|
||||
model = _load_model(device)
|
||||
|
||||
codes = codes.to(device=device, dtype=torch.int32)
|
||||
|
||||
# yucky kludge
|
||||
if sums:
|
||||
if codes.dim() == 1:
|
||||
codes = rearrange(codes, "t -> t 1")
|
||||
|
||||
if cfg.audio_backend == "dac":
|
||||
x = []
|
||||
for i in range(quant_level+1):
|
||||
emb = model.quantizer.quantizers[i]
|
||||
code = rearrange(codes[:, quant_level], "t -> 1 t")
|
||||
|
||||
xi = emb.decode_code(code)
|
||||
xi = emb.out_proj(xi)
|
||||
x.append( xi[0].t() )
|
||||
|
||||
return sum(x).detach()
|
||||
|
||||
raise Exception(f'Currently only DAC is supported')
|
||||
|
||||
|
||||
if codes.dim() == 2:
|
||||
codes = codes[:, quant_level]
|
||||
|
||||
|
|
|
@ -473,7 +473,7 @@ def example_usage():
|
|||
"""
|
||||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
steps = 100
|
||||
steps = 150
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
|
|
@ -194,7 +194,7 @@ class AudioEmbedding(nn.Module):
|
|||
input = input[:-1]
|
||||
|
||||
# get external embedding
|
||||
embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype)
|
||||
embedding = encode_as_embedding( input, quant_level, sums=self.sums ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype)
|
||||
# resize if necessary (in case the external embeddings do not match our model dim)
|
||||
embedding = ml.resize_weight( embedding, self.embeddings[quant_level].weight.shape[-1], dim=-1, random=False )
|
||||
|
||||
|
|
|
@ -252,12 +252,6 @@ def example_usage():
|
|||
kwargs = {}
|
||||
model = Model(**kwargs).to(device)
|
||||
steps = 100
|
||||
if cfg.model.arch_type == "mamba2":
|
||||
steps = 100
|
||||
elif cfg.model.arch_type == "llama":
|
||||
steps = 500
|
||||
elif cfg.model.interleave:
|
||||
steps = 250
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
|
Loading…
Reference in New Issue
Block a user