ugh
This commit is contained in:
parent
2808f881c8
commit
793ccb16fb
|
@ -551,7 +551,11 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
if self.sampler_type == "path":
|
if self.sampler_type == "path":
|
||||||
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
||||||
self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size )
|
self.sampler = BatchedOrderedSampler(
|
||||||
|
self.duration_buckets if not sampler_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
|
||||||
|
cfg.dataset.sample_max_duration_batch,
|
||||||
|
cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.sampler = OrderedSampler( len(self) )
|
self.sampler = OrderedSampler( len(self) )
|
||||||
self.samplers = {}
|
self.samplers = {}
|
||||||
|
|
|
@ -267,16 +267,15 @@ def _replace_file_extension(path, suffix):
|
||||||
# > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime
|
# > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime
|
||||||
# each audio backend does their "embeddings" a different way that isn't just a embedding weights
|
# each audio backend does their "embeddings" a different way that isn't just a embedding weights
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"):
|
def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cuda"):
|
||||||
model = _load_model(device)
|
model = _load_model(device)
|
||||||
|
|
||||||
codes = codes.to(device=device, dtype=torch.int32)
|
codes = codes.to(device=device, dtype=torch.int32)
|
||||||
|
|
||||||
if codes.dim() == 1:
|
if codes.dim() == 2:
|
||||||
codes = rearrange(codes, "t -> 1 t")
|
|
||||||
else:
|
|
||||||
codes = codes[:, quant_level]
|
codes = codes[:, quant_level]
|
||||||
codes = rearrange(codes, "t -> 1 t")
|
|
||||||
|
codes = rearrange(codes, "t -> 1 t")
|
||||||
|
|
||||||
# dac conveniently has its dim = 1024
|
# dac conveniently has its dim = 1024
|
||||||
if cfg.audio_backend == "dac":
|
if cfg.audio_backend == "dac":
|
||||||
|
|
|
@ -173,6 +173,9 @@ class AudioEmbedding(nn.Module):
|
||||||
|
|
||||||
# for AR, trim any stop tokens
|
# for AR, trim any stop tokens
|
||||||
has_stop_token = False
|
has_stop_token = False
|
||||||
|
|
||||||
|
# this block apparently doesn't work
|
||||||
|
"""
|
||||||
if quant_level == 0:
|
if quant_level == 0:
|
||||||
stop_token = self.embeddings[0].weight.shape[0] - 1
|
stop_token = self.embeddings[0].weight.shape[0] - 1
|
||||||
stop_token_indices = (input == stop_token).nonzero()
|
stop_token_indices = (input == stop_token).nonzero()
|
||||||
|
@ -180,6 +183,15 @@ class AudioEmbedding(nn.Module):
|
||||||
|
|
||||||
if has_stop_token:
|
if has_stop_token:
|
||||||
input = input[:stop_token_indices.min().item()]
|
input = input[:stop_token_indices.min().item()]
|
||||||
|
"""
|
||||||
|
has_stop_token = False
|
||||||
|
|
||||||
|
if quant_level == 0:
|
||||||
|
stop_token = self.embeddings[0].weight.shape[0] - 1
|
||||||
|
has_stop_token = input[-1] == stop_token
|
||||||
|
|
||||||
|
if has_stop_token:
|
||||||
|
input = input[:-1]
|
||||||
|
|
||||||
# get external embedding
|
# get external embedding
|
||||||
embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype)
|
embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user