added summing of external embeddings (at this point i dont think any amount of cope bandaids will get DAC to train nicely, I think the RVQ levels the NAR tends add too much noise if they're not accurate)
This commit is contained in:
parent
793ccb16fb
commit
b21f74a5c5
|
@ -267,11 +267,31 @@ def _replace_file_extension(path, suffix):
|
||||||
# > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime
|
# > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime
|
||||||
# each audio backend does their "embeddings" a different way that isn't just a embedding weights
|
# each audio backend does their "embeddings" a different way that isn't just a embedding weights
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cuda"):
|
def encode_as_embedding(codes: Tensor, quant_level: int = 0, sums=False, device="cuda"):
|
||||||
model = _load_model(device)
|
model = _load_model(device)
|
||||||
|
|
||||||
codes = codes.to(device=device, dtype=torch.int32)
|
codes = codes.to(device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
# yucky kludge
|
||||||
|
if sums:
|
||||||
|
if codes.dim() == 1:
|
||||||
|
codes = rearrange(codes, "t -> t 1")
|
||||||
|
|
||||||
|
if cfg.audio_backend == "dac":
|
||||||
|
x = []
|
||||||
|
for i in range(quant_level+1):
|
||||||
|
emb = model.quantizer.quantizers[i]
|
||||||
|
code = rearrange(codes[:, quant_level], "t -> 1 t")
|
||||||
|
|
||||||
|
xi = emb.decode_code(code)
|
||||||
|
xi = emb.out_proj(xi)
|
||||||
|
x.append( xi[0].t() )
|
||||||
|
|
||||||
|
return sum(x).detach()
|
||||||
|
|
||||||
|
raise Exception(f'Currently only DAC is supported')
|
||||||
|
|
||||||
|
|
||||||
if codes.dim() == 2:
|
if codes.dim() == 2:
|
||||||
codes = codes[:, quant_level]
|
codes = codes[:, quant_level]
|
||||||
|
|
||||||
|
|
|
@ -473,7 +473,7 @@ def example_usage():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model = AR_NAR(**kwargs).to(device)
|
model = AR_NAR(**kwargs).to(device)
|
||||||
steps = 100
|
steps = 150
|
||||||
|
|
||||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||||
|
|
|
@ -194,7 +194,7 @@ class AudioEmbedding(nn.Module):
|
||||||
input = input[:-1]
|
input = input[:-1]
|
||||||
|
|
||||||
# get external embedding
|
# get external embedding
|
||||||
embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype)
|
embedding = encode_as_embedding( input, quant_level, sums=self.sums ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype)
|
||||||
# resize if necessary (in case the external embeddings do not match our model dim)
|
# resize if necessary (in case the external embeddings do not match our model dim)
|
||||||
embedding = ml.resize_weight( embedding, self.embeddings[quant_level].weight.shape[-1], dim=-1, random=False )
|
embedding = ml.resize_weight( embedding, self.embeddings[quant_level].weight.shape[-1], dim=-1, random=False )
|
||||||
|
|
||||||
|
|
|
@ -252,12 +252,6 @@ def example_usage():
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
model = Model(**kwargs).to(device)
|
model = Model(**kwargs).to(device)
|
||||||
steps = 100
|
steps = 100
|
||||||
if cfg.model.arch_type == "mamba2":
|
|
||||||
steps = 100
|
|
||||||
elif cfg.model.arch_type == "llama":
|
|
||||||
steps = 500
|
|
||||||
elif cfg.model.interleave:
|
|
||||||
steps = 250
|
|
||||||
|
|
||||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user