misc
This commit is contained in:
parent
147478a148
commit
48cb6a5abd
|
@ -176,15 +176,15 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
|||
if __name__ == '__main__':
|
||||
params = {
|
||||
'mode': 'unsupervised_audio',
|
||||
'path': ['\\\\192.168.5.3\\rtx3080_audio\\split\\cleaned\\books0'],
|
||||
'cache_path': 'E:\\audio\\remote-cache3.pth',
|
||||
'path': ['Y:\\split\\yt-music'],
|
||||
'cache_path': 'Y:\\split\\yt-music\\cache-windows.pth',
|
||||
'sampling_rate': 22050,
|
||||
'pad_to_samples': 40960,
|
||||
'pad_to_samples': 22050,
|
||||
'phase': 'train',
|
||||
'n_workers': 1,
|
||||
'batch_size': 16,
|
||||
'extra_samples': 4,
|
||||
'resample_clip': True,
|
||||
'resample_clip': False,
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
|
||||
|
@ -195,5 +195,5 @@ if __name__ == '__main__':
|
|||
for b_ in range(b['clip'].shape[0]):
|
||||
#pass
|
||||
torchaudio.save(f'{i}_clip_{b_}.wav', b['clip'][b_], ds.sampling_rate)
|
||||
torchaudio.save(f'{i}_resampled_clip_{b_}.wav', b['resampled_clip'][b_], ds.sampling_rate)
|
||||
#torchaudio.save(f'{i}_resampled_clip_{b_}.wav', b['resampled_clip'][b_], ds.sampling_rate)
|
||||
i += 1
|
||||
|
|
|
@ -86,7 +86,6 @@ class CLVP(nn.Module):
|
|||
speech_enc_depth=6,
|
||||
speech_mask_percentage=0,
|
||||
latent_multiplier=4,
|
||||
is_distributed=False,
|
||||
):
|
||||
super().__init__()
|
||||
latent_dim = latent_multiplier*model_dim
|
||||
|
@ -102,8 +101,6 @@ class CLVP(nn.Module):
|
|||
self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True)
|
||||
self.to_text_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
|
||||
self.distributed = is_distributed
|
||||
|
||||
if mel_codes is None:
|
||||
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
|
||||
else:
|
||||
|
@ -143,16 +140,6 @@ class CLVP(nn.Module):
|
|||
|
||||
text_latents = self.to_text_latent(enc_text)
|
||||
speech_latents = self.to_speech_latent(enc_speech)
|
||||
if self.distributed:
|
||||
ws = get_world_size()
|
||||
text_gather_cells = [torch.zeros_like(text_latents) for _ in range(ws)]
|
||||
speech_gather_cells = [torch.zeros_like(speech_latents) for _ in range(ws)]
|
||||
distributed.all_gather(text_gather_cells, text_latents)
|
||||
text_gather_cells[distributed.get_rank()] = text_latents # Propagate gradients in this way.
|
||||
text_latents = torch.cat(text_gather_cells, dim=0)
|
||||
distributed.all_gather(speech_gather_cells, speech_latents)
|
||||
speech_gather_cells[distributed.get_rank()] = speech_latents
|
||||
speech_latents = torch.cat(speech_gather_cells, dim=0)
|
||||
|
||||
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
||||
temp = self.temperature.exp()
|
||||
|
|
|
@ -327,7 +327,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clip_text_to_voice.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_cvvp_codes.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user