ugh
This commit is contained in:
parent
2808f881c8
commit
793ccb16fb
|
@ -551,7 +551,11 @@ class Dataset(_Dataset):
|
|||
|
||||
if self.sampler_type == "path":
|
||||
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
||||
self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size )
|
||||
self.sampler = BatchedOrderedSampler(
|
||||
self.duration_buckets if not sampler_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
|
||||
cfg.dataset.sample_max_duration_batch,
|
||||
cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size
|
||||
)
|
||||
else:
|
||||
self.sampler = OrderedSampler( len(self) )
|
||||
self.samplers = {}
|
||||
|
|
|
@ -267,16 +267,15 @@ def _replace_file_extension(path, suffix):
|
|||
# > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime
|
||||
# each audio backend does their "embeddings" a different way that isn't just a embedding weights
|
||||
@torch.inference_mode()
|
||||
def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"):
|
||||
def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cuda"):
|
||||
model = _load_model(device)
|
||||
|
||||
codes = codes.to(device=device, dtype=torch.int32)
|
||||
|
||||
if codes.dim() == 1:
|
||||
codes = rearrange(codes, "t -> 1 t")
|
||||
else:
|
||||
if codes.dim() == 2:
|
||||
codes = codes[:, quant_level]
|
||||
codes = rearrange(codes, "t -> 1 t")
|
||||
|
||||
codes = rearrange(codes, "t -> 1 t")
|
||||
|
||||
# dac conveniently has its dim = 1024
|
||||
if cfg.audio_backend == "dac":
|
||||
|
|
|
@ -173,6 +173,9 @@ class AudioEmbedding(nn.Module):
|
|||
|
||||
# for AR, trim any stop tokens
|
||||
has_stop_token = False
|
||||
|
||||
# this block apparently doesn't work
|
||||
"""
|
||||
if quant_level == 0:
|
||||
stop_token = self.embeddings[0].weight.shape[0] - 1
|
||||
stop_token_indices = (input == stop_token).nonzero()
|
||||
|
@ -180,6 +183,15 @@ class AudioEmbedding(nn.Module):
|
|||
|
||||
if has_stop_token:
|
||||
input = input[:stop_token_indices.min().item()]
|
||||
"""
|
||||
has_stop_token = False
|
||||
|
||||
if quant_level == 0:
|
||||
stop_token = self.embeddings[0].weight.shape[0] - 1
|
||||
has_stop_token = input[-1] == stop_token
|
||||
|
||||
if has_stop_token:
|
||||
input = input[:-1]
|
||||
|
||||
# get external embedding
|
||||
embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype)
|
||||
|
|
Loading…
Reference in New Issue
Block a user