Fix two scripts

This commit is contained in:
James Betker 2021-10-30 17:00:06 -06:00
parent df45a9dec2
commit 36ed28913a
2 changed files with 5 additions and 6 deletions

View File

@ -71,7 +71,7 @@ if __name__ == "__main__":
for k, lbl in enumerate(lbls):
lbl = labels[torch.argmax(lbl, dim=0)]
src_path = data[path_key][k]
output_file.write(f'{src_path}\t{lbl}')
output_file.write(f'{src_path}\t{lbl}\n')
if output_base_dir is not None:
dest = os.path.join(output_base_dir, lbl)
os.makedirs(dest, exist_ok=True)

View File

@ -14,6 +14,7 @@ from torchvision.transforms import ToTensor
import utils
import utils.options as option
import utils.util as util
from data.audio.unsupervised_audio_dataset import load_audio
from models.tacotron2.taco_utils import load_wav_to_torch
from trainer.ExtensibleTrainer import ExtensibleTrainer
from data import create_dataset, create_dataloader
@ -48,10 +49,8 @@ def forward_pass(model, data, output_dir, spacing, audio_mode):
def load_image(path, audio_mode):
# Load test image
if audio_mode:
im, sr = load_wav_to_torch(path)
assert sr == 22050
im = im.unsqueeze(0)
im = im[:, :(im.shape[1]//4096)*4096]
im = load_audio(path, 22050)
im = im[:, :(im.shape[1]//4096)*4096].unsqueeze(0)
else:
im = ToTensor()(Image.open(path)) * 2 - 1
_, h, w = im.shape
@ -113,7 +112,7 @@ if __name__ == "__main__":
if audio_mode:
data = {
'clip': im.to('cuda'),
'alt_clips': refs.to('cuda'),
'alt_clips': torch.zeros_like(refs[:,0].to('cuda')),
'num_alt_clips': torch.tensor([refs.shape[1]], dtype=torch.int32, device='cuda'),
'GT_path': opt['image']
}