This commit is contained in:
mrq 2024-06-29 22:14:35 -05:00
parent 2808f881c8
commit 793ccb16fb
3 changed files with 21 additions and 6 deletions

View File

@ -551,7 +551,11 @@ class Dataset(_Dataset):
if self.sampler_type == "path":
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size )
self.sampler = BatchedOrderedSampler(
self.duration_buckets if not sampler_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
cfg.dataset.sample_max_duration_batch,
cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size
)
else:
self.sampler = OrderedSampler( len(self) )
self.samplers = {}

View File

@ -267,16 +267,15 @@ def _replace_file_extension(path, suffix):
# > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime
# each audio backend does their "embeddings" a different way that isn't just a embedding weights
@torch.inference_mode()
def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"):
def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cuda"):
model = _load_model(device)
codes = codes.to(device=device, dtype=torch.int32)
if codes.dim() == 1:
codes = rearrange(codes, "t -> 1 t")
else:
if codes.dim() == 2:
codes = codes[:, quant_level]
codes = rearrange(codes, "t -> 1 t")
codes = rearrange(codes, "t -> 1 t")
# dac conveniently has its dim = 1024
if cfg.audio_backend == "dac":

View File

@ -173,6 +173,9 @@ class AudioEmbedding(nn.Module):
# for AR, trim any stop tokens
has_stop_token = False
# this block apparently doesn't work
"""
if quant_level == 0:
stop_token = self.embeddings[0].weight.shape[0] - 1
stop_token_indices = (input == stop_token).nonzero()
@ -180,6 +183,15 @@ class AudioEmbedding(nn.Module):
if has_stop_token:
input = input[:stop_token_indices.min().item()]
"""
has_stop_token = False
if quant_level == 0:
stop_token = self.embeddings[0].weight.shape[0] - 1
has_stop_token = input[-1] == stop_token
if has_stop_token:
input = input[:-1]
# get external embedding
embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype)