diff --git a/vall_e/data.py b/vall_e/data.py index d400b3d..c7d1651 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -188,6 +188,7 @@ class Dataset(_Dataset): self.training = training self.dataset_type = "training" if self.training else "validation" self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation + self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "path" # to-do: do not do validation if there's nothing in the validation # this just makes it be happy @@ -214,6 +215,9 @@ class Dataset(_Dataset): spkr = cfg.get_spkr( data_dir / "dummy" ) spkr_group = cfg.get_spkr_group( data_dir / "dummy" ) + if spkr not in self.paths_by_spkr_name or len(self.paths_by_spkr_name[spkr]) < cfg.dataset.min_utterances: + continue + if spkr_group not in self.spkrs_by_spkr_group: self.spkrs_by_spkr_group[spkr_group] = [] @@ -223,7 +227,7 @@ class Dataset(_Dataset): self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() } - if cfg.dataset.sample_type == "path": + if self.sampler_type == "path": self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)] self.noise_paths = _load_paths(cfg.dataset.noise, "noise") @@ -379,24 +383,24 @@ class Dataset(_Dataset): return prom def __getitem__(self, index): - if cfg.dataset.sample_type == "group": + if self.sampler_type == "group": spkr_group = self.spkr_groups[index] - spkr_group_id = self.spkr_group_symmap[spkr_group] + #spkr_group_id = self.spkr_group_symmap[spkr_group] spkr_name = self.spkr_samplers[spkr_group].sample() spkr_id = self.spkr_symmap[spkr_name] path = self.samplers[spkr_name].sample() - elif cfg.dataset.sample_type == "speaker": + elif self.sampler_type == "speaker": spkr_name = self.spkrs[index] spkr_id = self.spkr_symmap[spkr_name] path = self.samplers[spkr_name].sample() spkr_group = self.get_speaker_group(path) - spkr_group_id = self.spkr_group_symmap[spkr_group] + #spkr_group_id = self.spkr_group_symmap[spkr_group] else: path = self.paths[index] spkr_name = self.get_speaker(path) spkr_id = self.spkr_symmap[spkr_name] spkr_group = self.get_speaker_group(path) - spkr_group_id = self.spkr_group_symmap[spkr_group] + #spkr_group_id = self.spkr_group_symmap[spkr_group] if cfg.dataset.use_hdf5: key = _get_hdf5_path(path) @@ -641,9 +645,9 @@ class Dataset(_Dataset): self.training = value def __len__(self): - if cfg.dataset.sample_type == "group": + if self.sampler_type == "group": return min(len(self.spkr_groups), self._head or len(self.spkr_groups)) - if cfg.dataset.sample_type == "speaker": + if self.sampler_type == "speaker": return min(len(self.spkrs), self._head or len(self.spkrs)) return min(len(self.paths), self._head or len(self.paths)) @@ -702,7 +706,7 @@ def create_train_val_dataloader(): train_dataset, val_dataset = create_datasets() subtrain_dataset = copy.deepcopy(train_dataset) - if cfg.dataset.sample_type == "path": + if subtrain_dataset.sampler_type == "path": subtrain_dataset.head_(cfg.evaluation.size) train_dl = _create_dataloader(train_dataset, training=True) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index e3c6ea1..34ba3f9 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -221,6 +221,8 @@ class Base(nn.Module): n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop self.text_emb = Embedding(n_tokens, d_model) + self.langs_emb = None + self.tasks_emb = None if self.version == 1: # legacy n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom @@ -232,9 +234,10 @@ class Base(nn.Module): # [1025] + [1024] * 8 self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model) + if self.version >= 3: - self.langs_emb = Embedding(self.n_langs, d_model) - self.tasks_emb = Embedding(self.n_tasks, d_model) + self.langs_emb = Embedding(self.n_langs, d_model) if self.n_langs > 0 else None + self.tasks_emb = Embedding(self.n_tasks, d_model) if self.n_tasks > 0 else None self.sep = nn.Parameter(torch.randn(d_model))