41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
import os
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from data.util import is_wav_file, get_image_paths
|
|
from models.audio_resnet import resnet34
|
|
from models.tacotron2.taco_utils import load_wav_to_torch
|
|
from scripts.byol.byol_extract_wrapped_model import extract_byol_model_from_state_dict
|
|
|
|
if __name__ == '__main__':
|
|
window = 48000
|
|
root_path = 'D:\\tmp\\clips'
|
|
paths = get_image_paths('img', root_path, qualifier=is_wav_file)[0]
|
|
clips = []
|
|
for path in paths:
|
|
clip, sr = load_wav_to_torch(os.path.join(root_path, path))
|
|
if len(clip.shape) > 1:
|
|
clip = clip[:,0]
|
|
clip = clip[:window].unsqueeze(0)
|
|
clip = clip / 32768.0 # Normalize
|
|
assert sr == 24000
|
|
clips.append(clip)
|
|
clips = torch.stack(clips, dim=0)
|
|
|
|
resnet = resnet34()
|
|
sd = torch.load('../experiments/train_byol_audio_clips/models/66000_generator.pth')
|
|
sd = extract_byol_model_from_state_dict(sd)
|
|
resnet.load_state_dict(sd)
|
|
embedding = resnet(clips, return_pool=True)
|
|
|
|
for i, path in enumerate(paths):
|
|
print(f'Using a baseline of {path}..')
|
|
for j, cpath in enumerate(paths):
|
|
if i == j:
|
|
continue
|
|
l2 = F.mse_loss(embedding[j], embedding[i])
|
|
print(f'Compared to {cpath}: {l2}')
|
|
|