make validation samplers ignore sampler type

master
mrq 2023-10-22 09:01:47 +07:00
parent 32d4271ca8
commit 9a6040383e
2 changed files with 18 additions and 11 deletions

@ -188,6 +188,7 @@ class Dataset(_Dataset):
self.training = training
self.dataset_type = "training" if self.training else "validation"
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "path"
# to-do: do not do validation if there's nothing in the validation
# this just makes it be happy
@ -214,6 +215,9 @@ class Dataset(_Dataset):
spkr = cfg.get_spkr( data_dir / "dummy" )
spkr_group = cfg.get_spkr_group( data_dir / "dummy" )
if spkr not in self.paths_by_spkr_name or len(self.paths_by_spkr_name[spkr]) < cfg.dataset.min_utterances:
continue
if spkr_group not in self.spkrs_by_spkr_group:
self.spkrs_by_spkr_group[spkr_group] = []
@ -223,7 +227,7 @@ class Dataset(_Dataset):
self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
if cfg.dataset.sample_type == "path":
if self.sampler_type == "path":
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
self.noise_paths = _load_paths(cfg.dataset.noise, "noise")
@ -379,24 +383,24 @@ class Dataset(_Dataset):
return prom
def __getitem__(self, index):
if cfg.dataset.sample_type == "group":
if self.sampler_type == "group":
spkr_group = self.spkr_groups[index]
spkr_group_id = self.spkr_group_symmap[spkr_group]
#spkr_group_id = self.spkr_group_symmap[spkr_group]
spkr_name = self.spkr_samplers[spkr_group].sample()
spkr_id = self.spkr_symmap[spkr_name]
path = self.samplers[spkr_name].sample()
elif cfg.dataset.sample_type == "speaker":
elif self.sampler_type == "speaker":
spkr_name = self.spkrs[index]
spkr_id = self.spkr_symmap[spkr_name]
path = self.samplers[spkr_name].sample()
spkr_group = self.get_speaker_group(path)
spkr_group_id = self.spkr_group_symmap[spkr_group]
#spkr_group_id = self.spkr_group_symmap[spkr_group]
else:
path = self.paths[index]
spkr_name = self.get_speaker(path)
spkr_id = self.spkr_symmap[spkr_name]
spkr_group = self.get_speaker_group(path)
spkr_group_id = self.spkr_group_symmap[spkr_group]
#spkr_group_id = self.spkr_group_symmap[spkr_group]
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
@ -641,9 +645,9 @@ class Dataset(_Dataset):
self.training = value
def __len__(self):
if cfg.dataset.sample_type == "group":
if self.sampler_type == "group":
return min(len(self.spkr_groups), self._head or len(self.spkr_groups))
if cfg.dataset.sample_type == "speaker":
if self.sampler_type == "speaker":
return min(len(self.spkrs), self._head or len(self.spkrs))
return min(len(self.paths), self._head or len(self.paths))
@ -702,7 +706,7 @@ def create_train_val_dataloader():
train_dataset, val_dataset = create_datasets()
subtrain_dataset = copy.deepcopy(train_dataset)
if cfg.dataset.sample_type == "path":
if subtrain_dataset.sampler_type == "path":
subtrain_dataset.head_(cfg.evaluation.size)
train_dl = _create_dataloader(train_dataset, training=True)

@ -221,6 +221,8 @@ class Base(nn.Module):
n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
self.text_emb = Embedding(n_tokens, d_model)
self.langs_emb = None
self.tasks_emb = None
if self.version == 1: # legacy
n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
@ -232,9 +234,10 @@ class Base(nn.Module):
# [1025] + [1024] * 8
self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model)
if self.version >= 3:
self.langs_emb = Embedding(self.n_langs, d_model)
self.tasks_emb = Embedding(self.n_tasks, d_model)
self.langs_emb = Embedding(self.n_langs, d_model) if self.n_langs > 0 else None
self.tasks_emb = Embedding(self.n_tasks, d_model) if self.n_tasks > 0 else None
self.sep = nn.Parameter(torch.randn(d_model))