Fix two scripts
This commit is contained in:
parent
df45a9dec2
commit
36ed28913a
|
@ -71,7 +71,7 @@ if __name__ == "__main__":
|
||||||
for k, lbl in enumerate(lbls):
|
for k, lbl in enumerate(lbls):
|
||||||
lbl = labels[torch.argmax(lbl, dim=0)]
|
lbl = labels[torch.argmax(lbl, dim=0)]
|
||||||
src_path = data[path_key][k]
|
src_path = data[path_key][k]
|
||||||
output_file.write(f'{src_path}\t{lbl}')
|
output_file.write(f'{src_path}\t{lbl}\n')
|
||||||
if output_base_dir is not None:
|
if output_base_dir is not None:
|
||||||
dest = os.path.join(output_base_dir, lbl)
|
dest = os.path.join(output_base_dir, lbl)
|
||||||
os.makedirs(dest, exist_ok=True)
|
os.makedirs(dest, exist_ok=True)
|
||||||
|
|
|
@ -14,6 +14,7 @@ from torchvision.transforms import ToTensor
|
||||||
import utils
|
import utils
|
||||||
import utils.options as option
|
import utils.options as option
|
||||||
import utils.util as util
|
import utils.util as util
|
||||||
|
from data.audio.unsupervised_audio_dataset import load_audio
|
||||||
from models.tacotron2.taco_utils import load_wav_to_torch
|
from models.tacotron2.taco_utils import load_wav_to_torch
|
||||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||||
from data import create_dataset, create_dataloader
|
from data import create_dataset, create_dataloader
|
||||||
|
@ -48,10 +49,8 @@ def forward_pass(model, data, output_dir, spacing, audio_mode):
|
||||||
def load_image(path, audio_mode):
|
def load_image(path, audio_mode):
|
||||||
# Load test image
|
# Load test image
|
||||||
if audio_mode:
|
if audio_mode:
|
||||||
im, sr = load_wav_to_torch(path)
|
im = load_audio(path, 22050)
|
||||||
assert sr == 22050
|
im = im[:, :(im.shape[1]//4096)*4096].unsqueeze(0)
|
||||||
im = im.unsqueeze(0)
|
|
||||||
im = im[:, :(im.shape[1]//4096)*4096]
|
|
||||||
else:
|
else:
|
||||||
im = ToTensor()(Image.open(path)) * 2 - 1
|
im = ToTensor()(Image.open(path)) * 2 - 1
|
||||||
_, h, w = im.shape
|
_, h, w = im.shape
|
||||||
|
@ -113,7 +112,7 @@ if __name__ == "__main__":
|
||||||
if audio_mode:
|
if audio_mode:
|
||||||
data = {
|
data = {
|
||||||
'clip': im.to('cuda'),
|
'clip': im.to('cuda'),
|
||||||
'alt_clips': refs.to('cuda'),
|
'alt_clips': torch.zeros_like(refs[:,0].to('cuda')),
|
||||||
'num_alt_clips': torch.tensor([refs.shape[1]], dtype=torch.int32, device='cuda'),
|
'num_alt_clips': torch.tensor([refs.shape[1]], dtype=torch.int32, device='cuda'),
|
||||||
'GT_path': opt['image']
|
'GT_path': opt['image']
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user