forked from mrq/DL-Art-School
Fix two scripts
This commit is contained in:
parent
df45a9dec2
commit
36ed28913a
|
@ -71,7 +71,7 @@ if __name__ == "__main__":
|
|||
for k, lbl in enumerate(lbls):
|
||||
lbl = labels[torch.argmax(lbl, dim=0)]
|
||||
src_path = data[path_key][k]
|
||||
output_file.write(f'{src_path}\t{lbl}')
|
||||
output_file.write(f'{src_path}\t{lbl}\n')
|
||||
if output_base_dir is not None:
|
||||
dest = os.path.join(output_base_dir, lbl)
|
||||
os.makedirs(dest, exist_ok=True)
|
||||
|
|
|
@ -14,6 +14,7 @@ from torchvision.transforms import ToTensor
|
|||
import utils
|
||||
import utils.options as option
|
||||
import utils.util as util
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from models.tacotron2.taco_utils import load_wav_to_torch
|
||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||
from data import create_dataset, create_dataloader
|
||||
|
@ -48,10 +49,8 @@ def forward_pass(model, data, output_dir, spacing, audio_mode):
|
|||
def load_image(path, audio_mode):
|
||||
# Load test image
|
||||
if audio_mode:
|
||||
im, sr = load_wav_to_torch(path)
|
||||
assert sr == 22050
|
||||
im = im.unsqueeze(0)
|
||||
im = im[:, :(im.shape[1]//4096)*4096]
|
||||
im = load_audio(path, 22050)
|
||||
im = im[:, :(im.shape[1]//4096)*4096].unsqueeze(0)
|
||||
else:
|
||||
im = ToTensor()(Image.open(path)) * 2 - 1
|
||||
_, h, w = im.shape
|
||||
|
@ -113,7 +112,7 @@ if __name__ == "__main__":
|
|||
if audio_mode:
|
||||
data = {
|
||||
'clip': im.to('cuda'),
|
||||
'alt_clips': refs.to('cuda'),
|
||||
'alt_clips': torch.zeros_like(refs[:,0].to('cuda')),
|
||||
'num_alt_clips': torch.tensor([refs.shape[1]], dtype=torch.int32, device='cuda'),
|
||||
'GT_path': opt['image']
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user