diff --git a/vall_e/data.py b/vall_e/data.py index 1e2d773..d7571f7 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -551,7 +551,11 @@ class Dataset(_Dataset): if self.sampler_type == "path": if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0: - self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size ) + self.sampler = BatchedOrderedSampler( + self.duration_buckets if not sampler_path.exists() else {}, # pass nothing if we're just going to load from a state anyways + cfg.dataset.sample_max_duration_batch, + cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size + ) else: self.sampler = OrderedSampler( len(self) ) self.samplers = {} diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 2c6e8c4..4072037 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -267,16 +267,15 @@ def _replace_file_extension(path, suffix): # > b-but why not just initialize the embedding weights to these instead of fetching them at r-runtime # each audio backend does their "embeddings" a different way that isn't just a embedding weights @torch.inference_mode() -def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cpu"): +def encode_as_embedding(codes: Tensor, quant_level: int = 0, device="cuda"): model = _load_model(device) codes = codes.to(device=device, dtype=torch.int32) - if codes.dim() == 1: - codes = rearrange(codes, "t -> 1 t") - else: + if codes.dim() == 2: codes = codes[:, quant_level] - codes = rearrange(codes, "t -> 1 t") + + codes = rearrange(codes, "t -> 1 t") # dac conveniently has its dim = 1024 if cfg.audio_backend == "dac": diff --git a/vall_e/models/base.py b/vall_e/models/base.py index b063148..4674262 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -173,6 +173,9 @@ class AudioEmbedding(nn.Module): # for AR, trim any stop tokens has_stop_token = False + + # this block apparently doesn't work + """ if quant_level == 0: stop_token = self.embeddings[0].weight.shape[0] - 1 stop_token_indices = (input == stop_token).nonzero() @@ -180,6 +183,15 @@ class AudioEmbedding(nn.Module): if has_stop_token: input = input[:stop_token_indices.min().item()] + """ + has_stop_token = False + + if quant_level == 0: + stop_token = self.embeddings[0].weight.shape[0] - 1 + has_stop_token = input[-1] == stop_token + + if has_stop_token: + input = input[:-1] # get external embedding embedding = encode_as_embedding( input, quant_level ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype)