This commit is contained in:
James Betker 2022-04-16 20:28:04 -06:00
parent 147478a148
commit 48cb6a5abd
3 changed files with 6 additions and 19 deletions

View File

@ -176,15 +176,15 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
if __name__ == '__main__':
params = {
'mode': 'unsupervised_audio',
'path': ['\\\\192.168.5.3\\rtx3080_audio\\split\\cleaned\\books0'],
'cache_path': 'E:\\audio\\remote-cache3.pth',
'path': ['Y:\\split\\yt-music'],
'cache_path': 'Y:\\split\\yt-music\\cache-windows.pth',
'sampling_rate': 22050,
'pad_to_samples': 40960,
'pad_to_samples': 22050,
'phase': 'train',
'n_workers': 1,
'batch_size': 16,
'extra_samples': 4,
'resample_clip': True,
'resample_clip': False,
}
from data import create_dataset, create_dataloader
@ -195,5 +195,5 @@ if __name__ == '__main__':
for b_ in range(b['clip'].shape[0]):
#pass
torchaudio.save(f'{i}_clip_{b_}.wav', b['clip'][b_], ds.sampling_rate)
torchaudio.save(f'{i}_resampled_clip_{b_}.wav', b['resampled_clip'][b_], ds.sampling_rate)
#torchaudio.save(f'{i}_resampled_clip_{b_}.wav', b['resampled_clip'][b_], ds.sampling_rate)
i += 1

View File

@ -86,7 +86,6 @@ class CLVP(nn.Module):
speech_enc_depth=6,
speech_mask_percentage=0,
latent_multiplier=4,
is_distributed=False,
):
super().__init__()
latent_dim = latent_multiplier*model_dim
@ -102,8 +101,6 @@ class CLVP(nn.Module):
self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True)
self.to_text_latent = nn.Linear(latent_dim, latent_dim, bias=False)
self.distributed = is_distributed
if mel_codes is None:
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
else:
@ -143,16 +140,6 @@ class CLVP(nn.Module):
text_latents = self.to_text_latent(enc_text)
speech_latents = self.to_speech_latent(enc_speech)
if self.distributed:
ws = get_world_size()
text_gather_cells = [torch.zeros_like(text_latents) for _ in range(ws)]
speech_gather_cells = [torch.zeros_like(speech_latents) for _ in range(ws)]
distributed.all_gather(text_gather_cells, text_latents)
text_gather_cells[distributed.get_rank()] = text_latents # Propagate gradients in this way.
text_latents = torch.cat(text_gather_cells, dim=0)
distributed.all_gather(speech_gather_cells, speech_latents)
speech_gather_cells[distributed.get_rank()] = speech_latents
speech_latents = torch.cat(speech_gather_cells, dim=0)
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
temp = self.temperature.exp()

View File

@ -327,7 +327,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clip_text_to_voice.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_cvvp_codes.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)