forked from mrq/DL-Art-School
misc
This commit is contained in:
parent
147478a148
commit
48cb6a5abd
codes
|
@ -176,15 +176,15 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
params = {
|
params = {
|
||||||
'mode': 'unsupervised_audio',
|
'mode': 'unsupervised_audio',
|
||||||
'path': ['\\\\192.168.5.3\\rtx3080_audio\\split\\cleaned\\books0'],
|
'path': ['Y:\\split\\yt-music'],
|
||||||
'cache_path': 'E:\\audio\\remote-cache3.pth',
|
'cache_path': 'Y:\\split\\yt-music\\cache-windows.pth',
|
||||||
'sampling_rate': 22050,
|
'sampling_rate': 22050,
|
||||||
'pad_to_samples': 40960,
|
'pad_to_samples': 22050,
|
||||||
'phase': 'train',
|
'phase': 'train',
|
||||||
'n_workers': 1,
|
'n_workers': 1,
|
||||||
'batch_size': 16,
|
'batch_size': 16,
|
||||||
'extra_samples': 4,
|
'extra_samples': 4,
|
||||||
'resample_clip': True,
|
'resample_clip': False,
|
||||||
}
|
}
|
||||||
from data import create_dataset, create_dataloader
|
from data import create_dataset, create_dataloader
|
||||||
|
|
||||||
|
@ -195,5 +195,5 @@ if __name__ == '__main__':
|
||||||
for b_ in range(b['clip'].shape[0]):
|
for b_ in range(b['clip'].shape[0]):
|
||||||
#pass
|
#pass
|
||||||
torchaudio.save(f'{i}_clip_{b_}.wav', b['clip'][b_], ds.sampling_rate)
|
torchaudio.save(f'{i}_clip_{b_}.wav', b['clip'][b_], ds.sampling_rate)
|
||||||
torchaudio.save(f'{i}_resampled_clip_{b_}.wav', b['resampled_clip'][b_], ds.sampling_rate)
|
#torchaudio.save(f'{i}_resampled_clip_{b_}.wav', b['resampled_clip'][b_], ds.sampling_rate)
|
||||||
i += 1
|
i += 1
|
||||||
|
|
|
@ -86,7 +86,6 @@ class CLVP(nn.Module):
|
||||||
speech_enc_depth=6,
|
speech_enc_depth=6,
|
||||||
speech_mask_percentage=0,
|
speech_mask_percentage=0,
|
||||||
latent_multiplier=4,
|
latent_multiplier=4,
|
||||||
is_distributed=False,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
latent_dim = latent_multiplier*model_dim
|
latent_dim = latent_multiplier*model_dim
|
||||||
|
@ -102,8 +101,6 @@ class CLVP(nn.Module):
|
||||||
self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True)
|
self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True)
|
||||||
self.to_text_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
self.to_text_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||||
|
|
||||||
self.distributed = is_distributed
|
|
||||||
|
|
||||||
if mel_codes is None:
|
if mel_codes is None:
|
||||||
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
|
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
|
||||||
else:
|
else:
|
||||||
|
@ -143,16 +140,6 @@ class CLVP(nn.Module):
|
||||||
|
|
||||||
text_latents = self.to_text_latent(enc_text)
|
text_latents = self.to_text_latent(enc_text)
|
||||||
speech_latents = self.to_speech_latent(enc_speech)
|
speech_latents = self.to_speech_latent(enc_speech)
|
||||||
if self.distributed:
|
|
||||||
ws = get_world_size()
|
|
||||||
text_gather_cells = [torch.zeros_like(text_latents) for _ in range(ws)]
|
|
||||||
speech_gather_cells = [torch.zeros_like(speech_latents) for _ in range(ws)]
|
|
||||||
distributed.all_gather(text_gather_cells, text_latents)
|
|
||||||
text_gather_cells[distributed.get_rank()] = text_latents # Propagate gradients in this way.
|
|
||||||
text_latents = torch.cat(text_gather_cells, dim=0)
|
|
||||||
distributed.all_gather(speech_gather_cells, speech_latents)
|
|
||||||
speech_gather_cells[distributed.get_rank()] = speech_latents
|
|
||||||
speech_latents = torch.cat(speech_gather_cells, dim=0)
|
|
||||||
|
|
||||||
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
||||||
temp = self.temperature.exp()
|
temp = self.temperature.exp()
|
||||||
|
|
|
@ -327,7 +327,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_clip_text_to_voice.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_cvvp_codes.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user