make validation samplers ignore sampler type
This commit is contained in:
parent
32d4271ca8
commit
9a6040383e
|
@ -188,6 +188,7 @@ class Dataset(_Dataset):
|
|||
self.training = training
|
||||
self.dataset_type = "training" if self.training else "validation"
|
||||
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
||||
self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "path"
|
||||
|
||||
# to-do: do not do validation if there's nothing in the validation
|
||||
# this just makes it be happy
|
||||
|
@ -214,6 +215,9 @@ class Dataset(_Dataset):
|
|||
spkr = cfg.get_spkr( data_dir / "dummy" )
|
||||
spkr_group = cfg.get_spkr_group( data_dir / "dummy" )
|
||||
|
||||
if spkr not in self.paths_by_spkr_name or len(self.paths_by_spkr_name[spkr]) < cfg.dataset.min_utterances:
|
||||
continue
|
||||
|
||||
if spkr_group not in self.spkrs_by_spkr_group:
|
||||
self.spkrs_by_spkr_group[spkr_group] = []
|
||||
|
||||
|
@ -223,7 +227,7 @@ class Dataset(_Dataset):
|
|||
|
||||
self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||
|
||||
if cfg.dataset.sample_type == "path":
|
||||
if self.sampler_type == "path":
|
||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||
|
||||
self.noise_paths = _load_paths(cfg.dataset.noise, "noise")
|
||||
|
@ -379,24 +383,24 @@ class Dataset(_Dataset):
|
|||
return prom
|
||||
|
||||
def __getitem__(self, index):
|
||||
if cfg.dataset.sample_type == "group":
|
||||
if self.sampler_type == "group":
|
||||
spkr_group = self.spkr_groups[index]
|
||||
spkr_group_id = self.spkr_group_symmap[spkr_group]
|
||||
#spkr_group_id = self.spkr_group_symmap[spkr_group]
|
||||
spkr_name = self.spkr_samplers[spkr_group].sample()
|
||||
spkr_id = self.spkr_symmap[spkr_name]
|
||||
path = self.samplers[spkr_name].sample()
|
||||
elif cfg.dataset.sample_type == "speaker":
|
||||
elif self.sampler_type == "speaker":
|
||||
spkr_name = self.spkrs[index]
|
||||
spkr_id = self.spkr_symmap[spkr_name]
|
||||
path = self.samplers[spkr_name].sample()
|
||||
spkr_group = self.get_speaker_group(path)
|
||||
spkr_group_id = self.spkr_group_symmap[spkr_group]
|
||||
#spkr_group_id = self.spkr_group_symmap[spkr_group]
|
||||
else:
|
||||
path = self.paths[index]
|
||||
spkr_name = self.get_speaker(path)
|
||||
spkr_id = self.spkr_symmap[spkr_name]
|
||||
spkr_group = self.get_speaker_group(path)
|
||||
spkr_group_id = self.spkr_group_symmap[spkr_group]
|
||||
#spkr_group_id = self.spkr_group_symmap[spkr_group]
|
||||
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
|
@ -641,9 +645,9 @@ class Dataset(_Dataset):
|
|||
self.training = value
|
||||
|
||||
def __len__(self):
|
||||
if cfg.dataset.sample_type == "group":
|
||||
if self.sampler_type == "group":
|
||||
return min(len(self.spkr_groups), self._head or len(self.spkr_groups))
|
||||
if cfg.dataset.sample_type == "speaker":
|
||||
if self.sampler_type == "speaker":
|
||||
return min(len(self.spkrs), self._head or len(self.spkrs))
|
||||
return min(len(self.paths), self._head or len(self.paths))
|
||||
|
||||
|
@ -702,7 +706,7 @@ def create_train_val_dataloader():
|
|||
train_dataset, val_dataset = create_datasets()
|
||||
|
||||
subtrain_dataset = copy.deepcopy(train_dataset)
|
||||
if cfg.dataset.sample_type == "path":
|
||||
if subtrain_dataset.sampler_type == "path":
|
||||
subtrain_dataset.head_(cfg.evaluation.size)
|
||||
|
||||
train_dl = _create_dataloader(train_dataset, training=True)
|
||||
|
|
|
@ -221,6 +221,8 @@ class Base(nn.Module):
|
|||
n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
|
||||
|
||||
self.text_emb = Embedding(n_tokens, d_model)
|
||||
self.langs_emb = None
|
||||
self.tasks_emb = None
|
||||
|
||||
if self.version == 1: # legacy
|
||||
n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
|
||||
|
@ -232,9 +234,10 @@ class Base(nn.Module):
|
|||
# [1025] + [1024] * 8
|
||||
self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model)
|
||||
|
||||
|
||||
if self.version >= 3:
|
||||
self.langs_emb = Embedding(self.n_langs, d_model)
|
||||
self.tasks_emb = Embedding(self.n_tasks, d_model)
|
||||
self.langs_emb = Embedding(self.n_langs, d_model) if self.n_langs > 0 else None
|
||||
self.tasks_emb = Embedding(self.n_tasks, d_model) if self.n_tasks > 0 else None
|
||||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user