PIP-ified (credit to https://git.ecker.tech/eschmidbauer)
This commit is contained in:
parent
fe24641763
commit
a4afad8837
72
.gitignore
vendored
72
.gitignore
vendored
|
@ -1,30 +1,3 @@
|
|||
dlas/experiments/*
|
||||
dlas/codes/*.txt
|
||||
dlas/codes/wandb/*
|
||||
dlas/codes/pretrained_models/*
|
||||
dlas/codes/scripts/audio/pretrained_models/*
|
||||
|
||||
results/*
|
||||
tb_logger/*
|
||||
datasets/*
|
||||
options/*
|
||||
data/*
|
||||
.vscode
|
||||
|
||||
*.html
|
||||
*.png
|
||||
*.jpg
|
||||
*.gif
|
||||
*.pth
|
||||
*.pytorch
|
||||
*.zip
|
||||
*.cu
|
||||
*.pt
|
||||
*.pth
|
||||
*.pdf
|
||||
*.tsv
|
||||
|
||||
# template
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
@ -36,6 +9,7 @@ __pycache__/
|
|||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
env/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
|
@ -47,12 +21,9 @@ lib64/
|
|||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pretrained/*
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
|
@ -72,9 +43,8 @@ htmlcov/
|
|||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
|
@ -83,14 +53,6 @@ coverage.xml
|
|||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
@ -98,36 +60,8 @@ docs/_build/
|
|||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
#Ipython Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
|
|
|
@ -1 +1 @@
|
|||
recursive-include codes/*
|
||||
recursive-include dlas/*
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
|||
import torch.utils.data
|
||||
from munch import munchify
|
||||
|
||||
from utils.util import opt_get
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None, collate_fn=None, shuffle=True):
|
||||
|
@ -33,77 +33,90 @@ def create_dataset(dataset_opt, return_collate=False):
|
|||
|
||||
# datasets for image restoration
|
||||
if mode == 'fullimage':
|
||||
from data.images.full_image_dataset import FullImageDataset as D
|
||||
from dlas.data.images.full_image_dataset import FullImageDataset as D
|
||||
elif mode == 'single_image_extensible':
|
||||
from data.images.single_image_dataset import SingleImageDataset as D
|
||||
from dlas.data.images.single_image_dataset import \
|
||||
SingleImageDataset as D
|
||||
elif mode == 'multi_frame_extensible':
|
||||
from data.images.multi_frame_dataset import MultiFrameDataset as D
|
||||
from dlas.data.images.multi_frame_dataset import MultiFrameDataset as D
|
||||
elif mode == 'combined':
|
||||
from data.combined_dataset import CombinedDataset as D
|
||||
from dlas.data.combined_dataset import CombinedDataset as D
|
||||
elif mode == 'multiscale':
|
||||
from data.images.multiscale_dataset import MultiScaleDataset as D
|
||||
from dlas.data.images.multiscale_dataset import MultiScaleDataset as D
|
||||
elif mode == 'paired_frame':
|
||||
from data.images.paired_frame_dataset import PairedFrameDataset as D
|
||||
from dlas.data.images.paired_frame_dataset import \
|
||||
PairedFrameDataset as D
|
||||
elif mode == 'stylegan2':
|
||||
from data.images.stylegan2_dataset import Stylegan2Dataset as D
|
||||
from dlas.data.images.stylegan2_dataset import Stylegan2Dataset as D
|
||||
elif mode == 'imagefolder':
|
||||
from data.images.image_folder_dataset import ImageFolderDataset as D
|
||||
from dlas.data.images.image_folder_dataset import \
|
||||
ImageFolderDataset as D
|
||||
elif mode == 'torch_dataset':
|
||||
from data.torch_dataset import TorchDataset as D
|
||||
elif mode == 'byol_dataset':
|
||||
from data.images.byol_attachment import ByolDatasetWrapper as D
|
||||
from dlas.data.images.byol_attachment import ByolDatasetWrapper as D
|
||||
elif mode == 'byol_structured_dataset':
|
||||
from data.images.byol_attachment import StructuredCropDatasetWrapper as D
|
||||
from dlas.data.images.byol_attachment import \
|
||||
StructuredCropDatasetWrapper as D
|
||||
elif mode == 'random_aug_wrapper':
|
||||
from data.images.byol_attachment import DatasetRandomAugWrapper as D
|
||||
from dlas.data.images.byol_attachment import \
|
||||
DatasetRandomAugWrapper as D
|
||||
elif mode == 'random_dataset':
|
||||
from data.images.random_dataset import RandomDataset as D
|
||||
from dlas.data.images.random_dataset import RandomDataset as D
|
||||
elif mode == 'zipfile':
|
||||
from data.images.zip_file_dataset import ZipFileDataset as D
|
||||
from dlas.data.images.zip_file_dataset import ZipFileDataset as D
|
||||
elif mode == 'nv_tacotron':
|
||||
from data.audio.nv_tacotron_dataset import TextWavLoader as D
|
||||
from data.audio.nv_tacotron_dataset import TextMelCollate as C
|
||||
from models.audio.tts.tacotron2 import create_hparams
|
||||
from dlas.data.audio.nv_tacotron_dataset import TextMelCollate as C
|
||||
from dlas.data.audio.nv_tacotron_dataset import TextWavLoader as D
|
||||
from dlas.models.audio.tts.tacotron2 import create_hparams
|
||||
default_params = create_hparams()
|
||||
default_params.update(dataset_opt)
|
||||
dataset_opt = munchify(default_params)
|
||||
if opt_get(dataset_opt, ['needs_collate'], True):
|
||||
collate = C()
|
||||
elif mode == 'paired_voice_audio':
|
||||
from data.audio.paired_voice_audio_dataset import TextWavLoader as D
|
||||
from models.audio.tts.tacotron2 import create_hparams
|
||||
from dlas.data.audio.paired_voice_audio_dataset import \
|
||||
TextWavLoader as D
|
||||
from dlas.models.audio.tts.tacotron2 import create_hparams
|
||||
default_params = create_hparams()
|
||||
default_params.update(dataset_opt)
|
||||
dataset_opt = munchify(default_params)
|
||||
elif mode == 'fast_paired_voice_audio':
|
||||
from data.audio.fast_paired_dataset import FastPairedVoiceDataset as D
|
||||
from models.audio.tts.tacotron2 import create_hparams
|
||||
from dlas.data.audio.fast_paired_dataset import \
|
||||
FastPairedVoiceDataset as D
|
||||
from dlas.models.audio.tts.tacotron2 import create_hparams
|
||||
default_params = create_hparams()
|
||||
default_params.update(dataset_opt)
|
||||
dataset_opt = munchify(default_params)
|
||||
elif mode == 'fast_paired_voice_audio_with_phonemes':
|
||||
from data.audio.fast_paired_dataset_with_phonemes import FastPairedVoiceDataset as D
|
||||
from models.audio.tts.tacotron2 import create_hparams
|
||||
from dlas.data.audio.fast_paired_dataset_with_phonemes import \
|
||||
FastPairedVoiceDataset as D
|
||||
from dlas.models.audio.tts.tacotron2 import create_hparams
|
||||
default_params = create_hparams()
|
||||
default_params.update(dataset_opt)
|
||||
dataset_opt = munchify(default_params)
|
||||
elif mode == 'gpt_tts':
|
||||
from data.audio.gpt_tts_dataset import GptTtsDataset as D
|
||||
from data.audio.gpt_tts_dataset import GptTtsCollater as C
|
||||
from dlas.data.audio.gpt_tts_dataset import GptTtsCollater as C
|
||||
from dlas.data.audio.gpt_tts_dataset import GptTtsDataset as D
|
||||
collate = C(dataset_opt)
|
||||
elif mode == 'unsupervised_audio':
|
||||
from data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset as D
|
||||
from dlas.data.audio.unsupervised_audio_dataset import \
|
||||
UnsupervisedAudioDataset as D
|
||||
elif mode == 'unsupervised_audio_with_noise':
|
||||
from data.audio.audio_with_noise_dataset import AudioWithNoiseDataset as D
|
||||
from dlas.data.audio.audio_with_noise_dataset import \
|
||||
AudioWithNoiseDataset as D
|
||||
elif mode == 'preprocessed_mel':
|
||||
from data.audio.preprocessed_mel_dataset import PreprocessedMelDataset as D
|
||||
from dlas.data.audio.preprocessed_mel_dataset import \
|
||||
PreprocessedMelDataset as D
|
||||
elif mode == 'grand_conjoined_voice':
|
||||
from data.audio.grand_conjoined_dataset import GrandConjoinedDataset as D
|
||||
from data.zero_pad_dict_collate import ZeroPadDictCollate as C
|
||||
from dlas.data.audio.grand_conjoined_dataset import \
|
||||
GrandConjoinedDataset as D
|
||||
from dlas.data.zero_pad_dict_collate import ZeroPadDictCollate as C
|
||||
if opt_get(dataset_opt, ['needs_collate'], False):
|
||||
collate = C()
|
||||
else:
|
||||
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
||||
raise NotImplementedError(
|
||||
'Dataset [{:s}] is not recognized.'.format(mode))
|
||||
dataset = D(dataset_opt)
|
||||
|
||||
if return_collate:
|
||||
|
@ -115,9 +128,10 @@ def create_dataset(dataset_opt, return_collate=False):
|
|||
def get_dataset_debugger(dataset_opt):
|
||||
mode = dataset_opt['mode']
|
||||
if mode == 'paired_voice_audio':
|
||||
from data.audio.paired_voice_audio_dataset import PairedVoiceDebugger
|
||||
from dlas.data.audio.paired_voice_audio_dataset import \
|
||||
PairedVoiceDebugger
|
||||
return PairedVoiceDebugger()
|
||||
elif mode == 'fast_paired_voice_audio':
|
||||
from data.audio.fast_paired_dataset import FastPairedVoiceDebugger
|
||||
from dlas.data.audio.fast_paired_dataset import FastPairedVoiceDebugger
|
||||
return FastPairedVoiceDebugger()
|
||||
return None
|
||||
return None
|
||||
|
|
0
dlas/data/audio/__init__.py
Normal file
0
dlas/data/audio/__init__.py
Normal file
|
@ -2,18 +2,17 @@ import random
|
|||
import sys
|
||||
from math import pi
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch.nn.functional as F
|
||||
from data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset, load_audio
|
||||
from data.util import load_paths_from_cache, find_files_of_type, is_audio_file
|
||||
|
||||
# Just all ones.
|
||||
from utils.util import opt_get
|
||||
from dlas.data.audio.unsupervised_audio_dataset import (
|
||||
UnsupervisedAudioDataset, load_audio)
|
||||
from dlas.data.util import (find_files_of_type, is_audio_file,
|
||||
load_paths_from_cache)
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
def _integration_fn_fully_enabled(n):
|
||||
|
@ -23,7 +22,7 @@ def _integration_fn_fully_enabled(n):
|
|||
# Randomly assigns up to 5 blocks of the output tensor the value '1'. Rest is zero
|
||||
def _integration_fn_spiky(n):
|
||||
fn = torch.zeros((n,))
|
||||
spikes = random.randint(1,5)
|
||||
spikes = random.randint(1, 5)
|
||||
for _ in range(spikes):
|
||||
sz = random.randint(n//8, n//2)
|
||||
pos = random.randint(0, n)
|
||||
|
@ -35,18 +34,19 @@ def _integration_fn_spiky(n):
|
|||
# Uses a sinusoidal ramp up and down (of random length) to a peak which is held for a random duration.
|
||||
def _integration_fn_smooth(n):
|
||||
center = random.randint(1, n-2)
|
||||
max_duration=n-center-1
|
||||
max_duration = n-center-1
|
||||
duration = random.randint(max_duration//4, max_duration)
|
||||
end = center+duration
|
||||
|
||||
ramp_up_sz = random.randint(n//16,n//4)
|
||||
ramp_up = torch.sin(pi*torch.arange(0,ramp_up_sz)/(2*ramp_up_sz))
|
||||
ramp_up_sz = random.randint(n//16, n//4)
|
||||
ramp_up = torch.sin(pi*torch.arange(0, ramp_up_sz)/(2*ramp_up_sz))
|
||||
if ramp_up_sz > center:
|
||||
ramp_up = ramp_up[(ramp_up_sz-center):]
|
||||
ramp_up_sz = center
|
||||
|
||||
ramp_down_sz = random.randint(n//16,n//4)
|
||||
ramp_down = torch.flip(torch.sin(pi*torch.arange(0,ramp_down_sz)/(2*ramp_down_sz)), dims=[0])
|
||||
ramp_down_sz = random.randint(n//16, n//4)
|
||||
ramp_down = torch.flip(
|
||||
torch.sin(pi*torch.arange(0, ramp_down_sz)/(2*ramp_down_sz)), dims=[0])
|
||||
if ramp_down_sz > (n-end):
|
||||
ramp_down = ramp_down[:(n-end)]
|
||||
ramp_down_sz = n-end
|
||||
|
@ -71,16 +71,22 @@ def load_rir(path, sr, max_sz):
|
|||
Wraps a unsupervised_audio_dataset and applies noise to the output clips, then provides labels depending on what
|
||||
noise was added.
|
||||
'''
|
||||
|
||||
|
||||
class AudioWithNoiseDataset(Dataset):
|
||||
def __init__(self, opt):
|
||||
self.underlying_dataset = UnsupervisedAudioDataset(opt)
|
||||
self.env_noise_paths = load_paths_from_cache(opt['env_noise_paths'], opt['env_noise_cache'])
|
||||
self.music_paths = load_paths_from_cache(opt['music_paths'], opt['music_cache'])
|
||||
self.openair_paths = find_files_of_type('img', opt['openair_path'], qualifier=is_audio_file)[0]
|
||||
self.env_noise_paths = load_paths_from_cache(
|
||||
opt['env_noise_paths'], opt['env_noise_cache'])
|
||||
self.music_paths = load_paths_from_cache(
|
||||
opt['music_paths'], opt['music_cache'])
|
||||
self.openair_paths = find_files_of_type(
|
||||
'img', opt['openair_path'], qualifier=is_audio_file)[0]
|
||||
self.min_volume = opt_get(opt, ['min_noise_volume'], .2)
|
||||
self.max_volume = opt_get(opt, ['max_noise_volume'], .5)
|
||||
self.sampling_rate = self.underlying_dataset.sampling_rate
|
||||
self.use_gpu_for_reverb_compute = opt_get(opt, ['use_gpu_for_reverb_compute'], True)
|
||||
self.use_gpu_for_reverb_compute = opt_get(
|
||||
opt, ['use_gpu_for_reverb_compute'], True)
|
||||
self.openair_kernels = None
|
||||
self.current_item_fetch = 0
|
||||
self.fetch_error_count = 0
|
||||
|
@ -90,7 +96,8 @@ class AudioWithNoiseDataset(Dataset):
|
|||
# Load the openair reverbs as CUDA tensors.
|
||||
self.openair_kernels = []
|
||||
for oa in self.openair_paths:
|
||||
self.openair_kernels.append(load_rir(oa, self.underlying_dataset.sampling_rate, self.underlying_dataset.sampling_rate*2).cuda())
|
||||
self.openair_kernels.append(load_rir(
|
||||
oa, self.underlying_dataset.sampling_rate, self.underlying_dataset.sampling_rate*2).cuda())
|
||||
|
||||
def __getitem__(self, item):
|
||||
if self.current_item_fetch != item:
|
||||
|
@ -113,10 +120,11 @@ class AudioWithNoiseDataset(Dataset):
|
|||
clip = clip * clipvol
|
||||
|
||||
label = random.randint(0, 4) # Current excludes GSM corruption.
|
||||
#label = 3
|
||||
# label = 3
|
||||
if label > 0 and label < 4: # 0 is basically "leave it alone"
|
||||
aug_needed = True
|
||||
augvol = (random.random() * (self.max_volume-self.min_volume) + self.min_volume)
|
||||
augvol = (random.random() * (self.max_volume -
|
||||
self.min_volume) + self.min_volume)
|
||||
if label == 1:
|
||||
# Add environmental noise.
|
||||
augpath = random.choice(self.env_noise_paths)
|
||||
|
@ -131,13 +139,15 @@ class AudioWithNoiseDataset(Dataset):
|
|||
# This can take two forms:
|
||||
if padding_room < 22000 or random.random() < .5:
|
||||
# (1) The voices talk over one another. If there is no padding room, we always take this choice.
|
||||
intg_fns = [_integration_fn_smooth, _integration_fn_fully_enabled]
|
||||
intg_fns = [_integration_fn_smooth,
|
||||
_integration_fn_fully_enabled]
|
||||
else:
|
||||
# (2) There are simply two voices in the clip, separated from one another.
|
||||
# This is a special case that does not use the same logic as the rest of the augmentations.
|
||||
aug = load_audio(augpath, self.underlying_dataset.sampling_rate)
|
||||
aug = load_audio(
|
||||
augpath, self.underlying_dataset.sampling_rate)
|
||||
# Pad with some random silence
|
||||
aug = F.pad(aug, (random.randint(20,4000), 0))
|
||||
aug = F.pad(aug, (random.randint(20, 4000), 0))
|
||||
# Fit what we can given the padding room we have.
|
||||
aug = aug[:, :padding_room]
|
||||
clip = torch.cat([clip, aug], dim=1)
|
||||
|
@ -146,7 +156,8 @@ class AudioWithNoiseDataset(Dataset):
|
|||
out['clip_lengths'] = torch.tensor(clip.shape[-1])
|
||||
aug_needed = False
|
||||
if aug_needed:
|
||||
aug = load_audio(augpath, self.underlying_dataset.sampling_rate)
|
||||
aug = load_audio(
|
||||
augpath, self.underlying_dataset.sampling_rate)
|
||||
if aug.shape[1] > clip.shape[1]:
|
||||
n, cn = aug.shape[1], clip.shape[1]
|
||||
gap = n-cn
|
||||
|
@ -157,7 +168,8 @@ class AudioWithNoiseDataset(Dataset):
|
|||
if aug.shape[1] < clip.shape[1]:
|
||||
gap = clip.shape[1] - aug.shape[1]
|
||||
placement = random.randint(0, gap-1)
|
||||
aug = torch.nn.functional.pad(aug, (placement, gap-placement))
|
||||
aug = torch.nn.functional.pad(
|
||||
aug, (placement, gap-placement))
|
||||
clip = clip + aug
|
||||
elif label == 4:
|
||||
# Perform reverb (to simulate being in a large room with an omni-mic). This is performed by convolving
|
||||
|
@ -166,19 +178,23 @@ class AudioWithNoiseDataset(Dataset):
|
|||
rir = random.choice(self.openair_kernels)
|
||||
else:
|
||||
augpath = random.choice(self.openair_paths)
|
||||
rir = load_rir(augpath, self.underlying_dataset.sampling_rate, clip.shape[-1])
|
||||
rir = load_rir(
|
||||
augpath, self.underlying_dataset.sampling_rate, clip.shape[-1])
|
||||
clip = torch.nn.functional.pad(clip, (rir.shape[1]-1, 0))
|
||||
if self.use_gpu_for_reverb_compute:
|
||||
clip = clip.cuda()
|
||||
clip = torch.nn.functional.conv1d(clip.unsqueeze(0), rir.unsqueeze(0)).squeeze(0).cpu()
|
||||
clip = torch.nn.functional.conv1d(
|
||||
clip.unsqueeze(0), rir.unsqueeze(0)).squeeze(0).cpu()
|
||||
elif label == 5:
|
||||
# Apply the GSM codec to simulate cellular phone audio.
|
||||
clip = torchaudio.functional.apply_codec(clip, self.underlying_dataset.sampling_rate, format="gsm")
|
||||
clip = torchaudio.functional.apply_codec(
|
||||
clip, self.underlying_dataset.sampling_rate, format="gsm")
|
||||
except:
|
||||
if self.fetch_error_count > 10:
|
||||
print(f"Exception encountered processing {item}, re-trying because this is often just a failed aug.")
|
||||
print(
|
||||
f"Exception encountered processing {item}, re-trying because this is often just a failed aug.")
|
||||
print(sys.exc_info())
|
||||
#raise # Uncomment to surface exceptions.
|
||||
# raise # Uncomment to surface exceptions.
|
||||
self.fetch_error_count += 1
|
||||
return self[item]
|
||||
|
||||
|
@ -187,7 +203,7 @@ class AudioWithNoiseDataset(Dataset):
|
|||
clip = F.pad(clip, (0, padding_room))
|
||||
out['clip'] = clip
|
||||
out['label'] = label
|
||||
#out['aug'] = aug
|
||||
# out['aug'] = aug
|
||||
out['augpath'] = augpath
|
||||
out['augvol'] = augvol
|
||||
out['clipvol'] = clipvol
|
||||
|
@ -216,14 +232,15 @@ if __name__ == '__main__':
|
|||
'openair_path': 'D:\\data\\audio\\openair\\resampled',
|
||||
'use_gpu_for_reverb_compute': False,
|
||||
}
|
||||
from data import create_dataset, create_dataloader, util
|
||||
from data import create_dataloader, create_dataset, util
|
||||
|
||||
ds = create_dataset(params)
|
||||
dl = create_dataloader(ds, params, pin_memory=False)
|
||||
i = 0
|
||||
for b in tqdm(dl):
|
||||
for b_ in range(b['clip'].shape[0]):
|
||||
#torchaudio.save(f'{i}_clip_{b_}_{b["label"][b_].item()}.wav', b['clip'][b_][:, :b['clip_lengths'][b_]], ds.sampling_rate)
|
||||
#torchaudio.save(f'{i}_clip_{b_}_aug.wav', b['aug'][b_], ds.sampling_rate)
|
||||
print(f'{i} aug path: {b["augpath"][b_]} aug volume: {b["augvol"][b_]} clip volume: {b["clipvol"][b_]}')
|
||||
# torchaudio.save(f'{i}_clip_{b_}_{b["label"][b_].item()}.wav', b['clip'][b_][:, :b['clip_lengths'][b_]], ds.sampling_rate)
|
||||
# torchaudio.save(f'{i}_clip_{b_}_aug.wav', b['aug'][b_], ds.sampling_rate)
|
||||
print(
|
||||
f'{i} aug path: {b["augpath"][b_]} aug volume: {b["augvol"][b_]} clip volume: {b["clipvol"][b_]}')
|
||||
i += 1
|
||||
|
|
|
@ -12,13 +12,15 @@ import torchaudio
|
|||
from tqdm import tqdm
|
||||
from transformers import Wav2Vec2CTCTokenizer
|
||||
|
||||
from data.audio.paired_voice_audio_dataset import CharacterTokenizer
|
||||
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
|
||||
from utils.util import opt_get
|
||||
from dlas.data.audio.paired_voice_audio_dataset import CharacterTokenizer
|
||||
from dlas.data.audio.unsupervised_audio_dataset import (load_audio,
|
||||
load_similar_clips)
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
def parse_tsv_aligned_codes(line, base_path):
|
||||
fpt = line.strip().split('\t')
|
||||
|
||||
def convert_string_list_to_tensor(strlist):
|
||||
if strlist.startswith('['):
|
||||
strlist = strlist[1:]
|
||||
|
@ -43,6 +45,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
|
||||
The upshot is that this dataset loads extremely quickly and consumes almost no system memory.
|
||||
"""
|
||||
|
||||
def __init__(self, hparams):
|
||||
self.paths = hparams['path']
|
||||
if not isinstance(self.paths, list):
|
||||
|
@ -52,26 +55,33 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
self.types = opt_get(hparams, ['types'], [0 for _ in self.paths])
|
||||
|
||||
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
|
||||
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
|
||||
self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
|
||||
self.produce_ctc_metadata = opt_get(hparams, ['produce_ctc_metadata'], False)
|
||||
self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False)
|
||||
self.conditioning_candidates = opt_get(
|
||||
hparams, ['num_conditioning_candidates'], 1)
|
||||
self.conditioning_length = opt_get(
|
||||
hparams, ['conditioning_length'], 44100)
|
||||
self.produce_ctc_metadata = opt_get(
|
||||
hparams, ['produce_ctc_metadata'], False)
|
||||
self.debug_failures = opt_get(
|
||||
hparams, ['debug_loading_failures'], False)
|
||||
self.text_cleaners = hparams.text_cleaners
|
||||
self.sample_rate = hparams.sample_rate
|
||||
self.aligned_codes_to_audio_ratio = 443 * self.sample_rate // 22050
|
||||
self.max_wav_len = opt_get(hparams, ['max_wav_length'], None)
|
||||
self.load_aligned_codes = opt_get(hparams, ['load_aligned_codes'], False)
|
||||
self.load_aligned_codes = opt_get(
|
||||
hparams, ['load_aligned_codes'], False)
|
||||
if self.max_wav_len is not None:
|
||||
self.max_aligned_codes = self.max_wav_len // self.aligned_codes_to_audio_ratio
|
||||
self.max_text_len = opt_get(hparams, ['max_text_length'], None)
|
||||
assert self.max_wav_len is not None and self.max_text_len is not None
|
||||
self.use_bpe_tokenizer = opt_get(hparams, ['use_bpe_tokenizer'], False)
|
||||
if self.use_bpe_tokenizer:
|
||||
from data.audio.voice_tokenizer import VoiceBpeTokenizer
|
||||
self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
|
||||
from dlas.data.audio.voice_tokenizer import VoiceBpeTokenizer
|
||||
self.tokenizer = VoiceBpeTokenizer(opt_get(
|
||||
hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
|
||||
else:
|
||||
self.tokenizer = CharacterTokenizer()
|
||||
self.skipped_items = 0 # records how many items are skipped when accessing an index.
|
||||
# records how many items are skipped when accessing an index.
|
||||
self.skipped_items = 0
|
||||
|
||||
self.load_times = torch.zeros((256,))
|
||||
self.load_ind = 0
|
||||
|
@ -110,7 +120,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
try: # This can fail when seeking to a UTF-8 escape byte.
|
||||
f.readline()
|
||||
except:
|
||||
return self.load_random_line(depth=depth + 1), type # On failure, just recurse and try again.
|
||||
# On failure, just recurse and try again.
|
||||
return self.load_random_line(depth=depth + 1), type
|
||||
l2 = f.readline()
|
||||
|
||||
if l2:
|
||||
|
@ -119,14 +130,16 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
return parse_tsv_aligned_codes(l2, base_path), type
|
||||
except:
|
||||
print(f"error parsing random offset: {sys.exc_info()}")
|
||||
return self.load_random_line(depth=depth+1), type # On failure, just recurse and try again.
|
||||
# On failure, just recurse and try again.
|
||||
return self.load_random_line(depth=depth+1), type
|
||||
|
||||
def get_ctc_metadata(self, codes):
|
||||
grouped = groupby(codes.tolist())
|
||||
rcodes, repeats, seps = [], [], [0]
|
||||
for val, group in grouped:
|
||||
if val == 0:
|
||||
seps[-1] = len(list(group)) # This is a very important distinction! It means the padding belongs to the character proceeding it.
|
||||
# This is a very important distinction! It means the padding belongs to the character proceeding it.
|
||||
seps[-1] = len(list(group))
|
||||
else:
|
||||
rcodes.append(val)
|
||||
repeats.append(len(list(group)))
|
||||
|
@ -142,7 +155,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
if rcodes.shape[0] < self.max_text_len:
|
||||
gap = self.max_text_len - rcodes.shape[0]
|
||||
rcodes = F.pad(rcodes, (0, gap))
|
||||
repeats = F.pad(repeats, (0, gap), value=1) # The minimum value for repeats is 1, hence this is the pad value too.
|
||||
# The minimum value for repeats is 1, hence this is the pad value too.
|
||||
repeats = F.pad(repeats, (0, gap), value=1)
|
||||
seps = F.pad(seps, (0, gap))
|
||||
elif rcodes.shape[0] > self.max_text_len:
|
||||
rcodes = rcodes[:self.max_text_len]
|
||||
|
@ -165,7 +179,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
if text is None or len(text.strip()) == 0:
|
||||
raise ValueError
|
||||
cond, cond_is_self = load_similar_clips(apt[0], self.conditioning_length, self.sample_rate,
|
||||
n=self.conditioning_candidates) if self.load_conditioning else (None, False)
|
||||
n=self.conditioning_candidates) if self.load_conditioning else (None, False)
|
||||
except:
|
||||
if self.skipped_items > 100:
|
||||
raise # Rethrow if we have nested too far.
|
||||
|
@ -179,12 +193,13 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
self.skipped_items = 0
|
||||
if wav is None or \
|
||||
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
|
||||
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
||||
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
||||
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
||||
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
|
||||
if self.debug_failures:
|
||||
print(f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}")
|
||||
rv = random.randint(0,len(self)-1)
|
||||
print(
|
||||
f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}")
|
||||
rv = random.randint(0, len(self)-1)
|
||||
return self[rv]
|
||||
orig_output = wav.shape[-1]
|
||||
orig_text_len = tseq.shape[0]
|
||||
|
@ -192,7 +207,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
if wav.shape[-1] != self.max_wav_len:
|
||||
wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1]))
|
||||
# These codes are aligned to audio inputs, so make sure to pad them as well.
|
||||
aligned_codes = F.pad(aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0]))
|
||||
aligned_codes = F.pad(
|
||||
aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0]))
|
||||
if tseq.shape[0] != self.max_text_len:
|
||||
tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0]))
|
||||
|
||||
|
@ -223,7 +239,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
return res
|
||||
|
||||
def __len__(self):
|
||||
return self.total_size_bytes // 1000 # 1000 cuts down a TSV file to the actual length pretty well.
|
||||
# 1000 cuts down a TSV file to the actual length pretty well.
|
||||
return self.total_size_bytes // 1000
|
||||
|
||||
|
||||
class FastPairedVoiceDebugger:
|
||||
|
@ -243,7 +260,8 @@ class FastPairedVoiceDebugger:
|
|||
if isinstance(state, dict):
|
||||
self.total_items = opt_get(state, ['total_items'], 0)
|
||||
self.loaded_items = opt_get(state, ['loaded_items'], 0)
|
||||
self.self_conditioning_items = opt_get(state, ['self_conditioning_items'], 0)
|
||||
self.self_conditioning_items = opt_get(
|
||||
state, ['self_conditioning_items'], 0)
|
||||
|
||||
def update(self, batch):
|
||||
self.total_items += batch['wav'].shape[0]
|
||||
|
@ -252,7 +270,8 @@ class FastPairedVoiceDebugger:
|
|||
for filename in batch['filenames']:
|
||||
self.unique_files.add(hashlib.sha256(filename.encode('utf-8')))
|
||||
if 'conditioning' in batch.keys():
|
||||
self.self_conditioning_items += batch['conditioning_contains_self'].sum().item()
|
||||
self.self_conditioning_items += batch['conditioning_contains_self'].sum(
|
||||
).item()
|
||||
|
||||
def get_debugging_map(self):
|
||||
return {
|
||||
|
@ -269,13 +288,13 @@ if __name__ == '__main__':
|
|||
params = {
|
||||
'mode': 'fast_paired_voice_audio',
|
||||
'path': ['y:/libritts/train-other-500/transcribed-oco.tsv',
|
||||
'y:/libritts/train-clean-100/transcribed-oco.tsv',
|
||||
'y:/libritts/train-clean-360/transcribed-oco.tsv',
|
||||
'y:/clips/books1/transcribed-oco.tsv',
|
||||
'y:/clips/books2/transcribed-oco.tsv',
|
||||
'y:/bigasr_dataset/hifi_tts/transcribed-oco.tsv',
|
||||
'y:/clips/podcasts-1/transcribed-oco.tsv',],
|
||||
'types': [0,1,1,1,2,2,0],
|
||||
'y:/libritts/train-clean-100/transcribed-oco.tsv',
|
||||
'y:/libritts/train-clean-360/transcribed-oco.tsv',
|
||||
'y:/clips/books1/transcribed-oco.tsv',
|
||||
'y:/clips/books2/transcribed-oco.tsv',
|
||||
'y:/bigasr_dataset/hifi_tts/transcribed-oco.tsv',
|
||||
'y:/clips/podcasts-1/transcribed-oco.tsv',],
|
||||
'types': [0, 1, 1, 1, 2, 2, 0],
|
||||
'phase': 'train',
|
||||
'n_workers': 0,
|
||||
'batch_size': batch_sz,
|
||||
|
@ -289,11 +308,12 @@ if __name__ == '__main__':
|
|||
'load_aligned_codes': True,
|
||||
'produce_ctc_metadata': True,
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
from data import create_dataloader, create_dataset
|
||||
|
||||
def save(b, i, ib, key, c=None):
|
||||
if c is not None:
|
||||
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', b[key][ib][c], 22050)
|
||||
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav',
|
||||
b[key][ib][c], 22050)
|
||||
else:
|
||||
torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)
|
||||
|
||||
|
@ -304,8 +324,8 @@ if __name__ == '__main__':
|
|||
max_pads, max_repeats = 0, 0
|
||||
for i, b in tqdm(enumerate(dl)):
|
||||
for ib in range(batch_sz):
|
||||
#max_pads = max(max_pads, b['ctc_pads'].max())
|
||||
#max_repeats = max(max_repeats, b['ctc_repeats'].max())
|
||||
# max_pads = max(max_pads, b['ctc_pads'].max())
|
||||
# max_repeats = max(max_repeats, b['ctc_repeats'].max())
|
||||
print(f'{i} {ib} {b["real_text"][ib]}')
|
||||
save(b, i, ib, 'wav')
|
||||
save(b, i, ib, 'conditioning', 0)
|
||||
|
@ -314,4 +334,3 @@ if __name__ == '__main__':
|
|||
if i > 15:
|
||||
break
|
||||
print(max_pads, max_repeats)
|
||||
|
||||
|
|
|
@ -12,13 +12,15 @@ import torchaudio
|
|||
from tqdm import tqdm
|
||||
from transformers import Wav2Vec2Processor
|
||||
|
||||
from data.audio.paired_voice_audio_dataset import CharacterTokenizer
|
||||
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
|
||||
from utils.util import opt_get
|
||||
from dlas.data.audio.paired_voice_audio_dataset import CharacterTokenizer
|
||||
from dlas.data.audio.unsupervised_audio_dataset import (load_audio,
|
||||
load_similar_clips)
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
def parse_tsv_aligned_codes(line, base_path):
|
||||
fpt = line.strip().split('\t')
|
||||
|
||||
def convert_string_list_to_tensor(strlist):
|
||||
if strlist.startswith('['):
|
||||
strlist = strlist[1:]
|
||||
|
@ -43,10 +45,12 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
|
||||
The upshot is that this dataset loads extremely quickly and consumes almost no system memory.
|
||||
"""
|
||||
|
||||
def __init__(self, hparams):
|
||||
self.paths = hparams['path']
|
||||
phoneme_paths = hparams['phoneme_paths']
|
||||
self.paths = [(p, False) for p in self.paths] + [(p, True) for p in phoneme_paths]
|
||||
self.paths = [(p, False) for p in self.paths] + [(p, True)
|
||||
for p in phoneme_paths]
|
||||
|
||||
self.paths_size_bytes = [os.path.getsize(p) for p, _ in self.paths]
|
||||
self.total_size_bytes = sum(self.paths_size_bytes)
|
||||
|
@ -54,28 +58,36 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
|
||||
self.normal_text_end_token = hparams['normal_text_end_token']
|
||||
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
|
||||
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
|
||||
self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
|
||||
self.produce_ctc_metadata = opt_get(hparams, ['produce_ctc_metadata'], False)
|
||||
self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False)
|
||||
self.conditioning_candidates = opt_get(
|
||||
hparams, ['num_conditioning_candidates'], 1)
|
||||
self.conditioning_length = opt_get(
|
||||
hparams, ['conditioning_length'], 44100)
|
||||
self.produce_ctc_metadata = opt_get(
|
||||
hparams, ['produce_ctc_metadata'], False)
|
||||
self.debug_failures = opt_get(
|
||||
hparams, ['debug_loading_failures'], False)
|
||||
self.text_cleaners = hparams.text_cleaners
|
||||
self.sample_rate = hparams.sample_rate
|
||||
self.aligned_codes_to_audio_ratio = 443 * self.sample_rate // 22050
|
||||
self.max_wav_len = opt_get(hparams, ['max_wav_length'], None)
|
||||
self.load_aligned_codes = opt_get(hparams, ['load_aligned_codes'], False)
|
||||
self.load_aligned_codes = opt_get(
|
||||
hparams, ['load_aligned_codes'], False)
|
||||
if self.max_wav_len is not None:
|
||||
self.max_aligned_codes = self.max_wav_len // self.aligned_codes_to_audio_ratio
|
||||
self.max_text_len = opt_get(hparams, ['max_text_length'], None)
|
||||
assert self.max_wav_len is not None and self.max_text_len is not None
|
||||
self.use_bpe_tokenizer = opt_get(hparams, ['use_bpe_tokenizer'], False)
|
||||
if self.use_bpe_tokenizer:
|
||||
from data.audio.voice_tokenizer import VoiceBpeTokenizer
|
||||
self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
|
||||
from dlas.data.audio.voice_tokenizer import VoiceBpeTokenizer
|
||||
self.tokenizer = VoiceBpeTokenizer(opt_get(
|
||||
hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
|
||||
else:
|
||||
self.tokenizer = CharacterTokenizer()
|
||||
self.ipa_phoneme_tokenizer = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft").tokenizer
|
||||
self.ipa_phoneme_tokenizer = Wav2Vec2Processor.from_pretrained(
|
||||
"facebook/wav2vec2-lv-60-espeak-cv-ft").tokenizer
|
||||
self.ipa_phoneme_tokenizer.do_phonemize = False
|
||||
self.skipped_items = 0 # records how many items are skipped when accessing an index.
|
||||
# records how many items are skipped when accessing an index.
|
||||
self.skipped_items = 0
|
||||
|
||||
self.load_times = torch.zeros((256,))
|
||||
self.load_ind = 0
|
||||
|
@ -117,7 +129,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
try: # This can fail when seeking to a UTF-8 escape byte.
|
||||
f.readline()
|
||||
except:
|
||||
return self.load_random_line(depth=depth + 1) # On failure, just recurse and try again.
|
||||
# On failure, just recurse and try again.
|
||||
return self.load_random_line(depth=depth + 1)
|
||||
l2 = f.readline()
|
||||
|
||||
if l2:
|
||||
|
@ -126,14 +139,16 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
return parse_tsv_aligned_codes(l2, base_path), type, is_phonetic
|
||||
except:
|
||||
print(f"error parsing random offset: {sys.exc_info()}")
|
||||
return self.load_random_line(depth=depth+1) # On failure, just recurse and try again.
|
||||
# On failure, just recurse and try again.
|
||||
return self.load_random_line(depth=depth+1)
|
||||
|
||||
def get_ctc_metadata(self, codes):
|
||||
grouped = groupby(codes.tolist())
|
||||
rcodes, repeats, seps = [], [], [0]
|
||||
for val, group in grouped:
|
||||
if val == 0:
|
||||
seps[-1] = len(list(group)) # This is a very important distinction! It means the padding belongs to the character proceeding it.
|
||||
# This is a very important distinction! It means the padding belongs to the character proceeding it.
|
||||
seps[-1] = len(list(group))
|
||||
else:
|
||||
rcodes.append(val)
|
||||
repeats.append(len(list(group)))
|
||||
|
@ -149,7 +164,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
if rcodes.shape[0] < self.max_text_len:
|
||||
gap = self.max_text_len - rcodes.shape[0]
|
||||
rcodes = F.pad(rcodes, (0, gap))
|
||||
repeats = F.pad(repeats, (0, gap), value=1) # The minimum value for repeats is 1, hence this is the pad value too.
|
||||
# The minimum value for repeats is 1, hence this is the pad value too.
|
||||
repeats = F.pad(repeats, (0, gap), value=1)
|
||||
seps = F.pad(seps, (0, gap))
|
||||
elif rcodes.shape[0] > self.max_text_len:
|
||||
rcodes = rcodes[:self.max_text_len]
|
||||
|
@ -171,7 +187,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
if text is None or len(text.strip()) == 0:
|
||||
raise ValueError
|
||||
cond, cond_is_self = load_similar_clips(apt[0], self.conditioning_length, self.sample_rate,
|
||||
n=self.conditioning_candidates) if self.load_conditioning else (None, False)
|
||||
n=self.conditioning_candidates) if self.load_conditioning else (None, False)
|
||||
except:
|
||||
if self.skipped_items > 100:
|
||||
raise # Rethrow if we have nested too far.
|
||||
|
@ -185,12 +201,13 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
self.skipped_items = 0
|
||||
if wav is None or \
|
||||
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
|
||||
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
||||
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
||||
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
||||
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
|
||||
if self.debug_failures:
|
||||
print(f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}")
|
||||
rv = random.randint(0,len(self)-1)
|
||||
print(
|
||||
f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}")
|
||||
rv = random.randint(0, len(self)-1)
|
||||
return self[rv]
|
||||
|
||||
# Shift phonetic token and aligned_code tokens over.
|
||||
|
@ -206,7 +223,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
if wav.shape[-1] != self.max_wav_len:
|
||||
wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1]))
|
||||
# These codes are aligned to audio inputs, so make sure to pad them as well.
|
||||
aligned_codes = F.pad(aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0]))
|
||||
aligned_codes = F.pad(
|
||||
aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0]))
|
||||
if tseq.shape[0] != self.max_text_len:
|
||||
tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0]))
|
||||
|
||||
|
@ -237,7 +255,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
return res
|
||||
|
||||
def __len__(self):
|
||||
return self.total_size_bytes // 1000 # 1000 cuts down a TSV file to the actual length pretty well.
|
||||
# 1000 cuts down a TSV file to the actual length pretty well.
|
||||
return self.total_size_bytes // 1000
|
||||
|
||||
|
||||
class FastPairedVoiceDebugger:
|
||||
|
@ -257,7 +276,8 @@ class FastPairedVoiceDebugger:
|
|||
if isinstance(state, dict):
|
||||
self.total_items = opt_get(state, ['total_items'], 0)
|
||||
self.loaded_items = opt_get(state, ['loaded_items'], 0)
|
||||
self.self_conditioning_items = opt_get(state, ['self_conditioning_items'], 0)
|
||||
self.self_conditioning_items = opt_get(
|
||||
state, ['self_conditioning_items'], 0)
|
||||
|
||||
def update(self, batch):
|
||||
self.total_items += batch['wav'].shape[0]
|
||||
|
@ -266,7 +286,8 @@ class FastPairedVoiceDebugger:
|
|||
for filename in batch['filenames']:
|
||||
self.unique_files.add(hashlib.sha256(filename.encode('utf-8')))
|
||||
if 'conditioning' in batch.keys():
|
||||
self.self_conditioning_items += batch['conditioning_contains_self'].sum().item()
|
||||
self.self_conditioning_items += batch['conditioning_contains_self'].sum(
|
||||
).item()
|
||||
|
||||
def get_debugging_map(self):
|
||||
return {
|
||||
|
@ -284,7 +305,7 @@ if __name__ == '__main__':
|
|||
'mode': 'fast_paired_voice_audio_with_phonemes',
|
||||
'path': ['y:/libritts/train-clean-100/transcribed-oco.tsv',],
|
||||
'phoneme_paths': ['y:/libritts/train-other-500/transcribed-phoneme-oco.tsv'],
|
||||
'types': [0,0],
|
||||
'types': [0, 0],
|
||||
'normal_text_end_token': 256,
|
||||
'phase': 'train',
|
||||
'n_workers': 0,
|
||||
|
@ -299,11 +320,12 @@ if __name__ == '__main__':
|
|||
'load_aligned_codes': False,
|
||||
'debug_loading_failures': True,
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
from data import create_dataloader, create_dataset
|
||||
|
||||
def save(b, i, ib, key, c=None):
|
||||
if c is not None:
|
||||
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', b[key][ib][c], 22050)
|
||||
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav',
|
||||
b[key][ib][c], 22050)
|
||||
else:
|
||||
torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)
|
||||
|
||||
|
@ -314,14 +336,13 @@ if __name__ == '__main__':
|
|||
max_pads, max_repeats = 0, 0
|
||||
for i, b in tqdm(enumerate(dl)):
|
||||
for ib in range(batch_sz):
|
||||
#max_pads = max(max_pads, b['ctc_pads'].max())
|
||||
#max_repeats = max(max_repeats, b['ctc_repeats'].max())
|
||||
# max_pads = max(max_pads, b['ctc_pads'].max())
|
||||
# max_repeats = max(max_repeats, b['ctc_repeats'].max())
|
||||
print(f'{i} {ib} {b["real_text"][ib]}')
|
||||
#save(b, i, ib, 'wav')
|
||||
#save(b, i, ib, 'conditioning', 0)
|
||||
#save(b, i, ib, 'conditioning', 1)
|
||||
# save(b, i, ib, 'wav')
|
||||
# save(b, i, ib, 'conditioning', 0)
|
||||
# save(b, i, ib, 'conditioning', 1)
|
||||
pass
|
||||
if i > 15:
|
||||
break
|
||||
print(max_pads, max_repeats)
|
||||
|
||||
|
|
|
@ -6,9 +6,8 @@ import torch.utils.data
|
|||
from torch import LongTensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from models.audio.tts.tacotron2 import load_filepaths_and_text
|
||||
from models.audio.tts.tacotron2 import symbols
|
||||
from models.audio.tts.tacotron2 import text_to_sequence
|
||||
from dlas.models.audio.tts.tacotron2 import (load_filepaths_and_text, symbols,
|
||||
text_to_sequence)
|
||||
|
||||
|
||||
class GptTtsDataset(torch.utils.data.Dataset):
|
||||
|
@ -21,7 +20,7 @@ class GptTtsDataset(torch.utils.data.Dataset):
|
|||
def __init__(self, opt):
|
||||
self.path = os.path.dirname(opt['path'])
|
||||
self.audiopaths_and_text = load_filepaths_and_text(opt['path'])
|
||||
self.text_cleaners=['english_cleaners']
|
||||
self.text_cleaners = ['english_cleaners']
|
||||
|
||||
self.MEL_DICTIONARY_SIZE = opt['mel_vocab_size']+3
|
||||
self.MEL_START_TOKEN = LongTensor([self.MEL_DICTIONARY_SIZE-3])
|
||||
|
@ -32,7 +31,8 @@ class GptTtsDataset(torch.utils.data.Dataset):
|
|||
audiopath_and_text = self.audiopaths_and_text[index]
|
||||
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
|
||||
text = torch.IntTensor(text_to_sequence(text, self.text_cleaners))
|
||||
text = torch.cat([self.TEXT_START_TOKEN, text, self.TEXT_STOP_TOKEN], dim=0)
|
||||
text = torch.cat([self.TEXT_START_TOKEN, text,
|
||||
self.TEXT_STOP_TOKEN], dim=0)
|
||||
|
||||
# Fetch quantized MELs
|
||||
quant_path = audiopath.replace('wavs/', 'quantized_mels/') + '.pth'
|
||||
|
@ -57,8 +57,9 @@ class GptTtsCollater():
|
|||
|
||||
def __call__(self, batch):
|
||||
text_lens = [len(x[0]) for x in batch]
|
||||
#max_text_len = max(text_lens)
|
||||
max_text_len = self.MAX_SYMBOLS_PER_PHRASE # This forces all outputs to have the full 200 characters. Testing if this makes a difference.
|
||||
# max_text_len = max(text_lens)
|
||||
# This forces all outputs to have the full 200 characters. Testing if this makes a difference.
|
||||
max_text_len = self.MAX_SYMBOLS_PER_PHRASE
|
||||
mel_lens = [len(x[1]) for x in batch]
|
||||
max_mel_len = max(mel_lens)
|
||||
texts = []
|
||||
|
@ -70,7 +71,8 @@ class GptTtsCollater():
|
|||
text = F.pad(text, (0, max_text_len-len(text)), value=0)
|
||||
text = torch.where(text == 0, text_range_embedding, text)
|
||||
texts.append(text)
|
||||
qmels.append(F.pad(qmel, (0, max_mel_len-len(qmel)), value=self.MEL_PAD_TOKEN))
|
||||
qmels.append(F.pad(qmel, (0, max_mel_len-len(qmel)),
|
||||
value=self.MEL_PAD_TOKEN))
|
||||
|
||||
filenames = [j[2] for j in batch]
|
||||
|
||||
|
@ -96,7 +98,7 @@ if __name__ == '__main__':
|
|||
'batch_size': 16,
|
||||
'mel_vocab_size': 512,
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
from data import create_dataloader, create_dataset
|
||||
|
||||
ds, c = create_dataset(params, return_collate=True)
|
||||
dl = create_dataloader(ds, params, collate_fn=c)
|
||||
|
@ -107,5 +109,5 @@ if __name__ == '__main__':
|
|||
for b in tqdm(dl):
|
||||
max_mel = max(max_mel, b['padded_qmel'].shape[2])
|
||||
max_text = max(max_text, b['padded_text'].shape[1])
|
||||
m=torch.stack(m)
|
||||
m = torch.stack(m)
|
||||
print(m.mean(), m.std())
|
||||
|
|
|
@ -7,14 +7,15 @@ import torchaudio
|
|||
from munch import munchify
|
||||
from tqdm import tqdm
|
||||
|
||||
from data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset
|
||||
from data.text.hf_datasets_wrapper import HfDataset
|
||||
from utils.util import opt_get
|
||||
from dlas.data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset
|
||||
from dlas.data.text.hf_datasets_wrapper import HfDataset
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
def build_paired_voice_dataset(args):
|
||||
from data.audio.paired_voice_audio_dataset import TextWavLoader as D
|
||||
from models.audio.tts.tacotron2 import create_hparams
|
||||
|
||||
from dlas.data.audio.paired_voice_audio_dataset import TextWavLoader as D
|
||||
default_params = create_hparams()
|
||||
default_params.update(args)
|
||||
dataset_opt = munchify(default_params)
|
||||
|
@ -33,6 +34,7 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
|||
|
||||
Performs tokenization at this level, ignoring any tokenization performed by upstream datasets.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
sample_rate = 22050 # Fixed.
|
||||
paired_dataset_args = opt['paired_dataset_args']
|
||||
|
@ -47,7 +49,8 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
|||
self.max_solo_text_length = opt['max_solo_text_length']
|
||||
self.collate = opt_get(opt, ['needs_collate'], False)
|
||||
self.sample_rate = sample_rate
|
||||
self.num_conditioning_candidates = opt_get(opt, ['num_conditioning_candidates'], 0)
|
||||
self.num_conditioning_candidates = opt_get(
|
||||
opt, ['num_conditioning_candidates'], 0)
|
||||
self.conditioning_length = opt_get(opt, ['conditioning_length'], 44000)
|
||||
load_conditioning = self.num_conditioning_candidates > 0
|
||||
|
||||
|
@ -75,7 +78,8 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
|||
def fetch_text_at(self, i):
|
||||
try:
|
||||
txt = self.text[i % len(self.text)]['text']
|
||||
assert '*' not in txt # This is a hack to get around the use of '*' to mask expletives in some text-only datasets. There really isn't a linguistic use for this character anyways.
|
||||
# This is a hack to get around the use of '*' to mask expletives in some text-only datasets. There really isn't a linguistic use for this character anyways.
|
||||
assert '*' not in txt
|
||||
tok = self.speech_and_text.get_text(txt)
|
||||
padding_required = self.max_solo_text_length - tok.shape[0]
|
||||
if padding_required < 0:
|
||||
|
@ -137,7 +141,8 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
|||
sp = self.speech[i % len(self.speech)]
|
||||
# Set upper bound on solo speech lengths. This is handled automatically when collation is turned off, but needs to be done otherwise.
|
||||
sp['clip'] = sp['clip'][:, :self.max_solo_audio_length]
|
||||
sp['clip_lengths'] = sp['clip_lengths'].clamp(0, self.max_solo_audio_length)
|
||||
sp['clip_lengths'] = sp['clip_lengths'].clamp(
|
||||
0, self.max_solo_audio_length)
|
||||
return self.optionally_add_conditioning_candidates({
|
||||
'paired_audio': snt['wav'],
|
||||
'paired_audio_lengths': snt['wav_lengths'],
|
||||
|
@ -205,7 +210,7 @@ if __name__ == '__main__':
|
|||
'use_bpe_tokenizer': False,
|
||||
},
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
from data import create_dataloader, create_dataset
|
||||
os.remove('test_cache_delete_me2.pth')
|
||||
|
||||
ds, c = create_dataset(train_params, return_collate=True)
|
||||
|
@ -213,7 +218,8 @@ if __name__ == '__main__':
|
|||
|
||||
def save(b, i, ib, key, c=None):
|
||||
if c is not None:
|
||||
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', b[key][ib][c], 22050)
|
||||
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav',
|
||||
b[key][ib][c], 22050)
|
||||
else:
|
||||
torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)
|
||||
|
||||
|
@ -224,16 +230,17 @@ if __name__ == '__main__':
|
|||
m = None
|
||||
for i, b in tqdm(enumerate(dl)):
|
||||
for ib in range(batch_sz):
|
||||
#save(b, i, ib, 'paired_audio')
|
||||
#save(b, i, ib, 'paired_audio_conditioning', 0)
|
||||
#save(b, i, ib, 'paired_audio_conditioning', 1)
|
||||
print(f'Paired file: {b["paired_file"][ib]} text: {b["paired_text"][ib]}')
|
||||
print(f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}')
|
||||
#save(b, i, ib, 'speech_audio')
|
||||
#save(b, i, ib, 'speech_audio_conditioning', 0)
|
||||
#save(b, i, ib, 'speech_audio_conditioning', 1)
|
||||
#print(f'Text: {b["text_text"][ib]}')
|
||||
#print(f'Text decoded: {decode(b, ib, "text_tokens")}')
|
||||
# save(b, i, ib, 'paired_audio')
|
||||
# save(b, i, ib, 'paired_audio_conditioning', 0)
|
||||
# save(b, i, ib, 'paired_audio_conditioning', 1)
|
||||
print(
|
||||
f'Paired file: {b["paired_file"][ib]} text: {b["paired_text"][ib]}')
|
||||
print(
|
||||
f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}')
|
||||
# save(b, i, ib, 'speech_audio')
|
||||
# save(b, i, ib, 'speech_audio_conditioning', 0)
|
||||
# save(b, i, ib, 'speech_audio_conditioning', 1)
|
||||
# print(f'Text: {b["text_text"][ib]}')
|
||||
# print(f'Text decoded: {decode(b, ib, "text_tokens")}')
|
||||
if i > 5:
|
||||
break
|
||||
|
||||
|
|
|
@ -7,32 +7,36 @@ import torch.utils.data
|
|||
import torchaudio
|
||||
from tqdm import tqdm
|
||||
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from data.util import find_files_of_type, is_audio_file
|
||||
from models.audio.tts.tacotron2 import load_filepaths_and_text
|
||||
from models.audio.tts.tacotron2 import text_to_sequence
|
||||
from utils.util import opt_get
|
||||
from dlas.data.audio.unsupervised_audio_dataset import load_audio
|
||||
from dlas.data.util import find_files_of_type, is_audio_file
|
||||
from dlas.models.audio.tts.tacotron2 import (load_filepaths_and_text,
|
||||
text_to_sequence)
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
def load_tsv(filename):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
components = [line.strip().split('\t') for line in f]
|
||||
base = os.path.dirname(filename)
|
||||
filepaths_and_text = [[os.path.join(base, f'{component[1]}'), component[0]] for component in components]
|
||||
filepaths_and_text = [
|
||||
[os.path.join(base, f'{component[1]}'), component[0]] for component in components]
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
def load_mozilla_cv(filename):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
components = [line.strip().split('\t') for line in f][1:] # First line is the header
|
||||
components = [line.strip().split('\t')
|
||||
for line in f][1:] # First line is the header
|
||||
base = os.path.dirname(filename)
|
||||
filepaths_and_text = [[os.path.join(base, f'clips/{component[1]}'), component[2]] for component in components]
|
||||
filepaths_and_text = [[os.path.join(
|
||||
base, f'clips/{component[1]}'), component[2]] for component in components]
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
def load_voxpopuli(filename):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
lines = [line.strip().split('\t') for line in f][1:] # First line is the header
|
||||
lines = [line.strip().split('\t')
|
||||
for line in f][1:] # First line is the header
|
||||
base = os.path.dirname(filename)
|
||||
filepaths_and_text = []
|
||||
for line in lines:
|
||||
|
@ -40,7 +44,8 @@ def load_voxpopuli(filename):
|
|||
continue
|
||||
file, raw_text, norm_text, speaker_id, split, gender = line
|
||||
year = file[:4]
|
||||
filepaths_and_text.append([os.path.join(base, year, f'{file}.ogg.wav'), raw_text])
|
||||
filepaths_and_text.append(
|
||||
[os.path.join(base, year, f'{file}.ogg.wav'), raw_text])
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
|
@ -56,8 +61,10 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
assert len(self.path) == len(fetcher_mode)
|
||||
|
||||
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
|
||||
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 3)
|
||||
self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
|
||||
self.conditioning_candidates = opt_get(
|
||||
hparams, ['num_conditioning_candidates'], 3)
|
||||
self.conditioning_length = opt_get(
|
||||
hparams, ['conditioning_length'], 44100)
|
||||
self.audiopaths_and_text = []
|
||||
for p, fm in zip(self.path, fetcher_mode):
|
||||
if fm == 'lj' or fm == 'libritts':
|
||||
|
@ -65,10 +72,12 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
elif fm == 'tsv':
|
||||
fetcher_fn = load_tsv
|
||||
elif fm == 'mozilla_cv':
|
||||
assert not self.load_conditioning # Conditioning inputs are incompatible with mozilla_cv
|
||||
# Conditioning inputs are incompatible with mozilla_cv
|
||||
assert not self.load_conditioning
|
||||
fetcher_fn = load_mozilla_cv
|
||||
elif fm == 'voxpopuli':
|
||||
assert not self.load_conditioning # Conditioning inputs are incompatible with voxpopuli
|
||||
# Conditioning inputs are incompatible with voxpopuli
|
||||
assert not self.load_conditioning
|
||||
fetcher_fn = load_voxpopuli
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
@ -96,10 +105,13 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
return text_norm
|
||||
|
||||
def load_conditioning_candidates(self, path):
|
||||
candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0]
|
||||
assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
|
||||
candidates = find_files_of_type(
|
||||
'img', os.path.dirname(path), qualifier=is_audio_file)[0]
|
||||
# Sanity check to ensure we aren't loading "related files" that aren't actually related.
|
||||
assert len(candidates) < 50000
|
||||
if len(candidates) == 0:
|
||||
print(f"No conditioning candidates found for {path} (not even the clip itself??)")
|
||||
print(
|
||||
f"No conditioning candidates found for {path} (not even the clip itself??)")
|
||||
raise NotImplementedError()
|
||||
# Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates.
|
||||
related_clips = []
|
||||
|
@ -110,25 +122,28 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
|
||||
elif gap > 0:
|
||||
rand_start = random.randint(0, gap)
|
||||
rel_clip = rel_clip[:, rand_start:rand_start+self.conditioning_length]
|
||||
rel_clip = rel_clip[:, rand_start:rand_start +
|
||||
self.conditioning_length]
|
||||
related_clips.append(rel_clip)
|
||||
return torch.stack(related_clips, dim=0)
|
||||
|
||||
def __getitem__(self, index):
|
||||
try:
|
||||
tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index])
|
||||
cond = self.load_conditioning_candidates(self.audiopaths_and_text[index][0]) if self.load_conditioning else None
|
||||
tseq, wav, text, path = self.get_wav_text_pair(
|
||||
self.audiopaths_and_text[index])
|
||||
cond = self.load_conditioning_candidates(
|
||||
self.audiopaths_and_text[index][0]) if self.load_conditioning else None
|
||||
except:
|
||||
print(f"error loading {self.audiopaths_and_text[index][0]}")
|
||||
return self[index+1]
|
||||
if wav is None or \
|
||||
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
|
||||
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
||||
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
||||
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
||||
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
|
||||
#if wav is not None:
|
||||
# if wav is not None:
|
||||
# print(f"Exception {index} wav_len:{wav.shape[-1]} text_len:{tseq.shape[0]} fname: {path}")
|
||||
rv = random.randint(0,len(self)-1)
|
||||
rv = random.randint(0, len(self)-1)
|
||||
return self[rv]
|
||||
orig_output = wav.shape[-1]
|
||||
orig_text_len = tseq.shape[0]
|
||||
|
@ -157,6 +172,7 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
class TextMelCollate():
|
||||
""" Zero-pads model inputs and targets based on number of frames per step
|
||||
"""
|
||||
|
||||
def __call__(self, batch):
|
||||
"""Collate's training batch from normalized text and wav
|
||||
PARAMS
|
||||
|
@ -226,7 +242,7 @@ if __name__ == '__main__':
|
|||
'num_conditioning_candidates': 3,
|
||||
'conditioning_length': 44100,
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
from data import create_dataloader, create_dataset
|
||||
|
||||
ds, c = create_dataset(params, return_collate=True)
|
||||
dl = create_dataloader(ds, params, collate_fn=c)
|
||||
|
@ -240,4 +256,5 @@ if __name__ == '__main__':
|
|||
print(f'{i} {ib} {b["real_text"][ib]}')
|
||||
torchaudio.save(f'{i}_clip_{ib}.wav', b['wav'][ib], ds.sample_rate)
|
||||
for c in range(3):
|
||||
torchaudio.save(f'{i}_clip_{ib}_cond{c}.wav', b['conditioning'][ib, c], ds.sample_rate)
|
||||
torchaudio.save(f'{i}_clip_{ib}_cond{c}.wav',
|
||||
b['conditioning'][ib, c], ds.sample_rate)
|
||||
|
|
|
@ -8,10 +8,13 @@ import torch.utils.data
|
|||
import torchaudio
|
||||
from tqdm import tqdm
|
||||
|
||||
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
|
||||
from models.audio.tts.tacotron2 import load_filepaths_and_text, load_filepaths_and_text_type
|
||||
from models.audio.tts.tacotron2 import text_to_sequence, sequence_to_text
|
||||
from utils.util import opt_get
|
||||
from dlas.data.audio.unsupervised_audio_dataset import (load_audio,
|
||||
load_similar_clips)
|
||||
from dlas.models.audio.tts.tacotron2 import (load_filepaths_and_text,
|
||||
load_filepaths_and_text_type,
|
||||
sequence_to_text,
|
||||
text_to_sequence)
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
def load_tsv_type(filename, type):
|
||||
|
@ -24,10 +27,12 @@ def load_tsv_type(filename, type):
|
|||
if len(components) < 2:
|
||||
bad_lines += 1
|
||||
if bad_lines > 1000:
|
||||
print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
|
||||
print(
|
||||
f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
|
||||
raise ValueError
|
||||
continue
|
||||
filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0]] + [type])
|
||||
filepaths_and_text.append(
|
||||
[os.path.join(base, f'{components[1]}'), components[0]] + [type])
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
|
@ -41,10 +46,12 @@ def load_tsv(filename):
|
|||
if len(components) < 2:
|
||||
bad_lines += 1
|
||||
if bad_lines > 1000:
|
||||
print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
|
||||
print(
|
||||
f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
|
||||
raise ValueError
|
||||
continue
|
||||
filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0]])
|
||||
filepaths_and_text.append(
|
||||
[os.path.join(base, f'{components[1]}'), components[0]])
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
|
@ -67,10 +74,12 @@ def load_tsv_aligned_codes_type(filename, type):
|
|||
if len(components) < 3:
|
||||
bad_lines += 1
|
||||
if bad_lines > 1000:
|
||||
print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
|
||||
print(
|
||||
f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
|
||||
raise ValueError
|
||||
continue
|
||||
filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0], convert_string_list_to_tensor(components[2])] + [type])
|
||||
filepaths_and_text.append([os.path.join(
|
||||
base, f'{components[1]}'), components[0], convert_string_list_to_tensor(components[2])] + [type])
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
|
@ -84,24 +93,29 @@ def load_tsv_aligned_codes(filename):
|
|||
if len(components) < 3:
|
||||
bad_lines += 1
|
||||
if bad_lines > 1000:
|
||||
print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
|
||||
print(
|
||||
f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
|
||||
raise ValueError
|
||||
continue
|
||||
filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0], convert_string_list_to_tensor(components[2])])
|
||||
filepaths_and_text.append([os.path.join(
|
||||
base, f'{components[1]}'), components[0], convert_string_list_to_tensor(components[2])])
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
def load_mozilla_cv(filename, type):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
components = [line.strip().split('\t') for line in f][1:] # First line is the header
|
||||
components = [line.strip().split('\t')
|
||||
for line in f][1:] # First line is the header
|
||||
base = os.path.dirname(filename)
|
||||
filepaths_and_text = [[os.path.join(base, f'clips/{component[1]}'), component[2], type] for component in components]
|
||||
filepaths_and_text = [[os.path.join(
|
||||
base, f'clips/{component[1]}'), component[2], type] for component in components]
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
def load_voxpopuli(filename, type):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
lines = [line.strip().split('\t') for line in f][1:] # First line is the header
|
||||
lines = [line.strip().split('\t')
|
||||
for line in f][1:] # First line is the header
|
||||
base = os.path.dirname(filename)
|
||||
filepaths_and_text = []
|
||||
for line in lines:
|
||||
|
@ -109,7 +123,8 @@ def load_voxpopuli(filename, type):
|
|||
continue
|
||||
file, raw_text, norm_text, speaker_id, split, gender = line
|
||||
year = file[:4]
|
||||
filepaths_and_text.append([os.path.join(base, year, f'{file}.ogg.wav'), raw_text, type])
|
||||
filepaths_and_text.append(
|
||||
[os.path.join(base, year, f'{file}.ogg.wav'), raw_text, type])
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
|
@ -134,11 +149,16 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
assert len(self.path) == len(fetcher_mode)
|
||||
|
||||
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
|
||||
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
|
||||
self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
|
||||
self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False)
|
||||
self.load_aligned_codes = opt_get(hparams, ['load_aligned_codes'], False)
|
||||
self.aligned_codes_to_audio_ratio = opt_get(hparams, ['aligned_codes_ratio'], 443)
|
||||
self.conditioning_candidates = opt_get(
|
||||
hparams, ['num_conditioning_candidates'], 1)
|
||||
self.conditioning_length = opt_get(
|
||||
hparams, ['conditioning_length'], 44100)
|
||||
self.debug_failures = opt_get(
|
||||
hparams, ['debug_loading_failures'], False)
|
||||
self.load_aligned_codes = opt_get(
|
||||
hparams, ['load_aligned_codes'], False)
|
||||
self.aligned_codes_to_audio_ratio = opt_get(
|
||||
hparams, ['aligned_codes_ratio'], 443)
|
||||
self.audiopaths_and_text = []
|
||||
for p, fm, type in zip(self.path, fetcher_mode, self.types):
|
||||
if fm == 'lj' or fm == 'libritts':
|
||||
|
@ -146,10 +166,12 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
elif fm == 'tsv':
|
||||
fetcher_fn = load_tsv_aligned_codes_type if self.load_aligned_codes else load_tsv_type
|
||||
elif fm == 'mozilla_cv':
|
||||
assert not self.load_conditioning # Conditioning inputs are incompatible with mozilla_cv
|
||||
# Conditioning inputs are incompatible with mozilla_cv
|
||||
assert not self.load_conditioning
|
||||
fetcher_fn = load_mozilla_cv
|
||||
elif fm == 'voxpopuli':
|
||||
assert not self.load_conditioning # Conditioning inputs are incompatible with voxpopuli
|
||||
# Conditioning inputs are incompatible with voxpopuli
|
||||
assert not self.load_conditioning
|
||||
fetcher_fn = load_voxpopuli
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
@ -165,11 +187,13 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
assert self.max_wav_len is not None and self.max_text_len is not None
|
||||
self.use_bpe_tokenizer = opt_get(hparams, ['use_bpe_tokenizer'], True)
|
||||
if self.use_bpe_tokenizer:
|
||||
from data.audio.voice_tokenizer import VoiceBpeTokenizer
|
||||
self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
|
||||
from dlas.data.audio.voice_tokenizer import VoiceBpeTokenizer
|
||||
self.tokenizer = VoiceBpeTokenizer(opt_get(
|
||||
hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
|
||||
else:
|
||||
self.tokenizer = CharacterTokenizer()
|
||||
self.skipped_items = 0 # records how many items are skipped when accessing an index.
|
||||
# records how many items are skipped when accessing an index.
|
||||
self.skipped_items = 0
|
||||
|
||||
def get_wav_text_pair(self, audiopath_and_text):
|
||||
# separate filename and text
|
||||
|
@ -191,19 +215,21 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
def __getitem__(self, index):
|
||||
self.skipped_items += 1
|
||||
try:
|
||||
tseq, wav, text, path, type = self.get_wav_text_pair(self.audiopaths_and_text[index])
|
||||
tseq, wav, text, path, type = self.get_wav_text_pair(
|
||||
self.audiopaths_and_text[index])
|
||||
if text is None or len(text.strip()) == 0:
|
||||
raise ValueError
|
||||
if wav is None or wav.shape[-1] < (.6 * self.sample_rate):
|
||||
# Ultra short clips are also useless (and can cause problems within some models).
|
||||
raise ValueError
|
||||
cond, cond_is_self = load_similar_clips(self.audiopaths_and_text[index][0], self.conditioning_length, self.sample_rate,
|
||||
n=self.conditioning_candidates) if self.load_conditioning else (None, False)
|
||||
n=self.conditioning_candidates) if self.load_conditioning else (None, False)
|
||||
except:
|
||||
if self.skipped_items > 100:
|
||||
raise # Rethrow if we have nested too far.
|
||||
if self.debug_failures:
|
||||
print(f"error loading {self.audiopaths_and_text[index][0]} {sys.exc_info()}")
|
||||
print(
|
||||
f"error loading {self.audiopaths_and_text[index][0]} {sys.exc_info()}")
|
||||
return self[(index+1) % len(self)]
|
||||
|
||||
if self.load_aligned_codes:
|
||||
|
@ -213,12 +239,13 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
self.skipped_items = 0
|
||||
if wav is None or \
|
||||
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
|
||||
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
||||
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
||||
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
||||
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
|
||||
if self.debug_failures:
|
||||
print(f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}")
|
||||
rv = random.randint(0,len(self)-1)
|
||||
print(
|
||||
f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}")
|
||||
rv = random.randint(0, len(self)-1)
|
||||
return self[rv]
|
||||
orig_output = wav.shape[-1]
|
||||
orig_text_len = tseq.shape[0]
|
||||
|
@ -226,7 +253,8 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1]))
|
||||
if self.load_aligned_codes:
|
||||
# These codes are aligned to audio inputs, so make sure to pad them as well.
|
||||
aligned_codes = F.pad(aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0]))
|
||||
aligned_codes = F.pad(
|
||||
aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0]))
|
||||
if tseq.shape[0] != self.max_text_len:
|
||||
tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0]))
|
||||
res = {
|
||||
|
@ -265,13 +293,15 @@ class PairedVoiceDebugger:
|
|||
if isinstance(state, dict):
|
||||
self.total_items = opt_get(state, ['total_items'], 0)
|
||||
self.loaded_items = opt_get(state, ['loaded_items'], 0)
|
||||
self.self_conditioning_items = opt_get(state, ['self_conditioning_items'], 0)
|
||||
self.self_conditioning_items = opt_get(
|
||||
state, ['self_conditioning_items'], 0)
|
||||
|
||||
def update(self, batch):
|
||||
self.total_items += batch['wav'].shape[0]
|
||||
self.loaded_items += batch['skipped_items'].sum().item()
|
||||
if 'conditioning' in batch.keys():
|
||||
self.self_conditioning_items += batch['conditioning_contains_self'].sum().item()
|
||||
self.self_conditioning_items += batch['conditioning_contains_self'].sum(
|
||||
).item()
|
||||
|
||||
def get_debugging_map(self):
|
||||
return {
|
||||
|
@ -299,11 +329,12 @@ if __name__ == '__main__':
|
|||
'use_bpe_tokenizer': True,
|
||||
'load_aligned_codes': False,
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
from data import create_dataloader, create_dataset
|
||||
|
||||
def save(b, i, ib, key, c=None):
|
||||
if c is not None:
|
||||
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', b[key][ib][c], 22050)
|
||||
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav',
|
||||
b[key][ib][c], 22050)
|
||||
else:
|
||||
torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)
|
||||
|
||||
|
@ -317,4 +348,3 @@ if __name__ == '__main__':
|
|||
save(b, i, ib, 'wav')
|
||||
if i > 5:
|
||||
break
|
||||
|
||||
|
|
|
@ -9,14 +9,15 @@ import torchaudio
|
|||
import torchvision
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils.util import opt_get
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
class PreprocessedMelDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, opt):
|
||||
path = opt['path']
|
||||
cache_path = opt['cache_path'] # Will fail when multiple paths specified, must be specified in this case.
|
||||
# Will fail when multiple paths specified, must be specified in this case.
|
||||
cache_path = opt['cache_path']
|
||||
if os.path.exists(cache_path):
|
||||
self.paths = torch.load(cache_path)
|
||||
else:
|
||||
|
@ -36,8 +37,8 @@ class PreprocessedMelDataset(torch.utils.data.Dataset):
|
|||
padding_needed = self.pad_to - mel.shape[-1]
|
||||
mask = torch.zeros_like(mel)
|
||||
if padding_needed > 0:
|
||||
mel = F.pad(mel, (0,padding_needed))
|
||||
mask = F.pad(mask, (0,padding_needed), value=1)
|
||||
mel = F.pad(mel, (0, padding_needed))
|
||||
mask = F.pad(mask, (0, padding_needed), value=1)
|
||||
|
||||
output = {
|
||||
'mel': mel,
|
||||
|
@ -62,13 +63,13 @@ if __name__ == '__main__':
|
|||
'n_workers': 0,
|
||||
'batch_size': 16,
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
from data import create_dataloader, create_dataset
|
||||
|
||||
ds = create_dataset(params)
|
||||
dl = create_dataloader(ds, params)
|
||||
i = 0
|
||||
for b in tqdm(dl):
|
||||
#pass
|
||||
# pass
|
||||
torchvision.utils.save_image((b['mel'].unsqueeze(1)+1)/2, f'{i}.png')
|
||||
i += 1
|
||||
if i > 20:
|
||||
|
|
|
@ -3,15 +3,16 @@ import random
|
|||
import sys
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
import torchaudio
|
||||
from audio2numpy import open_audio
|
||||
from tqdm import tqdm
|
||||
|
||||
from data.util import find_files_of_type, is_audio_file, load_paths_from_cache
|
||||
from models.audio.tts.tacotron2.taco_utils import load_wav_to_torch
|
||||
from utils.util import opt_get
|
||||
from dlas.data.util import (find_files_of_type, is_audio_file,
|
||||
load_paths_from_cache)
|
||||
from dlas.models.audio.tts.tacotron2.taco_utils import load_wav_to_torch
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
def load_audio(audiopath, sampling_rate):
|
||||
|
@ -53,17 +54,21 @@ def load_similar_clips(path, sample_length, sample_rate, n=3, fallback_to_self=T
|
|||
similarities = torch.load(sim_path)
|
||||
fname = os.path.basename(path)
|
||||
if fname in similarities.keys():
|
||||
candidates = [os.path.join(os.path.dirname(path), s) for s in similarities[fname]]
|
||||
candidates = [os.path.join(os.path.dirname(path), s)
|
||||
for s in similarities[fname]]
|
||||
else:
|
||||
print(f'Similarities list found for {path} but {fname} was not in that list.')
|
||||
#candidates.append(path) # Always include self as a possible similar clip.
|
||||
print(
|
||||
f'Similarities list found for {path} but {fname} was not in that list.')
|
||||
# candidates.append(path) # Always include self as a possible similar clip.
|
||||
if len(candidates) == 0:
|
||||
if fallback_to_self:
|
||||
candidates = [path]
|
||||
else:
|
||||
candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0]
|
||||
candidates = find_files_of_type(
|
||||
'img', os.path.dirname(path), qualifier=is_audio_file)[0]
|
||||
|
||||
assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
|
||||
# Sanity check to ensure we aren't loading "related files" that aren't actually related.
|
||||
assert len(candidates) < 50000
|
||||
if len(candidates) == 0:
|
||||
print(f"No conditioning candidates found for {path}")
|
||||
raise NotImplementedError()
|
||||
|
@ -92,7 +97,8 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
|||
|
||||
def __init__(self, opt):
|
||||
path = opt['path']
|
||||
cache_path = opt['cache_path'] # Will fail when multiple paths specified, must be specified in this case.
|
||||
# Will fail when multiple paths specified, must be specified in this case.
|
||||
cache_path = opt['cache_path']
|
||||
exclusions = []
|
||||
if 'exclusions' in opt.keys():
|
||||
for exc in opt['exclusions']:
|
||||
|
@ -102,7 +108,8 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
|||
assert isinstance(ew, list)
|
||||
not_ew = opt_get(opt, ['not_endswith'], [])
|
||||
assert isinstance(not_ew, list)
|
||||
self.audiopaths = load_paths_from_cache(path, cache_path, exclusions, endswith=ew, not_endswith=not_ew)
|
||||
self.audiopaths = load_paths_from_cache(
|
||||
path, cache_path, exclusions, endswith=ew, not_endswith=not_ew)
|
||||
|
||||
# Parse options
|
||||
self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050)
|
||||
|
@ -121,7 +128,8 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
|||
self.extra_samples = opt_get(opt, ['extra_samples'], 0)
|
||||
self.extra_sample_len = opt_get(opt, ['extra_sample_length'], 44000)
|
||||
|
||||
self.debug_loading_failures = opt_get(opt, ['debug_loading_failures'], True)
|
||||
self.debug_loading_failures = opt_get(
|
||||
opt, ['debug_loading_failures'], True)
|
||||
|
||||
def get_audio_for_index(self, index):
|
||||
audiopath = self.audiopaths[index]
|
||||
|
@ -144,8 +152,9 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
|||
alt_files, alt_is_self = self.get_related_audio_for_index(index)
|
||||
except:
|
||||
if self.debug_loading_failures:
|
||||
print(f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}")
|
||||
return self[random.randint(0,len(self))]
|
||||
print(
|
||||
f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}")
|
||||
return self[random.randint(0, len(self))]
|
||||
|
||||
# When generating resampled clips, skew is a bias that tries to spread them out from each other, reducing their
|
||||
# influence on one another.
|
||||
|
@ -157,10 +166,12 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
|||
for sk in skew:
|
||||
if self.pad_to is not None:
|
||||
if audio_norm.shape[-1] <= self.pad_to:
|
||||
clips.append(torch.nn.functional.pad(audio_norm, (0, self.pad_to - audio_norm.shape[-1])))
|
||||
clips.append(torch.nn.functional.pad(
|
||||
audio_norm, (0, self.pad_to - audio_norm.shape[-1])))
|
||||
else:
|
||||
gap = audio_norm.shape[-1] - self.pad_to
|
||||
start = min(max(random.randint(0, gap-1) + sk * gap // 2, 0), gap-1)
|
||||
start = min(
|
||||
max(random.randint(0, gap-1) + sk * gap // 2, 0), gap-1)
|
||||
clips.append(audio_norm[:, start:start+self.pad_to])
|
||||
else:
|
||||
clips.append(audio_norm)
|
||||
|
@ -200,16 +211,18 @@ if __name__ == '__main__':
|
|||
'n_workers': 1,
|
||||
'batch_size': 16,
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
from data import create_dataloader, create_dataset
|
||||
|
||||
ds = create_dataset(params)
|
||||
dl = create_dataloader(ds, params)
|
||||
i = 0
|
||||
for b in tqdm(dl):
|
||||
for b_ in range(b['clip'].shape[0]):
|
||||
#pass
|
||||
torchaudio.save(f'{i}_clip_{b_}.wav', b['clip'][b_], ds.sampling_rate)
|
||||
torchaudio.save(f'{i}_alt_clip_{b_}.wav', b['alt_clips'][b_], ds.sampling_rate)
|
||||
# pass
|
||||
torchaudio.save(f'{i}_clip_{b_}.wav',
|
||||
b['clip'][b_], ds.sampling_rate)
|
||||
torchaudio.save(f'{i}_alt_clip_{b_}.wav',
|
||||
b['alt_clips'][b_], ds.sampling_rate)
|
||||
i += 1
|
||||
if i > 200:
|
||||
break
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
import json
|
||||
import re
|
||||
|
||||
import torch
|
||||
import json
|
||||
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.pre_tokenizers import Whitespace
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
|
||||
from data.audio.paired_voice_audio_dataset import load_mozilla_cv, load_voxpopuli, load_tsv
|
||||
from models.audio.tts.tacotron2 import load_filepaths_and_text
|
||||
from models.audio.tts.tacotron2.text.cleaners import english_cleaners
|
||||
from dlas.data.audio.paired_voice_audio_dataset import (load_mozilla_cv,
|
||||
load_tsv,
|
||||
load_voxpopuli)
|
||||
from dlas.models.audio.tts.tacotron2 import load_filepaths_and_text
|
||||
from dlas.models.audio.tts.tacotron2.text.cleaners import english_cleaners
|
||||
|
||||
|
||||
def remove_extraneous_punctuation(word):
|
||||
|
@ -21,7 +22,8 @@ def remove_extraneous_punctuation(word):
|
|||
'—': '-', '`': '\'',
|
||||
'ʼ': '\''
|
||||
}
|
||||
replace = re.compile("|".join([re.escape(k) for k in sorted(replacement_punctuation, key=len, reverse=True)]), flags=re.DOTALL)
|
||||
replace = re.compile("|".join([re.escape(k) for k in sorted(
|
||||
replacement_punctuation, key=len, reverse=True)]), flags=re.DOTALL)
|
||||
word = replace.sub(lambda x: replacement_punctuation[x.group(0)], word)
|
||||
|
||||
# TODO: some of these are spoken ('@', '%', '+', etc). Integrate them into the cleaners.
|
||||
|
@ -29,27 +31,32 @@ def remove_extraneous_punctuation(word):
|
|||
word = extraneous.sub('', word)
|
||||
return word
|
||||
|
||||
|
||||
def expand_numbers(text):
|
||||
return normalize_numbers(text)
|
||||
return normalize_numbers(text)
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
return text.lower()
|
||||
|
||||
|
||||
_whitespace_re = re.compile(r'\s+')
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, ' ', text)
|
||||
return re.sub(_whitespace_re, ' ', text)
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
return unidecode(text)
|
||||
return unidecode(text)
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
class VoiceBpeTokenizer:
|
||||
def __init__(self, vocab_file, preprocess=None):
|
||||
|
@ -72,23 +79,24 @@ class VoiceBpeTokenizer:
|
|||
|
||||
kks = pykakasi.kakasi()
|
||||
results = kks.convert(txt)
|
||||
txt = " ".join([ result['kana'] for result in results ])
|
||||
txt = " ".join([result['kana'] for result in results])
|
||||
txt = basic_cleaners(txt)
|
||||
else:
|
||||
txt = english_cleaners(txt)
|
||||
|
||||
|
||||
return txt
|
||||
|
||||
def encode(self, txt):
|
||||
if self.preprocess:
|
||||
txt = self.preprocess_text(txt)
|
||||
txt = self.preprocess_text(txt)
|
||||
txt = txt.replace(' ', '[SPACE]')
|
||||
return self.tokenizer.encode(txt).ids
|
||||
|
||||
def decode(self, seq):
|
||||
if isinstance(seq, torch.Tensor):
|
||||
seq = seq.cpu().numpy()
|
||||
txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(' ', '')
|
||||
txt = self.tokenizer.decode(
|
||||
seq, skip_special_tokens=False).replace(' ', '')
|
||||
txt = txt.replace('[SPACE]', ' ')
|
||||
txt = txt.replace('[STOP]', '')
|
||||
txt = txt.replace('[UNK]', '')
|
||||
|
@ -117,10 +125,11 @@ def build_text_file_from_priors(priors, output):
|
|||
def train():
|
||||
with open('all_texts.txt', 'r', encoding='utf-8') as at:
|
||||
ttsd = at.readlines()
|
||||
#bcd = datasets.load_dataset('bookcorpus', cache_dir='Z:\\huggingface_datasets\\cache')['train']
|
||||
# bcd = datasets.load_dataset('bookcorpus', cache_dir='Z:\\huggingface_datasets\\cache')['train']
|
||||
|
||||
#allowed_characters_re = re.compile(r'^[0-9a-z!@#%_=:;"/, \-\$\^&\*\(\)\+\{\[\]\}\\\.\'\?—–ʼ]+$')
|
||||
# allowed_characters_re = re.compile(r'^[0-9a-z!@#%_=:;"/, \-\$\^&\*\(\)\+\{\[\]\}\\\.\'\?—–ʼ]+$')
|
||||
allowed_characters_re = re.compile(r'^[a-z!:;"/, \-\(\)\.\'\?ʼ]+$')
|
||||
|
||||
def preprocess_word(word, report=False):
|
||||
word = english_cleaners(word)
|
||||
word = remove_extraneous_punctuation(word)
|
||||
|
@ -135,16 +144,19 @@ def train():
|
|||
for i in range(0, len(ttsd), batch_size):
|
||||
yield [preprocess_word(t, True) for t in ttsd[i:i+batch_size]]
|
||||
|
||||
#print("Processing bookcorpus.")
|
||||
#for i in range(0, len(bcd), batch_size):
|
||||
# print("Processing bookcorpus.")
|
||||
# for i in range(0, len(bcd), batch_size):
|
||||
# yield [preprocess_word(t) for t in bcd[i:i+batch_size]['text']]
|
||||
|
||||
trainer = BpeTrainer(special_tokens=['[STOP]', '[UNK]', '[SPACE]'], vocab_size=255)
|
||||
trainer = BpeTrainer(
|
||||
special_tokens=['[STOP]', '[UNK]', '[SPACE]'], vocab_size=255)
|
||||
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
|
||||
tokenizer.pre_tokenizer = Whitespace()
|
||||
tokenizer.train_from_iterator(batch_iterator(), trainer, length=len(ttsd))#+len(bcd))
|
||||
tokenizer.train_from_iterator(
|
||||
batch_iterator(), trainer, length=len(ttsd)) # +len(bcd))
|
||||
|
||||
print(tokenizer.decode(tokenizer.encode("i was traveling throughhadslfghds the woods in 1235375t137{{}}").ids))
|
||||
print(tokenizer.decode(tokenizer.encode(
|
||||
"i was traveling throughhadslfghds the woods in 1235375t137{{}}").ids))
|
||||
|
||||
tokenizer.save('gpt_tts_tokenizer.json')
|
||||
|
||||
|
@ -171,5 +183,5 @@ if __name__ == '__main__':
|
|||
('Y:\\clips\\books2-transcribed.tsv', 'tsv'),
|
||||
('Y:\\clips\\podcasts-0-transcribed.tsv', 'tsv')], 'all_texts.txt')
|
||||
'''
|
||||
#train()
|
||||
# train()
|
||||
test()
|
||||
|
|
|
@ -3,19 +3,19 @@ import random
|
|||
import torch
|
||||
import torchaudio.sox_effects
|
||||
|
||||
from models.audio.tts.tacotron2.taco_utils import load_wav_to_torch
|
||||
from dlas.models.audio.tts.tacotron2.taco_utils import load_wav_to_torch
|
||||
|
||||
|
||||
# Returns random double on [l,h] as a string
|
||||
def rdstr(l=0,h=1):
|
||||
def rdstr(l=0, h=1):
|
||||
assert h > l
|
||||
i=h-l
|
||||
i = h-l
|
||||
return str(random.random() * i + l)
|
||||
|
||||
|
||||
# Returns a randint on [s,e] as a string
|
||||
def rdi(e, s=0):
|
||||
return str(random.randint(s,e))
|
||||
return str(random.randint(s, e))
|
||||
|
||||
|
||||
class WavAugmentor:
|
||||
|
@ -43,12 +43,13 @@ class WavAugmentor:
|
|||
band_effect = random.choice(band_effects)
|
||||
'''
|
||||
volume_effects = [
|
||||
['loudness', rdi(10,-2)],
|
||||
['overdrive', rdi(20,0), rdi(20,0)],
|
||||
['loudness', rdi(10, -2)],
|
||||
['overdrive', rdi(20, 0), rdi(20, 0)],
|
||||
]
|
||||
vol_effect = random.choice(volume_effects)
|
||||
effects = [speed_effect, vol_effect]
|
||||
out, sr = torchaudio.sox_effects.apply_effects_tensor(wav, sample_rate, effects)
|
||||
out, sr = torchaudio.sox_effects.apply_effects_tensor(
|
||||
wav, sample_rate, effects)
|
||||
# Add a variable amount of noise
|
||||
out = out + torch.rand_like(out) * random.random() * .03
|
||||
return out
|
||||
|
@ -60,4 +61,4 @@ if __name__ == '__main__':
|
|||
aug = WavAugmentor()
|
||||
for j in range(10):
|
||||
out = aug.augment(sample, 24000)
|
||||
torchaudio.save(f'out{j}.wav', out, 24000)
|
||||
torchaudio.save(f'out{j}.wav', out, 24000)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
from data import create_dataset
|
||||
|
||||
from dlas.data import create_dataset
|
||||
|
||||
|
||||
# Simple composite dataset that combines multiple other datasets.
|
||||
|
@ -31,4 +32,4 @@ class CombinedDataset(torch.utils.data.Dataset):
|
|||
return output
|
||||
|
||||
def __len__(self):
|
||||
return max(len(d) for d in self.datasets.values())
|
||||
return max(len(d) for d in self.datasets.values())
|
||||
|
|
|
@ -4,9 +4,10 @@ Support enlarging the dataset for *iteration-oriented* training, for saving time
|
|||
dataloader after each epoch
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.utils.data.sampler import Sampler
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
|
||||
class DistIterSampler(Sampler):
|
||||
|
@ -30,17 +31,20 @@ class DistIterSampler(Sampler):
|
|||
def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
raise RuntimeError(
|
||||
"Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
raise RuntimeError(
|
||||
"Requires distributed package to be available")
|
||||
rank = dist.get_rank()
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas))
|
||||
self.num_samples = int(
|
||||
math.ceil(len(self.dataset) * ratio / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
|
||||
def __iter__(self):
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import torch
|
||||
from torch.utils import data
|
||||
from data.images.image_corruptor import ImageCorruptor
|
||||
from data.images.chunk_with_reference import ChunkWithReference
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils import data
|
||||
|
||||
from dlas.data.images.chunk_with_reference import ChunkWithReference
|
||||
from dlas.data.images.image_corruptor import ImageCorruptor
|
||||
|
||||
|
||||
# Class whose purpose is to hold as much logic as can possibly be shared between datasets that operate on raw image
|
||||
# data and nothing else (which also have a very specific directory structure being used, as dictated by
|
||||
|
@ -13,13 +16,17 @@ class BaseUnsupervisedImageDataset(data.Dataset):
|
|||
def __init__(self, opt):
|
||||
self.opt = opt
|
||||
self.corruptor = ImageCorruptor(opt)
|
||||
self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys() else None
|
||||
self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys() else 1
|
||||
self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys(
|
||||
) else None
|
||||
self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys(
|
||||
) else 1
|
||||
self.for_eval = opt['eval'] if 'eval' in opt.keys() else False
|
||||
self.scale = opt['scale'] if not self.for_eval else 1
|
||||
self.paths = opt['paths']
|
||||
self.corrupt_before_downsize = opt['corrupt_before_downsize'] if 'corrupt_before_downsize' in opt.keys() else False
|
||||
assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors.
|
||||
self.corrupt_before_downsize = opt['corrupt_before_downsize'] if 'corrupt_before_downsize' in opt.keys(
|
||||
) else False
|
||||
# If we dont throw here, we get some really obscure errors.
|
||||
assert (self.target_hq_size // self.scale) % self.multiple == 0
|
||||
if not isinstance(self.paths, list):
|
||||
self.paths = [self.paths]
|
||||
self.weights = [1]
|
||||
|
@ -34,8 +41,10 @@ class BaseUnsupervisedImageDataset(data.Dataset):
|
|||
if os.path.exists(cache_path):
|
||||
chunks = torch.load(cache_path)
|
||||
else:
|
||||
print("Building chunk cache, this can take some time for large datasets..")
|
||||
chunks = [ChunkWithReference(opt, d) for d in sorted(os.scandir(path), key=lambda e: e.name) if d.is_dir()]
|
||||
print(
|
||||
"Building chunk cache, this can take some time for large datasets..")
|
||||
chunks = [ChunkWithReference(opt, d) for d in sorted(
|
||||
os.scandir(path), key=lambda e: e.name) if d.is_dir()]
|
||||
# Prune out chunks that have no images
|
||||
res = []
|
||||
for c in chunks:
|
||||
|
@ -60,7 +69,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
|
|||
for c in self.chunks:
|
||||
paths.extend(c.tiles)
|
||||
return paths
|
||||
|
||||
|
||||
# Utility method for translating a point when the dimensions of an image change.
|
||||
def resize_point(self, point, orig_dim, new_dim):
|
||||
oh, ow = orig_dim
|
||||
|
@ -78,20 +87,27 @@ class BaseUnsupervisedImageDataset(data.Dataset):
|
|||
for hq, hq_ref, hq_mask, hq_center in zip(imgs_hq, refs_hq, masks_hq, centers_hq):
|
||||
# It is assumed that the target size is a square.
|
||||
target_size = (self.target_hq_size, self.target_hq_size)
|
||||
hqs_adjusted.append(cv2.resize(hq, target_size, interpolation=cv2.INTER_AREA))
|
||||
hq_refs_adjusted.append(cv2.resize(hq_ref, target_size, interpolation=cv2.INTER_AREA))
|
||||
hq_masks_adjusted.append(cv2.resize(hq_mask, target_size, interpolation=cv2.INTER_AREA))
|
||||
hq_centers_adjusted.append(self.resize_point(hq_center, (h, w), target_size))
|
||||
hqs_adjusted.append(cv2.resize(
|
||||
hq, target_size, interpolation=cv2.INTER_AREA))
|
||||
hq_refs_adjusted.append(cv2.resize(
|
||||
hq_ref, target_size, interpolation=cv2.INTER_AREA))
|
||||
hq_masks_adjusted.append(cv2.resize(
|
||||
hq_mask, target_size, interpolation=cv2.INTER_AREA))
|
||||
hq_centers_adjusted.append(
|
||||
self.resize_point(hq_center, (h, w), target_size))
|
||||
h, w = self.target_hq_size, self.target_hq_size
|
||||
else:
|
||||
hqs_adjusted, hq_refs_adjusted, hq_masks_adjusted, hq_centers_adjusted = imgs_hq, refs_hq, masks_hq, centers_hq
|
||||
hq_masks_adjusted = [m.squeeze(-1) for m in hq_masks_adjusted] # This is done implicitly above..
|
||||
hq_multiple = self.multiple * self.scale # Multiple must apply to LQ image.
|
||||
# This is done implicitly above..
|
||||
hq_masks_adjusted = [m.squeeze(-1) for m in hq_masks_adjusted]
|
||||
# Multiple must apply to LQ image.
|
||||
hq_multiple = self.multiple * self.scale
|
||||
if h % hq_multiple != 0 or w % hq_multiple != 0:
|
||||
hqs_conformed, hq_refs_conformed, hq_masks_conformed, hq_centers_conformed = [], [], [], []
|
||||
for hq, hq_ref, hq_mask, hq_center in zip(hqs_adjusted, hq_refs_adjusted, hq_masks_adjusted, hq_centers_adjusted):
|
||||
h, w = (h - h % hq_multiple), (w - w % hq_multiple)
|
||||
hq_centers_conformed.append(self.resize_point(hq_center, hq.shape[:2], (h, w)))
|
||||
hq_centers_conformed.append(
|
||||
self.resize_point(hq_center, hq.shape[:2], (h, w)))
|
||||
hqs_conformed.append(hq[:h, :w, :])
|
||||
hq_refs_conformed.append(hq_ref[:h, :w, :])
|
||||
hq_masks_conformed.append(hq_mask[:h, :w, :])
|
||||
|
@ -110,10 +126,14 @@ class BaseUnsupervisedImageDataset(data.Dataset):
|
|||
lms.append(hq_mask)
|
||||
lcs.append(hq_center)
|
||||
else:
|
||||
ls.append(cv2.resize(hq, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA))
|
||||
lrs.append(cv2.resize(hq_ref, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA))
|
||||
lms.append(cv2.resize(hq_mask, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA))
|
||||
lcs.append(self.resize_point(hq_center, (h, w), ls[0].shape[:2]))
|
||||
ls.append(cv2.resize(hq, (h // self.scale, w //
|
||||
self.scale), interpolation=cv2.INTER_AREA))
|
||||
lrs.append(cv2.resize(hq_ref, (h // self.scale, w //
|
||||
self.scale), interpolation=cv2.INTER_AREA))
|
||||
lms.append(cv2.resize(hq_mask, (h // self.scale, w //
|
||||
self.scale), interpolation=cv2.INTER_AREA))
|
||||
lcs.append(self.resize_point(
|
||||
hq_center, (h, w), ls[0].shape[:2]))
|
||||
# Corrupt the LQ image (only in eval mode)
|
||||
if not self.corrupt_before_downsize and not self.for_eval:
|
||||
ls = self.corruptor.corrupt_images(ls)
|
||||
|
|
|
@ -3,22 +3,20 @@ from time import time
|
|||
|
||||
import kornia
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch.utils.data import Dataset
|
||||
from kornia import augmentation as augs, geometry
|
||||
from kornia import filters
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import torchvision
|
||||
from kornia import augmentation as augs
|
||||
from kornia import filters, geometry
|
||||
from torch.utils.data import Dataset
|
||||
# Wrapper for a DLAS Dataset class that applies random augmentations from the BYOL paper to BOTH the 'lq' and 'hq'
|
||||
# inputs. These are then outputted as 'aug1' and 'aug2'.
|
||||
from tqdm import tqdm
|
||||
|
||||
from data import create_dataset
|
||||
from models.arch_util import PixelUnshuffle
|
||||
from utils.util import opt_get
|
||||
from dlas.data import create_dataset
|
||||
from dlas.models.arch_util import PixelUnshuffle
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
class RandomApply(nn.Module):
|
||||
|
@ -26,6 +24,7 @@ class RandomApply(nn.Module):
|
|||
super().__init__()
|
||||
self.fn = fn
|
||||
self.p = p
|
||||
|
||||
def forward(self, x):
|
||||
if random.random() > self.p:
|
||||
return x
|
||||
|
@ -39,9 +38,10 @@ class ByolDatasetWrapper(Dataset):
|
|||
self.cropped_img_size = opt['crop_size']
|
||||
self.key1 = opt_get(opt, ['key1'], 'hq')
|
||||
self.key2 = opt_get(opt, ['key2'], 'lq')
|
||||
for_sr = opt_get(opt, ['for_sr'], False) # When set, color alterations and blurs are disabled.
|
||||
# When set, color alterations and blurs are disabled.
|
||||
for_sr = opt_get(opt, ['for_sr'], False)
|
||||
|
||||
augmentations = [ \
|
||||
augmentations = [
|
||||
augs.RandomHorizontalFlip(),
|
||||
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
|
||||
if not for_sr:
|
||||
|
@ -51,12 +51,14 @@ class ByolDatasetWrapper(Dataset):
|
|||
if opt['normalize']:
|
||||
# The paper calls for normalization. Most datasets/models in this repo don't use this.
|
||||
# Recommend setting true if you want to train exactly like the paper.
|
||||
augmentations.append(augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])))
|
||||
augmentations.append(augs.Normalize(mean=torch.tensor(
|
||||
[0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])))
|
||||
self.aug = nn.Sequential(*augmentations)
|
||||
|
||||
def __getitem__(self, item):
|
||||
item = self.wrapped_dataset[item]
|
||||
item.update({'aug1': self.aug(item[self.key1]).squeeze(dim=0), 'aug2': self.aug(item[self.key2]).squeeze(dim=0)})
|
||||
item.update({'aug1': self.aug(item[self.key1]).squeeze(
|
||||
dim=0), 'aug2': self.aug(item[self.key2]).squeeze(dim=0)})
|
||||
return item
|
||||
|
||||
def __len__(self):
|
||||
|
@ -71,7 +73,7 @@ class DatasetRandomAugWrapper(Dataset):
|
|||
self.wrapped_dataset = create_dataset(opt['dataset'])
|
||||
self.cropped_img_size = opt['crop_size']
|
||||
self.includes_labels = opt['includes_labels']
|
||||
augmentations = [ \
|
||||
augmentations = [
|
||||
RandomApply(augs.ColorJitter(0.4, 0.4, 0.4, 0.2), p=0.8),
|
||||
augs.RandomGrayscale(p=0.2),
|
||||
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)]
|
||||
|
@ -87,18 +89,19 @@ class DatasetRandomAugWrapper(Dataset):
|
|||
dtypes = []
|
||||
for k in item.keys():
|
||||
if 'label' in k and isinstance(item[k], torch.Tensor) and len(item[k].shape) == 3:
|
||||
assert item[k].shape[0] == 1 # Only supports a channel dim of 1.
|
||||
# Only supports a channel dim of 1.
|
||||
assert item[k].shape[0] == 1
|
||||
labels.append(k)
|
||||
dtypes.append(item[k].dtype)
|
||||
hq = torch.cat([hq, item[k].type(torch.float)], dim=0)
|
||||
hq = self.rrc(hq.unsqueeze(0)).squeeze(0)
|
||||
for i, k in enumerate(labels):
|
||||
# Strip out any label values that are not whole numbers.
|
||||
item[k] = hq[3+i:3+i+1,:,:]
|
||||
item[k] = hq[3+i:3+i+1, :, :]
|
||||
whole = (item[k].round() == item[k])
|
||||
item[k] = item[k] * whole
|
||||
item[k] = item[k].type(dtypes[i])
|
||||
item['lq'] = hq[:3,:,:]
|
||||
item['lq'] = hq[:3, :, :]
|
||||
item['hq'] = item['lq']
|
||||
return item
|
||||
|
||||
|
@ -137,14 +140,16 @@ def test_dataset_random_aug_wrapper():
|
|||
for k, v in o.items():
|
||||
# 'lq', 'hq', 'aug1', 'aug2',
|
||||
if k in ['hq']:
|
||||
torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
|
||||
torchvision.utils.save_image(
|
||||
v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
|
||||
masked = v * (o['labels_mask'] * .5 + .5)
|
||||
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
|
||||
# torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
|
||||
# Pick a random (non-zero) label and spit it out with the textual label.
|
||||
if len(o['labels'].unique()) > 1:
|
||||
randlbl = np.random.choice(o['labels'].unique()[1:])
|
||||
moremask = v * ((1*(o['labels'] == randlbl))*.5+.5)
|
||||
torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s_%s.png" % (i, k, o['label_strings'][randlbl]))
|
||||
torchvision.utils.save_image(moremask.unsqueeze(
|
||||
0), "debug/%i_%s_%s.png" % (i, k, o['label_strings'][randlbl]))
|
||||
|
||||
|
||||
def no_batch_interpolate(i, size, mode):
|
||||
|
@ -165,10 +170,10 @@ def snap(ref, other):
|
|||
# Pads a tensor with zeros so that it fits in a dxd square.
|
||||
def pad_to(im, d):
|
||||
if len(im.shape) == 3:
|
||||
pd = torch.zeros((im.shape[0],d,d))
|
||||
pd = torch.zeros((im.shape[0], d, d))
|
||||
pd[:, :im.shape[1], :im.shape[2]] = im
|
||||
else:
|
||||
pd = torch.zeros((im.shape[0],im.shape[1],d,d), device=im.device)
|
||||
pd = torch.zeros((im.shape[0], im.shape[1], d, d), device=im.device)
|
||||
pd[:, :, :im.shape[2], :im.shape[3]] = im
|
||||
return pd
|
||||
|
||||
|
@ -182,7 +187,8 @@ class RandomSharedRegionCrop(nn.Module):
|
|||
def __init__(self, multiple, jitter_range=0):
|
||||
super().__init__()
|
||||
self.multiple = multiple
|
||||
self.jitter_range = jitter_range # When specified, images are shifted an additional random([-j,j]) pixels where j=jitter_range
|
||||
# When specified, images are shifted an additional random([-j,j]) pixels where j=jitter_range
|
||||
self.jitter_range = jitter_range
|
||||
|
||||
def forward(self, i1, i2):
|
||||
assert i1.shape[-1] == i2.shape[-1]
|
||||
|
@ -218,19 +224,25 @@ class RandomSharedRegionCrop(nn.Module):
|
|||
|
||||
# Step 4
|
||||
m = self.multiple
|
||||
jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
|
||||
jt = jt if base_t != 0 else abs(jt) # If the top of a patch is zero, a negative jitter will cause it to go negative.
|
||||
jt = jt if (base_t+base_h)*m != i1.shape[1] else 0 # Likewise, jitter shouldn't allow the patch to go over-bounds.
|
||||
jl, jt = random.randint(-self.jitter_range,
|
||||
self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
|
||||
# If the top of a patch is zero, a negative jitter will cause it to go negative.
|
||||
jt = jt if base_t != 0 else abs(jt)
|
||||
# Likewise, jitter shouldn't allow the patch to go over-bounds.
|
||||
jt = jt if (base_t+base_h)*m != i1.shape[1] else 0
|
||||
jl = jl if base_l != 0 else abs(jl)
|
||||
jl = jl if (base_l+base_w)*m != i1.shape[1] else 0
|
||||
p1 = i1[:, base_t*m+jt:(base_t+base_h)*m+jt, base_l*m+jl:(base_l+base_w)*m+jl]
|
||||
p1 = i1[:, base_t*m+jt:(base_t+base_h)*m+jt,
|
||||
base_l*m+jl:(base_l+base_w)*m+jl]
|
||||
p1_resized = no_batch_interpolate(p1, size=(d*m, d*m), mode="bilinear")
|
||||
jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
|
||||
jl, jt = random.randint(-self.jitter_range,
|
||||
self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
|
||||
jt = jt if im2_t != 0 else abs(jt)
|
||||
jt = jt if (im2_t+im2_h)*m != i2.shape[1] else 0
|
||||
jl = jl if im2_l != 0 else abs(jl)
|
||||
jl = jl if (im2_l+im2_w)*m != i2.shape[1] else 0
|
||||
p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt, im2_l*m+jl:(im2_l+im2_w)*m+jl]
|
||||
p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt,
|
||||
im2_l*m+jl:(im2_l+im2_w)*m+jl]
|
||||
p2_resized = no_batch_interpolate(p2, size=(d*m, d*m), mode="bilinear")
|
||||
|
||||
# Step 5
|
||||
|
@ -246,14 +258,17 @@ class RandomSharedRegionCrop(nn.Module):
|
|||
i2_shared_t, i2_shared_l = snap(im2_t, base_t), snap(im2_l, base_l)
|
||||
ix_h = min(base_b, im2_b) - max(base_t, im2_t)
|
||||
ix_w = min(base_r, im2_r) - max(base_l, im2_l)
|
||||
recompute_package = torch.tensor([d, base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, should_flip, ix_h, ix_w], dtype=torch.long)
|
||||
recompute_package = torch.tensor([d, base_h, base_w, i1_shared_t, i1_shared_l, im2_h,
|
||||
im2_w, i2_shared_t, i2_shared_l, should_flip, ix_h, ix_w], dtype=torch.long)
|
||||
|
||||
# Step 7
|
||||
mask1 = torch.full((1, base_h*m, base_w*m), fill_value=.5)
|
||||
mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m, i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1
|
||||
mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m,
|
||||
i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1
|
||||
masked1 = pad_to(p1 * mask1, d*m)
|
||||
mask2 = torch.full((1, im2_h*m, im2_w*m), fill_value=.5)
|
||||
mask2[:, i2_shared_t*m:(i2_shared_t+ix_h)*m, i2_shared_l*m:(i2_shared_l+ix_w)*m] = 1
|
||||
mask2[:, i2_shared_t*m:(i2_shared_t+ix_h)*m,
|
||||
i2_shared_l*m:(i2_shared_l+ix_w)*m] = 1
|
||||
masked2 = pad_to(p2 * mask2, d*m)
|
||||
mask = torch.full((1, d*m, d*m), fill_value=.33)
|
||||
mask[:, base_t*m:(base_t+base_w)*m, base_l*m:(base_l+base_h)*m] += .33
|
||||
|
@ -262,10 +277,13 @@ class RandomSharedRegionCrop(nn.Module):
|
|||
|
||||
# Step 8 - Rebuild shared regions for testing purposes.
|
||||
p1_shuf, p2_shuf = PixelUnshuffle(self.multiple)(p1_resized.unsqueeze(0)), \
|
||||
PixelUnshuffle(self.multiple)(p2_resized.unsqueeze(0))
|
||||
i1_shared, i2_shared = reconstructed_shared_regions(p1_shuf, p2_shuf, recompute_package.unsqueeze(0))
|
||||
i1_shared = pad_to(nn.PixelShuffle(self.multiple)(i1_shared).squeeze(0), d * m)
|
||||
i2_shared = pad_to(nn.PixelShuffle(self.multiple)(i2_shared).squeeze(0), d*m)
|
||||
PixelUnshuffle(self.multiple)(p2_resized.unsqueeze(0))
|
||||
i1_shared, i2_shared = reconstructed_shared_regions(
|
||||
p1_shuf, p2_shuf, recompute_package.unsqueeze(0))
|
||||
i1_shared = pad_to(nn.PixelShuffle(self.multiple)
|
||||
(i1_shared).squeeze(0), d * m)
|
||||
i2_shared = pad_to(nn.PixelShuffle(self.multiple)
|
||||
(i2_shared).squeeze(0), d*m)
|
||||
|
||||
return p1_resized, p2_resized, recompute_package, masked1, masked2, masked_dbg, i1_shared, i2_shared
|
||||
|
||||
|
@ -280,7 +298,8 @@ def reconstructed_shared_regions(fea1, fea2, recompute_package: torch.Tensor):
|
|||
# It'd be real nice if we could do this at the batch level, but I don't see a really good way to do that outside
|
||||
# of conforming the recompute_package across the entire batch.
|
||||
for b in range(package.shape[0]):
|
||||
expected_dim, f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, should_flip, s_h, s_w = tuple(package[b].tolist())
|
||||
expected_dim, f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, should_flip, s_h, s_w = tuple(
|
||||
package[b].tolist())
|
||||
# If you are hitting this assert, you specified `latent_multiple` in your dataset config wrong.
|
||||
assert expected_dim == fea1.shape[2] and expected_dim == fea2.shape[2]
|
||||
|
||||
|
@ -292,8 +311,10 @@ def reconstructed_shared_regions(fea1, fea2, recompute_package: torch.Tensor):
|
|||
f1s = F.interpolate(fea1[b].unsqueeze(0), (f1_h, f1_w), mode="nearest")
|
||||
f2s = F.interpolate(f2.unsqueeze(0), (f2_h, f2_w), mode="nearest")
|
||||
# Outputs must be padded so they can "get along" with each other.
|
||||
res1.append(pad_to(f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w], pad_dim))
|
||||
res2.append(pad_to(f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w], pad_dim))
|
||||
res1.append(
|
||||
pad_to(f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w], pad_dim))
|
||||
res2.append(
|
||||
pad_to(f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w], pad_dim))
|
||||
return torch.cat(res1, dim=0), torch.cat(res2, dim=0)
|
||||
|
||||
|
||||
|
@ -308,10 +329,11 @@ class StructuredCropDatasetWrapper(Dataset):
|
|||
super().__init__()
|
||||
self.wrapped_dataset = create_dataset(opt['dataset'])
|
||||
augmentations = [RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
|
||||
augs.RandomGrayscale(p=0.2),
|
||||
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)]
|
||||
augs.RandomGrayscale(p=0.2),
|
||||
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)]
|
||||
self.aug = nn.Sequential(*augmentations)
|
||||
self.rrc = RandomSharedRegionCrop(opt['latent_multiple'], opt_get(opt, ['jitter_range'], 0))
|
||||
self.rrc = RandomSharedRegionCrop(
|
||||
opt['latent_multiple'], opt_get(opt, ['jitter_range'], 0))
|
||||
|
||||
def __getitem__(self, item):
|
||||
item = self.wrapped_dataset[item]
|
||||
|
@ -332,17 +354,17 @@ def test_structured_crop_dataset_wrapper():
|
|||
opt = {
|
||||
'dataset':
|
||||
{
|
||||
'mode': 'imagefolder',
|
||||
'name': 'amalgam',
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'],
|
||||
'weights': [1],
|
||||
'target_size': 256,
|
||||
'force_multiple': 32,
|
||||
'scale': 1,
|
||||
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
|
||||
'random_corruptions': ['noise-5', 'none'],
|
||||
'num_corrupts_per_image': 1,
|
||||
'corrupt_before_downsize': True,
|
||||
'mode': 'imagefolder',
|
||||
'name': 'amalgam',
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'],
|
||||
'weights': [1],
|
||||
'target_size': 256,
|
||||
'force_multiple': 32,
|
||||
'scale': 1,
|
||||
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
|
||||
'random_corruptions': ['noise-5', 'none'],
|
||||
'num_corrupts_per_image': 1,
|
||||
'corrupt_before_downsize': True,
|
||||
},
|
||||
'latent_multiple': 16,
|
||||
'jitter_range': 0,
|
||||
|
@ -353,16 +375,17 @@ def test_structured_crop_dataset_wrapper():
|
|||
os.makedirs("debug", exist_ok=True)
|
||||
for i in tqdm(range(0, len(ds))):
|
||||
o = ds[random.randint(0, len(ds)-1)]
|
||||
#for k, v in o.items():
|
||||
# 'lq', 'hq', 'aug1', 'aug2',
|
||||
#if k in [ 'aug_shared_view', 'masked1', 'masked2']:
|
||||
#torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
|
||||
# for k, v in o.items():
|
||||
# 'lq', 'hq', 'aug1', 'aug2',
|
||||
# if k in [ 'aug_shared_view', 'masked1', 'masked2']:
|
||||
# torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
|
||||
rcpkg = o['similar_region_dimensions']
|
||||
pixun = PixelUnshuffle(16)
|
||||
pixsh = nn.PixelShuffle(16)
|
||||
rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(0)), pixun(o['aug2'].unsqueeze(0)), rcpkg.unsqueeze(0))
|
||||
#torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,))
|
||||
#torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,))
|
||||
rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(
|
||||
0)), pixun(o['aug2'].unsqueeze(0)), rcpkg.unsqueeze(0))
|
||||
# torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,))
|
||||
# torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,17 +1,19 @@
|
|||
import os.path as osp
|
||||
from data import util
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from dlas.data import util
|
||||
# Iterable that reads all the images in a directory that contains a reference image, tile images and center coordinates.
|
||||
from utils.util import opt_get
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
class ChunkWithReference:
|
||||
def __init__(self, opt, path):
|
||||
self.path = path.path
|
||||
self.tiles, _ = util.find_files_of_type('img', self.path)
|
||||
self.need_metadata = opt_get(opt, ['strict'], False) or opt_get(opt, ['needs_metadata'], False)
|
||||
self.need_metadata = opt_get(opt, ['strict'], False) or opt_get(
|
||||
opt, ['needs_metadata'], False)
|
||||
self.need_ref = opt_get(opt, ['need_ref'], False)
|
||||
if 'ignore_first' in opt.keys():
|
||||
self.tiles = self.tiles[opt['ignore_first']:]
|
||||
|
@ -41,12 +43,14 @@ class ChunkWithReference:
|
|||
else:
|
||||
center = torch.tensor([128, 128], dtype=torch.long)
|
||||
tile_width = 256
|
||||
mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype)
|
||||
mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
|
||||
mask = np.full(tile.shape[:2] + (1,),
|
||||
fill_value=.1, dtype=tile.dtype)
|
||||
mask[center[0] - tile_width // 2:center[0] + tile_width // 2,
|
||||
center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
|
||||
else:
|
||||
ref = np.zeros_like(tile)
|
||||
mask = np.zeros(tile.shape[:2] + (1,))
|
||||
center = (0,0)
|
||||
center = (0, 0)
|
||||
|
||||
return tile, ref, center, mask, self.tiles[item]
|
||||
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
# A copy of the cifar dataset from torch which also returns coarse labels.
|
||||
|
||||
from PIL import Image
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
import pickle
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torchvision.datasets import VisionDataset
|
||||
from torchvision.datasets.utils import check_integrity, download_and_extract_archive
|
||||
from torchvision.datasets.utils import (check_integrity,
|
||||
download_and_extract_archive)
|
||||
|
||||
|
||||
class CIFAR10(VisionDataset):
|
||||
|
@ -104,7 +105,8 @@ class CIFAR10(VisionDataset):
|
|||
with open(path, 'rb') as infile:
|
||||
data = pickle.load(infile, encoding='latin1')
|
||||
self.classes = data[self.meta['key']]
|
||||
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
|
||||
self.class_to_idx = {_class: i for i,
|
||||
_class in enumerate(self.classes)}
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||
"""
|
||||
|
@ -147,7 +149,8 @@ class CIFAR10(VisionDataset):
|
|||
if self._check_integrity():
|
||||
print('Files already downloaded and verified')
|
||||
return
|
||||
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
|
||||
download_and_extract_archive(
|
||||
self.url, self.root, filename=self.filename, md5=self.tgz_md5)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "Split: {}".format("Train" if self.train is True else "Test")
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
import random
|
||||
import numpy as np
|
||||
from io import BytesIO
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import data.util as util
|
||||
from PIL import Image, ImageOps
|
||||
from io import BytesIO
|
||||
import torchvision.transforms.functional as F
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
import dlas.data.util as util
|
||||
|
||||
|
||||
# Reads full-quality images and pulls tiles from them. Also extracts LR renderings of the full image with cues as to
|
||||
|
@ -16,6 +18,7 @@ class FullImageDataset(data.Dataset):
|
|||
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
|
||||
If only GT images are provided, generate LQ images on-the-fly.
|
||||
"""
|
||||
|
||||
def get_lq_path(self, i):
|
||||
which_lq = random.randint(0, len(self.paths_LQ)-1)
|
||||
return self.paths_LQ[which_lq][i % len(self.paths_LQ[which_lq])]
|
||||
|
@ -27,19 +30,23 @@ class FullImageDataset(data.Dataset):
|
|||
self.paths_LQ, self.paths_GT = None, None
|
||||
self.sizes_LQ, self.sizes_GT = None, None
|
||||
self.LQ_env, self.GT_env = None, None
|
||||
self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1
|
||||
self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys(
|
||||
) else 1
|
||||
|
||||
self.paths_GT, self.sizes_GT = util.find_files_of_type(self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights'])
|
||||
self.paths_GT, self.sizes_GT = util.find_files_of_type(
|
||||
self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights'])
|
||||
if 'dataroot_LQ' in opt.keys():
|
||||
self.paths_LQ = []
|
||||
if isinstance(opt['dataroot_LQ'], list):
|
||||
# Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and
|
||||
# we want the model to learn them all.
|
||||
for dr_lq in opt['dataroot_LQ']:
|
||||
lq_path, self.sizes_LQ = util.find_files_of_type(self.data_type, dr_lq)
|
||||
lq_path, self.sizes_LQ = util.find_files_of_type(
|
||||
self.data_type, dr_lq)
|
||||
self.paths_LQ.append(lq_path)
|
||||
else:
|
||||
lq_path, self.sizes_LQ = util.find_files_of_type(self.data_type, opt['dataroot_LQ'])
|
||||
lq_path, self.sizes_LQ = util.find_files_of_type(
|
||||
self.data_type, opt['dataroot_LQ'])
|
||||
self.paths_LQ.append(lq_path)
|
||||
|
||||
assert self.paths_GT, 'Error: GT path is empty.'
|
||||
|
@ -48,7 +55,8 @@ class FullImageDataset(data.Dataset):
|
|||
def motion_blur(self, image, size, angle):
|
||||
k = np.zeros((size, size), dtype=np.float32)
|
||||
k[(size - 1) // 2, :] = np.ones(size, dtype=np.float32)
|
||||
k = cv2.warpAffine(k, cv2.getRotationMatrix2D((size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size))
|
||||
k = cv2.warpAffine(k, cv2.getRotationMatrix2D(
|
||||
(size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size))
|
||||
k = k * (1.0 / np.sum(k))
|
||||
return cv2.filter2D(image, -1, k)
|
||||
|
||||
|
@ -100,8 +108,10 @@ class FullImageDataset(data.Dataset):
|
|||
if 'fixed_size' in self.opt.keys() and self.opt['fixed_size']:
|
||||
square_size = target_sz
|
||||
else:
|
||||
tile_expansion_dev = self.opt['tile_scale_normal_stddev'] if 'tile_scale_normal_stddev' in self.opt.keys() else .17
|
||||
square_size = int(target_sz + possible_sizes_above_target * min(np.abs(np.random.normal(scale=tile_expansion_dev)), 1.0))
|
||||
tile_expansion_dev = self.opt['tile_scale_normal_stddev'] if 'tile_scale_normal_stddev' in self.opt.keys(
|
||||
) else .17
|
||||
square_size = int(target_sz + possible_sizes_above_target *
|
||||
min(np.abs(np.random.normal(scale=tile_expansion_dev)), 1.0))
|
||||
|
||||
# Pick the left,top coords to draw the patch from
|
||||
left = self.pick_along_range(w, square_size, .3)
|
||||
|
@ -110,11 +120,15 @@ class FullImageDataset(data.Dataset):
|
|||
mask = np.zeros((h, w, 1), dtype=image.dtype)
|
||||
mask[top:top+square_size, left:left+square_size] = 1
|
||||
patch = image[top:top+square_size, left:left+square_size, :]
|
||||
center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long)
|
||||
center = torch.tensor(
|
||||
[top + square_size // 2, left + square_size // 2], dtype=torch.long)
|
||||
|
||||
patch = cv2.resize(patch, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
|
||||
image = cv2.resize(image, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
|
||||
mask = cv2.resize(mask, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
|
||||
patch = cv2.resize(patch, (target_sz, target_sz),
|
||||
interpolation=cv2.INTER_LINEAR)
|
||||
image = cv2.resize(image, (target_sz, target_sz),
|
||||
interpolation=cv2.INTER_LINEAR)
|
||||
mask = cv2.resize(mask, (target_sz, target_sz),
|
||||
interpolation=cv2.INTER_LINEAR)
|
||||
center = self.resize_point(center, (h, w), image.shape[:2])
|
||||
|
||||
return patch, image, mask, center
|
||||
|
@ -127,19 +141,24 @@ class FullImageDataset(data.Dataset):
|
|||
assert H >= GT_size and W >= GT_size
|
||||
|
||||
LQ_size = GT_size // scale
|
||||
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
||||
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||||
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size),
|
||||
interpolation=cv2.INTER_LINEAR)
|
||||
img_GT = cv2.resize(img_GT, (GT_size, GT_size),
|
||||
interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
if self.opt['use_blurring']:
|
||||
# Pick randomly between gaussian, motion, or no blur.
|
||||
blur_det = random.randint(0, 100)
|
||||
blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude']
|
||||
blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys(
|
||||
) else self.opt['blur_magnitude']
|
||||
blur_magnitude = max(1, int(blur_magnitude*strength))
|
||||
if blur_det < 40:
|
||||
blur_sig = int(random.randrange(0, int(blur_magnitude)))
|
||||
img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig)
|
||||
img_LQ = cv2.GaussianBlur(
|
||||
img_LQ, (blur_magnitude, blur_magnitude), blur_sig)
|
||||
elif blur_det < 70:
|
||||
img_LQ = self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), random.randint(0, 360))
|
||||
img_LQ = self.motion_blur(img_LQ, random.randrange(
|
||||
1, int(blur_magnitude) * 3), random.randint(0, 360))
|
||||
|
||||
return img_GT, img_LQ
|
||||
|
||||
|
@ -174,13 +193,15 @@ class FullImageDataset(data.Dataset):
|
|||
# Gaussian Blur (point or motion)
|
||||
blur_magnitude = 3
|
||||
blur_sig = int(random.randrange(0, int(blur_magnitude)))
|
||||
image = cv2.GaussianBlur(image, (blur_magnitude, blur_magnitude), blur_sig)
|
||||
image = cv2.GaussianBlur(
|
||||
image, (blur_magnitude, blur_magnitude), blur_sig)
|
||||
elif 2 in aug_code:
|
||||
# Median Blur
|
||||
image = cv2.medianBlur(image, 3)
|
||||
elif 3 in aug_code:
|
||||
# Motion blur
|
||||
image = self.motion_blur(image, random.randrange(1, 9), random.randint(0, 360))
|
||||
image = self.motion_blur(
|
||||
image, random.randrange(1, 9), random.randint(0, 360))
|
||||
elif 4 in aug_code:
|
||||
# Smooth blur
|
||||
image = cv2.blur(image, ksize=3)
|
||||
|
@ -217,15 +238,19 @@ class FullImageDataset(data.Dataset):
|
|||
full_path = self.paths_GT[index % len(self.paths_GT)]
|
||||
LQ_path = full_path
|
||||
img_full = util.read_img(None, full_path, None)
|
||||
img_full = util.channel_convert(img_full.shape[2], 'RGB', [img_full])[0]
|
||||
img_full = util.channel_convert(
|
||||
img_full.shape[2], 'RGB', [img_full])[0]
|
||||
if self.opt['phase'] == 'train':
|
||||
img_full = util.augment([img_full], self.opt['use_flip'], self.opt['use_rot'])[0]
|
||||
img_full = util.augment(
|
||||
[img_full], self.opt['use_flip'], self.opt['use_rot'])[0]
|
||||
img_full = self.get_square_image(img_full)
|
||||
img_GT, gt_fullsize_ref, gt_mask, gt_center = self.pull_tile(img_full)
|
||||
img_GT, gt_fullsize_ref, gt_mask, gt_center = self.pull_tile(
|
||||
img_full)
|
||||
else:
|
||||
img_GT, gt_fullsize_ref = img_full, img_full
|
||||
gt_mask = np.ones(img_full.shape[:2], dtype=gt_fullsize_ref.dtype)
|
||||
gt_center = torch.tensor([img_full.shape[0] // 2, img_full.shape[1] // 2], dtype=torch.long)
|
||||
gt_center = torch.tensor(
|
||||
[img_full.shape[0] // 2, img_full.shape[1] // 2], dtype=torch.long)
|
||||
orig_gt_dim = gt_fullsize_ref.shape[:2]
|
||||
|
||||
# get LQ image
|
||||
|
@ -233,13 +258,17 @@ class FullImageDataset(data.Dataset):
|
|||
LQ_path = self.get_lq_path(index)
|
||||
img_lq_full = util.read_img(None, LQ_path, None)
|
||||
if self.opt['phase'] == 'train':
|
||||
img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0]
|
||||
img_lq_full = util.augment(
|
||||
[img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0]
|
||||
img_lq_full = self.get_square_image(img_lq_full)
|
||||
img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(img_lq_full, lq=True)
|
||||
img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(
|
||||
img_lq_full, lq=True)
|
||||
else:
|
||||
img_LQ, lq_fullsize_ref = img_lq_full, img_lq_full
|
||||
lq_mask = np.ones(img_lq_full.shape[:2], dtype=lq_fullsize_ref.dtype)
|
||||
lq_center = torch.tensor([img_lq_full.shape[0] // 2, img_lq_full.shape[1] // 2], dtype=torch.long)
|
||||
lq_mask = np.ones(
|
||||
img_lq_full.shape[:2], dtype=lq_fullsize_ref.dtype)
|
||||
lq_center = torch.tensor(
|
||||
[img_lq_full.shape[0] // 2, img_lq_full.shape[1] // 2], dtype=torch.long)
|
||||
else: # down-sampling on-the-fly
|
||||
# randomly scale during training
|
||||
if self.opt['phase'] == 'train':
|
||||
|
@ -258,7 +287,8 @@ class FullImageDataset(data.Dataset):
|
|||
|
||||
H_s = _mod(H_s, random_scale, scale, GT_size)
|
||||
W_s = _mod(W_s, random_scale, scale, GT_size)
|
||||
img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
|
||||
img_GT = cv2.resize(img_GT, (W_s, H_s),
|
||||
interpolation=cv2.INTER_LINEAR)
|
||||
if img_GT.ndim == 2:
|
||||
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
|
@ -266,10 +296,12 @@ class FullImageDataset(data.Dataset):
|
|||
|
||||
# using matlab imresize
|
||||
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
|
||||
lq_fullsize_ref = util.imresize_np(gt_fullsize_ref, 1 / scale, True)
|
||||
lq_fullsize_ref = util.imresize_np(
|
||||
gt_fullsize_ref, 1 / scale, True)
|
||||
if img_LQ.ndim == 2:
|
||||
img_LQ = np.expand_dims(img_LQ, axis=2)
|
||||
lq_mask, lq_center = gt_mask, self.resize_point(gt_center.clone(), orig_gt_dim, lq_fullsize_ref.shape[:2])
|
||||
lq_mask, lq_center = gt_mask, self.resize_point(
|
||||
gt_center.clone(), orig_gt_dim, lq_fullsize_ref.shape[:2])
|
||||
orig_lq_dim = lq_fullsize_ref.shape[:2]
|
||||
|
||||
# Enforce force_resize constraints via clipping.
|
||||
|
@ -285,15 +317,20 @@ class FullImageDataset(data.Dataset):
|
|||
|
||||
if self.opt['phase'] == 'train':
|
||||
img_GT, img_LQ = self.augment_tile(img_GT, img_LQ)
|
||||
gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(gt_fullsize_ref, lq_fullsize_ref, strength=.2)
|
||||
gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(
|
||||
gt_fullsize_ref, lq_fullsize_ref, strength=.2)
|
||||
|
||||
# Scale masks.
|
||||
lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
|
||||
gt_mask = cv2.resize(gt_mask, (gt_fullsize_ref.shape[1], gt_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
|
||||
lq_mask = cv2.resize(
|
||||
lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
|
||||
gt_mask = cv2.resize(
|
||||
gt_mask, (gt_fullsize_ref.shape[1], gt_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# Scale center coords
|
||||
lq_center = self.resize_point(lq_center, orig_lq_dim, lq_fullsize_ref.shape[:2])
|
||||
gt_center = self.resize_point(gt_center, orig_gt_dim, gt_fullsize_ref.shape[:2])
|
||||
lq_center = self.resize_point(
|
||||
lq_center, orig_lq_dim, lq_fullsize_ref.shape[:2])
|
||||
gt_center = self.resize_point(
|
||||
gt_center, orig_gt_dim, gt_fullsize_ref.shape[:2])
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
if img_GT.shape[2] == 3:
|
||||
|
@ -303,16 +340,20 @@ class FullImageDataset(data.Dataset):
|
|||
gt_fullsize_ref = cv2.cvtColor(gt_fullsize_ref, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# LQ needs to go to a PIL image to perform the compression-artifact transformation.
|
||||
#if self.opt['phase'] == 'train':
|
||||
#img_LQ = self.pil_augment(img_LQ)
|
||||
#lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2)
|
||||
# if self.opt['phase'] == 'train':
|
||||
# img_LQ = self.pil_augment(img_LQ)
|
||||
# lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2)
|
||||
|
||||
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||||
gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float()
|
||||
img_GT = torch.from_numpy(np.ascontiguousarray(
|
||||
np.transpose(img_GT, (2, 0, 1)))).float()
|
||||
gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(
|
||||
np.transpose(gt_fullsize_ref, (2, 0, 1)))).float()
|
||||
img_LQ = F.to_tensor(img_LQ)
|
||||
lq_fullsize_ref = F.to_tensor(lq_fullsize_ref)
|
||||
lq_mask = torch.from_numpy(np.ascontiguousarray(lq_mask)).unsqueeze(dim=0)
|
||||
gt_mask = torch.from_numpy(np.ascontiguousarray(gt_mask)).unsqueeze(dim=0)
|
||||
lq_mask = torch.from_numpy(
|
||||
np.ascontiguousarray(lq_mask)).unsqueeze(dim=0)
|
||||
gt_mask = torch.from_numpy(
|
||||
np.ascontiguousarray(gt_mask)).unsqueeze(dim=0)
|
||||
|
||||
if 'lq_noise' in self.opt.keys():
|
||||
lq_noise = torch.randn_like(img_LQ) * self.opt['lq_noise'] / 255
|
||||
|
@ -331,6 +372,7 @@ class FullImageDataset(data.Dataset):
|
|||
def __len__(self):
|
||||
return len(self.paths_GT)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
'''
|
||||
opt = {
|
||||
|
@ -365,10 +407,10 @@ if __name__ == '__main__':
|
|||
o = ds[i]
|
||||
for k, v in o.items():
|
||||
if 'path' not in k:
|
||||
#if 'full' in k:
|
||||
#masked = v[:3, :, :] * v[3]
|
||||
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
|
||||
#v = v[:3, :, :]
|
||||
#import torchvision
|
||||
#torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
|
||||
# if 'full' in k:
|
||||
# masked = v[:3, :, :] * v[3]
|
||||
# torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
|
||||
# v = v[:3, :, :]
|
||||
# import torchvision
|
||||
# torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
|
||||
pass
|
||||
|
|
|
@ -1,20 +1,18 @@
|
|||
import functools
|
||||
import random
|
||||
from io import BytesIO
|
||||
from math import cos, pi
|
||||
|
||||
import cv2
|
||||
import kornia
|
||||
import numpy as np
|
||||
import torch
|
||||
from kornia.augmentation import ColorJitter
|
||||
|
||||
from data.util import read_img
|
||||
from kornia.augmentation import ColorJitter
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
# Get a rough visualization of the above distribution. (Y-axis is meaningless, just spreads data)
|
||||
from utils.util import opt_get
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
'''
|
||||
if __name__ == '__main__':
|
||||
|
@ -29,9 +27,9 @@ if __name__ == '__main__':
|
|||
def kornia_color_jitter_numpy(img, setting):
|
||||
if setting * 255 > 1:
|
||||
# I'm using Kornia's ColorJitter, which requires pytorch arrays in b,c,h,w format.
|
||||
img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
|
||||
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
|
||||
img = ColorJitter(setting, setting, setting, setting)(img)
|
||||
img = img.squeeze(0).permute(1,2,0).numpy()
|
||||
img = img.squeeze(0).permute(1, 2, 0).numpy()
|
||||
return img
|
||||
|
||||
|
||||
|
@ -41,14 +39,18 @@ class ImageCorruptor:
|
|||
def __init__(self, opt):
|
||||
self.opt = opt
|
||||
self.reset_random()
|
||||
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
|
||||
self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else []
|
||||
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0
|
||||
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys(
|
||||
) else 1
|
||||
self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else [
|
||||
]
|
||||
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys(
|
||||
) else 0
|
||||
self.cosine_bias = opt_get(opt, ['cosine_bias'], True)
|
||||
if self.num_corrupts == 0:
|
||||
return
|
||||
else:
|
||||
self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else []
|
||||
self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else [
|
||||
]
|
||||
|
||||
def reset_random(self):
|
||||
if 'random_seed' in self.opt.keys():
|
||||
|
@ -75,7 +77,8 @@ class ImageCorruptor:
|
|||
if self.num_corrupts == 0:
|
||||
augmentations = []
|
||||
else:
|
||||
augmentations = random.choices(self.random_corruptions, k=self.num_corrupts)
|
||||
augmentations = random.choices(
|
||||
self.random_corruptions, k=self.num_corrupts)
|
||||
|
||||
# Sources of entropy
|
||||
corrupted_imgs = []
|
||||
|
@ -99,7 +102,6 @@ class ImageCorruptor:
|
|||
img = ufn(img)
|
||||
corrupted_imgs.append(img)
|
||||
|
||||
|
||||
if return_entropy:
|
||||
return corrupted_imgs, entropy
|
||||
else:
|
||||
|
@ -119,11 +121,11 @@ class ImageCorruptor:
|
|||
setting = rand_val * (hi_end - lo_end) + lo_end
|
||||
img = kornia_color_jitter_numpy(img, setting)
|
||||
elif 'gaussian_blur' in aug:
|
||||
img = cv2.GaussianBlur(img, (0,0), self.blur_scale*rand_val*1.5)
|
||||
img = cv2.GaussianBlur(img, (0, 0), self.blur_scale*rand_val*1.5)
|
||||
elif 'motion_blur' in aug:
|
||||
# Motion blur
|
||||
intensity = self.blur_scale*rand_val * 3 + 1
|
||||
angle = random.randint(0,360)
|
||||
angle = random.randint(0, 360)
|
||||
k = np.zeros((intensity, intensity), dtype=np.float32)
|
||||
k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32)
|
||||
k = cv2.warpAffine(k, cv2.getRotationMatrix2D((intensity / 2 - 0.5, intensity / 2 - 0.5), angle, 1.0),
|
||||
|
@ -145,10 +147,13 @@ class ImageCorruptor:
|
|||
else:
|
||||
scale = 4
|
||||
if scale > 1:
|
||||
interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4]
|
||||
mode = random.randint(0,4) % len(interpolation_modes)
|
||||
interpolation_modes = [
|
||||
cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4]
|
||||
mode = random.randint(0, 4) % len(interpolation_modes)
|
||||
# Downsample first, then upsample using the random mode.
|
||||
img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=mode)
|
||||
img = cv2.resize(img, dsize=(
|
||||
img.shape[1]//scale, img.shape[0]//scale), interpolation=mode)
|
||||
|
||||
def lq_resampling_undo_fn(scale, img):
|
||||
return cv2.resize(img, dsize=(img.shape[1]*scale, img.shape[0]*scale), interpolation=cv2.INTER_LINEAR)
|
||||
undo_fn = functools.partial(lq_resampling_undo_fn, scale)
|
||||
|
@ -171,22 +176,23 @@ class ImageCorruptor:
|
|||
elif 'jpeg' in aug:
|
||||
if 'noise' not in applied_augmentations and 'noise-5' not in applied_augmentations:
|
||||
if aug == 'jpeg':
|
||||
lo=10
|
||||
range=20
|
||||
lo = 10
|
||||
range = 20
|
||||
elif aug == 'jpeg-low':
|
||||
lo=15
|
||||
range=10
|
||||
lo = 15
|
||||
range = 10
|
||||
elif aug == 'jpeg-medium':
|
||||
lo=23
|
||||
range=25
|
||||
lo = 23
|
||||
range = 25
|
||||
elif aug == 'jpeg-broad':
|
||||
lo=15
|
||||
range=60
|
||||
lo = 15
|
||||
range = 60
|
||||
elif aug == 'jpeg-normal':
|
||||
lo=47
|
||||
range=35
|
||||
lo = 47
|
||||
range = 35
|
||||
else:
|
||||
raise NotImplementedError("specified jpeg corruption doesn't exist")
|
||||
raise NotImplementedError(
|
||||
"specified jpeg corruption doesn't exist")
|
||||
# JPEG compression
|
||||
qf = (int((1-rand_val)*range) + lo)
|
||||
# Use PIL to perform a mock compression to a data buffer, then swap back to cv2.
|
||||
|
@ -195,14 +201,15 @@ class ImageCorruptor:
|
|||
buffer = BytesIO()
|
||||
img.save(buffer, "JPEG", quality=qf, optimize=True)
|
||||
buffer.seek(0)
|
||||
jpeg_img_bytes = np.asarray(bytearray(buffer.read()), dtype="uint8")
|
||||
jpeg_img_bytes = np.asarray(
|
||||
bytearray(buffer.read()), dtype="uint8")
|
||||
img = read_img("buffer", jpeg_img_bytes, rgb=True)
|
||||
elif 'saturation' in aug:
|
||||
# Lightening / saturation
|
||||
saturation = rand_val * .3
|
||||
img = np.clip(img + saturation, a_max=1, a_min=0)
|
||||
elif 'greyscale' in aug:
|
||||
img = np.tile(np.mean(img, axis=2, keepdims=True), [1,1,3])
|
||||
img = np.tile(np.mean(img, axis=2, keepdims=True), [1, 1, 3])
|
||||
elif 'none' not in aug:
|
||||
raise NotImplementedError("Augmentation doesn't exist")
|
||||
|
||||
|
|
|
@ -1,21 +1,21 @@
|
|||
import functools
|
||||
import os
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
|
||||
import torchvision
|
||||
from data import util
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Normalize
|
||||
from tqdm import tqdm
|
||||
|
||||
from data import util
|
||||
# Builds a dataset created from a simple folder containing a list of training/test/validation images.
|
||||
from data.images.image_corruptor import ImageCorruptor, kornia_color_jitter_numpy
|
||||
from data.images.image_label_parser import VsNetImageLabeler
|
||||
from utils.util import opt_get
|
||||
from dlas.data.images.image_corruptor import (ImageCorruptor,
|
||||
kornia_color_jitter_numpy)
|
||||
from dlas.data.images.image_label_parser import VsNetImageLabeler
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
def ndarray_center_crop(crop, img):
|
||||
|
|
|
@ -50,15 +50,15 @@ class VsNetImageLabeler:
|
|||
|
||||
def get_labels_as_tensor(self, hq, img_key, resize_factor):
|
||||
_, h, w = hq.shape
|
||||
labels = torch.zeros((1,h,w), dtype=torch.long)
|
||||
mask = torch.zeros((1,h,w), dtype=torch.float)
|
||||
labels = torch.zeros((1, h, w), dtype=torch.long)
|
||||
mask = torch.zeros((1, h, w), dtype=torch.float)
|
||||
lbl_list = self.labeled_images[img_key]
|
||||
for patch_lbl in lbl_list:
|
||||
t, l, h, w = patch_lbl['patch_top'] // resize_factor, patch_lbl['patch_left'] // resize_factor, \
|
||||
patch_lbl['patch_height'] // resize_factor, patch_lbl['patch_width'] // resize_factor
|
||||
patch_lbl['patch_height'] // resize_factor, patch_lbl['patch_width'] // resize_factor
|
||||
val = patch_lbl['labelValue']
|
||||
labels[:,t:t+h,l:l+w] = val
|
||||
mask[:,t:t+h,l:l+w] = 1.0
|
||||
labels[:, t:t+h, l:l+w] = val
|
||||
mask[:, t:t+h, l:l+w] = 1.0
|
||||
return labels, mask, self.str_labels
|
||||
|
||||
def add_label(self, binding, img_name, top, left, dim):
|
||||
|
@ -100,22 +100,23 @@ class CompactJsonLabeler:
|
|||
assert self.config == parsed['config']
|
||||
assert self.labels == parsed['labels']
|
||||
assert self.label_map == parsed['label_map']
|
||||
self.images.update(parsed['images']) # This will overwrite existing images, which is acceptable.
|
||||
# This will overwrite existing images, which is acceptable.
|
||||
self.images.update(parsed['images'])
|
||||
|
||||
def get_labeled_paths(self, base_path):
|
||||
return [os.path.join(base_path, pth) for pth in self.images.keys()]
|
||||
|
||||
def get_labels_as_tensor(self, hq, img_key, resize_factor):
|
||||
_, h, w = hq.shape
|
||||
labels = torch.zeros((1,h,w), dtype=torch.long)
|
||||
mask = torch.zeros((1,h,w), dtype=torch.float)
|
||||
labels = torch.zeros((1, h, w), dtype=torch.long)
|
||||
mask = torch.zeros((1, h, w), dtype=torch.float)
|
||||
lbl_list = self.images[img_key]
|
||||
for patch_lbl in lbl_list:
|
||||
t, l, h, w = patch_lbl['top'] // resize_factor, patch_lbl['left'] // resize_factor, \
|
||||
self.config['dim'] // resize_factor, self.config['dim'] // resize_factor
|
||||
self.config['dim'] // resize_factor, self.config['dim'] // resize_factor
|
||||
val = patch_lbl['labelValue']
|
||||
labels[:,t:t+h,l:l+w] = val
|
||||
mask[:,t:t+h,l:l+w] = 1.0
|
||||
labels[:, t:t+h, l:l+w] = val
|
||||
mask[:, t:t+h, l:l+w] = 1.0
|
||||
return labels, mask, self.str_labels
|
||||
|
||||
def add_label(self, binding, img_name, top, left, dim):
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
import torch
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
# Builds a dataset created from a simple folder containing a list of training/test/validation images.
|
||||
|
||||
|
||||
|
@ -15,33 +14,42 @@ class ImagePairWithCorrespondingPointsDataset(Dataset):
|
|||
def __init__(self, opt):
|
||||
self.opt = opt
|
||||
self.path = opt['path']
|
||||
self.pairs = list(filter(lambda f: not os.path.isdir(f), os.listdir(self.path)))
|
||||
self.pairs = list(
|
||||
filter(lambda f: not os.path.isdir(f), os.listdir(self.path)))
|
||||
self.transforms = transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||
transforms.Normalize(
|
||||
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||
])
|
||||
self.size = opt['size']
|
||||
|
||||
|
||||
def __getitem__(self, item):
|
||||
dir = self.pairs[item]
|
||||
img1 = self.transforms(Image.open(os.path.join(self.path, dir, "1.jpg")))
|
||||
img2 = self.transforms(Image.open(os.path.join(self.path, dir, "2.jpg")))
|
||||
coords1, coords2 = torch.load(os.path.join(self.path, dir, "coords.pth"))
|
||||
img1 = self.transforms(Image.open(
|
||||
os.path.join(self.path, dir, "1.jpg")))
|
||||
img2 = self.transforms(Image.open(
|
||||
os.path.join(self.path, dir, "2.jpg")))
|
||||
coords1, coords2 = torch.load(
|
||||
os.path.join(self.path, dir, "coords.pth"))
|
||||
assert img1.shape[-2] == img1.shape[-1]
|
||||
assert img2.shape[-2] == img2.shape[-1]
|
||||
if img1.shape[-1] != self.size:
|
||||
scale = img1.shape[-1] / self.size
|
||||
assert(int(scale) == scale) # We will only downsample to even resolutions.
|
||||
# We will only downsample to even resolutions.
|
||||
assert (int(scale) == scale)
|
||||
scale = 1 / scale
|
||||
img1 = torch.nn.functional.interpolate(img1.unsqueeze(0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0)
|
||||
img1 = torch.nn.functional.interpolate(img1.unsqueeze(
|
||||
0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0)
|
||||
coords1 = [int(c * scale) for c in coords1]
|
||||
if img2.shape[-1] != self.size:
|
||||
scale = img2.shape[-1] / self.size
|
||||
assert(int(scale) == scale) # We will only downsample to even resolutions.
|
||||
# We will only downsample to even resolutions.
|
||||
assert (int(scale) == scale)
|
||||
scale = 1 / scale
|
||||
img2 = torch.nn.functional.interpolate(img2.unsqueeze(0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0)
|
||||
img2 = torch.nn.functional.interpolate(img2.unsqueeze(
|
||||
0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0)
|
||||
coords2 = [int(c * scale) for c in coords2]
|
||||
coords1 = (coords1[1], coords1[0]) # The UI puts these out backwards (x,y). Flip them.
|
||||
# The UI puts these out backwards (x,y). Flip them.
|
||||
coords1 = (coords1[1], coords1[0])
|
||||
coords2 = (coords2[1], coords2[0])
|
||||
return {
|
||||
'img1': img1,
|
||||
|
@ -53,6 +61,7 @@ class ImagePairWithCorrespondingPointsDataset(Dataset):
|
|||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = {
|
||||
'path': 'F:\\dlas\\codes\\scripts\\ui\\image_pair_labeler\\results',
|
||||
|
@ -60,13 +69,14 @@ if __name__ == '__main__':
|
|||
}
|
||||
output_path = '..'
|
||||
|
||||
ds = DataLoader(ImagePairWithCorrespondingPointsDataset(opt), shuffle=True, num_workers=0)
|
||||
ds = DataLoader(ImagePairWithCorrespondingPointsDataset(
|
||||
opt), shuffle=True, num_workers=0)
|
||||
for i, d in tqdm(enumerate(ds)):
|
||||
i1 = d['img1']
|
||||
i2 = d['img2']
|
||||
c1 = d['coords1']
|
||||
c2 = d['coords2']
|
||||
i1[:,:,c1[0]-3:c1[0]+3,c1[1]-3:c1[1]+3] = 0
|
||||
i2[:,:,c2[0]-3:c2[0]+3,c2[1]-3:c2[1]+3] = 0
|
||||
i1[:, :, c1[0]-3:c1[0]+3, c1[1]-3:c1[1]+3] = 0
|
||||
i2[:, :, c2[0]-3:c2[0]+3, c2[1]-3:c2[1]+3] = 0
|
||||
torchvision.utils.save_image(i1, f'{output_path}\\{i}_1.png')
|
||||
torchvision.utils.save_image(i2, f'{output_path}\\{i}_2.png')
|
||||
torchvision.utils.save_image(i2, f'{output_path}\\{i}_2.png')
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
from data.images.base_unsupervised_image_dataset import BaseUnsupervisedImageDataset
|
||||
import os.path as osp
|
||||
from bisect import bisect_left
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from bisect import bisect_left
|
||||
import os.path as osp
|
||||
|
||||
from dlas.data.images.base_unsupervised_image_dataset import \
|
||||
BaseUnsupervisedImageDataset
|
||||
|
||||
|
||||
class MultiFrameDataset(BaseUnsupervisedImageDataset):
|
||||
def __init__(self, opt):
|
||||
|
@ -30,7 +34,8 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
|
|||
for i in range(self.num_frames):
|
||||
idx = search_idx + i
|
||||
if idx < 0 or idx >= len(self.chunks) or chunk_offset < 0 or chunk_offset >= len(self.chunks[idx]):
|
||||
print("Chunk reference indexing failed for %s." % (im_name,), search_idx, i, chunk_offset, self.num_frames)
|
||||
print("Chunk reference indexing failed for %s." %
|
||||
(im_name,), search_idx, i, chunk_offset, self.num_frames)
|
||||
h, r, c, m, p = self.chunks[search_idx + i][chunk_offset]
|
||||
hqs.append(h)
|
||||
refs.append(r)
|
||||
|
@ -41,24 +46,32 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
|
|||
|
||||
def __getitem__(self, item):
|
||||
chunk_ind = bisect_left(self.starting_indices, item)
|
||||
chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1
|
||||
hqs, refs, masks, centers, path = self.get_sequential_image_paths_from(chunk_ind, item-self.starting_indices[chunk_ind])
|
||||
chunk_ind = chunk_ind if chunk_ind < len(
|
||||
self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1
|
||||
hqs, refs, masks, centers, path = self.get_sequential_image_paths_from(
|
||||
chunk_ind, item-self.starting_indices[chunk_ind])
|
||||
|
||||
hs, hrs, hms, hcs = self.resize_hq(hqs, refs, masks, centers)
|
||||
ls, lrs, lms, lcs = self.synthesize_lq(hs, hrs, hms, hcs)
|
||||
|
||||
# Convert to torch tensor
|
||||
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hs), (0, 3, 1, 2)))).float()
|
||||
hq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hrs), (0, 3, 1, 2)))).float()
|
||||
hq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(hms))).unsqueeze(dim=1)
|
||||
hq = torch.from_numpy(np.ascontiguousarray(
|
||||
np.transpose(np.stack(hs), (0, 3, 1, 2)))).float()
|
||||
hq_ref = torch.from_numpy(np.ascontiguousarray(
|
||||
np.transpose(np.stack(hrs), (0, 3, 1, 2)))).float()
|
||||
hq_mask = torch.from_numpy(
|
||||
np.ascontiguousarray(np.stack(hms))).unsqueeze(dim=1)
|
||||
hq_ref = torch.cat([hq_ref, hq_mask], dim=1)
|
||||
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(ls), (0, 3, 1, 2)))).float()
|
||||
lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float()
|
||||
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1)
|
||||
lq = torch.from_numpy(np.ascontiguousarray(
|
||||
np.transpose(np.stack(ls), (0, 3, 1, 2)))).float()
|
||||
lq_ref = torch.from_numpy(np.ascontiguousarray(
|
||||
np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float()
|
||||
lq_mask = torch.from_numpy(
|
||||
np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1)
|
||||
lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
|
||||
|
||||
return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
|
||||
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
|
||||
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -84,7 +97,7 @@ if __name__ == '__main__':
|
|||
for i in range(len(ds)):
|
||||
import random
|
||||
k = 'lq'
|
||||
element = ds[random.randint(0,len(ds))]
|
||||
element = ds[random.randint(0, len(ds))]
|
||||
base_file = osp.basename(element["GT_path"])
|
||||
o = element[k].unsqueeze(0)
|
||||
if bs < 32:
|
||||
|
@ -99,8 +112,9 @@ if __name__ == '__main__':
|
|||
b, fr, f, h, w = batch.shape
|
||||
for j in range(fr):
|
||||
import torchvision
|
||||
base=osp.basename(base_file)
|
||||
torchvision.utils.save_image(batch[:, j], "debug/%i_%s_%i__%s.png" % (i, k, j, base))
|
||||
base = osp.basename(base_file)
|
||||
torchvision.utils.save_image(
|
||||
batch[:, j], "debug/%i_%s_%i__%s.png" % (i, k, j, base))
|
||||
|
||||
bs = 0
|
||||
batch = None
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
import random
|
||||
import numpy as np
|
||||
|
||||
import cv2
|
||||
import data.util as util
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import data.util as util
|
||||
|
||||
# Reads full-quality images and pulls tiles at regular zoom intervals from them. Only usable for training purposes.
|
||||
from data.images.image_corruptor import ImageCorruptor
|
||||
from dlas.data.images.image_corruptor import ImageCorruptor
|
||||
|
||||
|
||||
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
|
||||
|
@ -27,6 +28,7 @@ def get_square_image(image):
|
|||
left = max(int(center + offset * (center - 2)), 0)
|
||||
return image[:, left:left + h, :]
|
||||
|
||||
|
||||
class MultiScaleDataset(data.Dataset):
|
||||
def __init__(self, opt):
|
||||
super(MultiScaleDataset, self).__init__()
|
||||
|
@ -36,10 +38,10 @@ class MultiScaleDataset(data.Dataset):
|
|||
self.num_scales = self.opt['num_scales']
|
||||
self.hq_size_cap = self.tile_size * 2 ** self.num_scales
|
||||
self.scale = self.opt['scale']
|
||||
self.paths_hq, self.sizes_hq = util.find_files_of_type(self.data_type, opt['paths'], [1 for _ in opt['paths']])
|
||||
self.paths_hq, self.sizes_hq = util.find_files_of_type(
|
||||
self.data_type, opt['paths'], [1 for _ in opt['paths']])
|
||||
self.corruptor = ImageCorruptor(opt)
|
||||
|
||||
|
||||
def recursively_extract_patches(self, input_img, result_list, depth):
|
||||
if depth >= self.num_scales:
|
||||
return
|
||||
|
@ -49,7 +51,8 @@ class MultiScaleDataset(data.Dataset):
|
|||
input_img[:patch_size, patch_size:],
|
||||
input_img[patch_size:, :patch_size],
|
||||
input_img[patch_size:, patch_size:]]
|
||||
result_list.extend([cv2.resize(p, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA) for p in patches])
|
||||
result_list.extend([cv2.resize(
|
||||
p, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA) for p in patches])
|
||||
for p in patches:
|
||||
self.recursively_extract_patches(p, result_list, depth+1)
|
||||
|
||||
|
@ -57,30 +60,41 @@ class MultiScaleDataset(data.Dataset):
|
|||
# get full size image
|
||||
full_path = self.paths_hq[index % len(self.paths_hq)]
|
||||
loaded_img = util.read_img(None, full_path, None)
|
||||
img_full1 = util.channel_convert(loaded_img.shape[2], 'RGB', [loaded_img])[0]
|
||||
img_full1 = util.channel_convert(
|
||||
loaded_img.shape[2], 'RGB', [loaded_img])[0]
|
||||
img_full2 = util.augment([img_full1], True, True)[0]
|
||||
img_full3 = get_square_image(img_full2)
|
||||
# This error crops up from time to time. I suspect an issue with util.read_img.
|
||||
if img_full3.shape[0] == 0 or img_full3.shape[1] == 0:
|
||||
print("Error with image: %s. Loaded image shape: %s" % (full_path,str(loaded_img.shape)), str(img_full1.shape), str(img_full2.shape), str(img_full3.shape))
|
||||
print("Error with image: %s. Loaded image shape: %s" % (full_path, str(
|
||||
loaded_img.shape)), str(img_full1.shape), str(img_full2.shape), str(img_full3.shape))
|
||||
# Attempt to recover by just using a fixed array of zeros, which the downstream networks should be fine training against, within reason.
|
||||
img_full3 = np.zeros((1024,1024,3), dtype=np.int)
|
||||
img_full = cv2.resize(img_full3, (self.hq_size_cap, self.hq_size_cap), interpolation=cv2.INTER_AREA)
|
||||
patches_hq = [cv2.resize(img_full, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)]
|
||||
img_full3 = np.zeros((1024, 1024, 3), dtype=np.int)
|
||||
img_full = cv2.resize(
|
||||
img_full3, (self.hq_size_cap, self.hq_size_cap), interpolation=cv2.INTER_AREA)
|
||||
patches_hq = [cv2.resize(
|
||||
img_full, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)]
|
||||
self.recursively_extract_patches(img_full, patches_hq, 1)
|
||||
# Image corruption is applied against the full size image for this dataset.
|
||||
img_corrupted = self.corruptor.corrupt_images([img_full])[0]
|
||||
patches_hq_corrupted = [cv2.resize(img_corrupted, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)]
|
||||
self.recursively_extract_patches(img_corrupted, patches_hq_corrupted, 1)
|
||||
patches_hq_corrupted = [cv2.resize(
|
||||
img_corrupted, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)]
|
||||
self.recursively_extract_patches(
|
||||
img_corrupted, patches_hq_corrupted, 1)
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
if patches_hq[0].shape[2] == 3:
|
||||
patches_hq = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq]
|
||||
patches_hq_corrupted = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq_corrupted]
|
||||
patches_hq = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq]
|
||||
patches_hq = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB)
|
||||
for p in patches_hq]
|
||||
patches_hq_corrupted = [cv2.cvtColor(
|
||||
p, cv2.COLOR_BGR2RGB) for p in patches_hq_corrupted]
|
||||
patches_hq = [torch.from_numpy(np.ascontiguousarray(
|
||||
np.transpose(p, (2, 0, 1)))).float() for p in patches_hq]
|
||||
patches_hq = torch.stack(patches_hq, dim=0)
|
||||
patches_hq_corrupted = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq_corrupted]
|
||||
patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='area').squeeze() for p in patches_hq_corrupted]
|
||||
patches_hq_corrupted = [torch.from_numpy(np.ascontiguousarray(
|
||||
np.transpose(p, (2, 0, 1)))).float() for p in patches_hq_corrupted]
|
||||
patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(
|
||||
0), scale_factor=1/self.scale, mode='area').squeeze() for p in patches_hq_corrupted]
|
||||
patches_lq = torch.stack(patches_lq, dim=0)
|
||||
|
||||
d = {'lq': patches_lq, 'hq': patches_hq, 'GT_path': full_path}
|
||||
|
@ -89,6 +103,7 @@ class MultiScaleDataset(data.Dataset):
|
|||
def __len__(self):
|
||||
return len(self.paths_hq)
|
||||
|
||||
|
||||
class MultiscaleTreeNode:
|
||||
def __init__(self, index, parent, i):
|
||||
self.index = index
|
||||
|
@ -117,7 +132,8 @@ def build_multiscale_patch_index_map(depth):
|
|||
|
||||
|
||||
def _build_multiscale_patch_index_map(depth, ind, node, leaves):
|
||||
subnodes = [node.add_child(MultiscaleTreeNode(ind+i, node, i)) for i in range(4)]
|
||||
subnodes = [node.add_child(MultiscaleTreeNode(ind+i, node, i))
|
||||
for i in range(4)]
|
||||
ind += 4
|
||||
if depth == 1:
|
||||
leaves.extend(subnodes)
|
||||
|
@ -146,7 +162,7 @@ if __name__ == '__main__':
|
|||
os.makedirs("debug", exist_ok=True)
|
||||
multiscale_tree = build_multiscale_patch_index_map(4)
|
||||
for i in range(500, len(ds)):
|
||||
quadrant=2
|
||||
quadrant = 2
|
||||
print(i)
|
||||
o = ds[random.randint(0, len(ds))]
|
||||
tree_ind = random.randint(0, len(multiscale_tree))
|
||||
|
@ -155,9 +171,10 @@ if __name__ == '__main__':
|
|||
continue
|
||||
depth = 0
|
||||
node = multiscale_tree[tree_ind]
|
||||
#for j, img in enumerate(v):
|
||||
# for j, img in enumerate(v):
|
||||
# torchvision.utils.save_image(img.unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j))
|
||||
while node is not None:
|
||||
torchvision.utils.save_image(v[node.index].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, depth))
|
||||
torchvision.utils.save_image(v[node.index].unsqueeze(
|
||||
0), "debug/%i_%s_%i.png" % (i, k, depth))
|
||||
depth += 1
|
||||
node = node.parent
|
||||
node = node.parent
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
from data.images.base_unsupervised_image_dataset import BaseUnsupervisedImageDataset
|
||||
import os.path as osp
|
||||
from bisect import bisect_left
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from bisect import bisect_left
|
||||
import os.path as osp
|
||||
|
||||
from dlas.data.images.base_unsupervised_image_dataset import \
|
||||
BaseUnsupervisedImageDataset
|
||||
|
||||
|
||||
class PairedFrameDataset(BaseUnsupervisedImageDataset):
|
||||
def __init__(self, opt):
|
||||
|
@ -26,24 +30,32 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset):
|
|||
|
||||
def __getitem__(self, item):
|
||||
chunk_ind = bisect_left(self.starting_indices, item)
|
||||
chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1
|
||||
hqs, refs, masks, centers, path = self.get_pair(chunk_ind, item-self.starting_indices[chunk_ind])
|
||||
chunk_ind = chunk_ind if chunk_ind < len(
|
||||
self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1
|
||||
hqs, refs, masks, centers, path = self.get_pair(
|
||||
chunk_ind, item-self.starting_indices[chunk_ind])
|
||||
|
||||
hs, hrs, hms, hcs = self.resize_hq(hqs, refs, masks, centers)
|
||||
ls, lrs, lms, lcs = self.synthesize_lq(hs, hrs, hms, hcs)
|
||||
|
||||
# Convert to torch tensor
|
||||
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hs), (0, 3, 1, 2)))).float()
|
||||
hq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hrs), (0, 3, 1, 2)))).float()
|
||||
hq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(hms))).squeeze().unsqueeze(dim=1)
|
||||
hq = torch.from_numpy(np.ascontiguousarray(
|
||||
np.transpose(np.stack(hs), (0, 3, 1, 2)))).float()
|
||||
hq_ref = torch.from_numpy(np.ascontiguousarray(
|
||||
np.transpose(np.stack(hrs), (0, 3, 1, 2)))).float()
|
||||
hq_mask = torch.from_numpy(np.ascontiguousarray(
|
||||
np.stack(hms))).squeeze().unsqueeze(dim=1)
|
||||
hq_ref = torch.cat([hq_ref, hq_mask], dim=1)
|
||||
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(ls), (0, 3, 1, 2)))).float()
|
||||
lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float()
|
||||
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).squeeze().unsqueeze(dim=1)
|
||||
lq = torch.from_numpy(np.ascontiguousarray(
|
||||
np.transpose(np.stack(ls), (0, 3, 1, 2)))).float()
|
||||
lq_ref = torch.from_numpy(np.ascontiguousarray(
|
||||
np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float()
|
||||
lq_mask = torch.from_numpy(np.ascontiguousarray(
|
||||
np.stack(lms))).squeeze().unsqueeze(dim=1)
|
||||
lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
|
||||
|
||||
return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
|
||||
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
|
||||
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -51,7 +63,7 @@ if __name__ == '__main__':
|
|||
'name': 'amalgam',
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\vr\\validation'],
|
||||
'weights': [1],
|
||||
#'target_size': 128,
|
||||
# 'target_size': 128,
|
||||
'force_multiple': 32,
|
||||
'scale': 2,
|
||||
'eval': False,
|
||||
|
@ -69,7 +81,7 @@ if __name__ == '__main__':
|
|||
for i in range(len(ds)):
|
||||
import random
|
||||
k = 'lq'
|
||||
element = ds[random.randint(0,len(ds))]
|
||||
element = ds[random.randint(0, len(ds))]
|
||||
base_file = osp.basename(element["GT_path"])
|
||||
o = element[k].unsqueeze(0)
|
||||
if bs < 2:
|
||||
|
@ -84,8 +96,9 @@ if __name__ == '__main__':
|
|||
b, fr, f, h, w = batch.shape
|
||||
for j in range(fr):
|
||||
import torchvision
|
||||
base=osp.basename(base_file)
|
||||
torchvision.utils.save_image(batch[:, j], "debug/%i_%s_%i__%s.png" % (i, k, j, base))
|
||||
base = osp.basename(base_file)
|
||||
torchvision.utils.save_image(
|
||||
batch[:, j], "debug/%i_%s_%i__%s.png" % (i, k, j, base))
|
||||
|
||||
bs = 0
|
||||
batch = None
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import random
|
||||
from bisect import bisect_left
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from data.images.base_unsupervised_image_dataset import BaseUnsupervisedImageDataset
|
||||
|
||||
from dlas.data.images.base_unsupervised_image_dataset import \
|
||||
BaseUnsupervisedImageDataset
|
||||
|
||||
|
||||
# Builds a dataset composed of a set of folders. Each folder represents a single high resolution image that has been
|
||||
|
@ -14,30 +17,40 @@ class SingleImageDataset(BaseUnsupervisedImageDataset):
|
|||
def get_paths(self):
|
||||
for i in range(len(self)):
|
||||
chunk_ind = bisect_left(self.starting_indices, i)
|
||||
chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == i else chunk_ind-1
|
||||
chunk_ind = chunk_ind if chunk_ind < len(
|
||||
self.starting_indices) and self.starting_indices[chunk_ind] == i else chunk_ind-1
|
||||
yield self.chunks[chunk_ind].tiles[i-self.starting_indices[chunk_ind]]
|
||||
|
||||
def __getitem__(self, item):
|
||||
chunk_ind = bisect_left(self.starting_indices, item)
|
||||
chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1
|
||||
hq, hq_ref, hq_center, hq_mask, path = self.chunks[chunk_ind][item-self.starting_indices[chunk_ind]]
|
||||
chunk_ind = chunk_ind if chunk_ind < len(
|
||||
self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1
|
||||
hq, hq_ref, hq_center, hq_mask, path = self.chunks[chunk_ind][item -
|
||||
self.starting_indices[chunk_ind]]
|
||||
|
||||
hs, hrs, hms, hcs = self.resize_hq([hq], [hq_ref], [hq_mask], [hq_center])
|
||||
hs, hrs, hms, hcs = self.resize_hq(
|
||||
[hq], [hq_ref], [hq_mask], [hq_center])
|
||||
ls, lrs, lms, lcs = self.synthesize_lq(hs, hrs, hms, hcs)
|
||||
|
||||
# Convert to torch tensor
|
||||
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float()
|
||||
hq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(hrs[0], (2, 0, 1)))).float()
|
||||
hq_mask = torch.from_numpy(np.ascontiguousarray(hms[0])).unsqueeze(dim=0)
|
||||
hq = torch.from_numpy(np.ascontiguousarray(
|
||||
np.transpose(hs[0], (2, 0, 1)))).float()
|
||||
hq_ref = torch.from_numpy(np.ascontiguousarray(
|
||||
np.transpose(hrs[0], (2, 0, 1)))).float()
|
||||
hq_mask = torch.from_numpy(
|
||||
np.ascontiguousarray(hms[0])).unsqueeze(dim=0)
|
||||
hq_ref = torch.cat([hq_ref, hq_mask], dim=0)
|
||||
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float()
|
||||
lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(lrs[0], (2, 0, 1)))).float()
|
||||
lq_mask = torch.from_numpy(np.ascontiguousarray(lms[0])).unsqueeze(dim=0)
|
||||
lq = torch.from_numpy(np.ascontiguousarray(
|
||||
np.transpose(ls[0], (2, 0, 1)))).float()
|
||||
lq_ref = torch.from_numpy(np.ascontiguousarray(
|
||||
np.transpose(lrs[0], (2, 0, 1)))).float()
|
||||
lq_mask = torch.from_numpy(
|
||||
np.ascontiguousarray(lms[0])).unsqueeze(dim=0)
|
||||
lq_ref = torch.cat([lq_ref, lq_mask], dim=0)
|
||||
|
||||
return {'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
|
||||
'lq_center': torch.tensor(lcs[0], dtype=torch.long), 'gt_center': torch.tensor(hcs[0], dtype=torch.long),
|
||||
'LQ_path': path, 'GT_path': path}
|
||||
'lq_center': torch.tensor(lcs[0], dtype=torch.long), 'gt_center': torch.tensor(hcs[0], dtype=torch.long),
|
||||
'LQ_path': path, 'GT_path': path}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -60,13 +73,14 @@ if __name__ == '__main__':
|
|||
os.makedirs("debug", exist_ok=True)
|
||||
for i in range(0, len(ds)):
|
||||
o = ds[random.randint(0, len(ds))]
|
||||
#for k, v in o.items():
|
||||
# for k, v in o.items():
|
||||
k = 'lq'
|
||||
v = o[k]
|
||||
#if 'LQ' in k and 'path' not in k and 'center' not in k:
|
||||
#if 'full' in k:
|
||||
#masked = v[:3, :, :] * v[3]
|
||||
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
|
||||
#v = v[:3, :, :]
|
||||
# if 'LQ' in k and 'path' not in k and 'center' not in k:
|
||||
# if 'full' in k:
|
||||
# masked = v[:3, :, :] * v[3]
|
||||
# torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
|
||||
# v = v[:3, :, :]
|
||||
import torchvision
|
||||
torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
|
||||
torchvision.utils.save_image(
|
||||
v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
from functools import partial
|
||||
from pathlib import Path
|
||||
from random import random
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from torch.utils import data
|
||||
from torchvision import transforms
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
|
||||
import models.image_generation.stylegan.stylegan2_lucidrains as sg2
|
||||
import dlas.models.image_generation.stylegan.stylegan2_lucidrains as sg2
|
||||
|
||||
|
||||
def convert_transparent_to_rgb(image):
|
||||
|
@ -29,6 +29,7 @@ def resize_to_minimum_size(min_size, image):
|
|||
return torchvision.transforms.functional.resize(image, min_size)
|
||||
return image
|
||||
|
||||
|
||||
class RandomApply(nn.Module):
|
||||
def __init__(self, prob, fn, fn_else=lambda x: x):
|
||||
super().__init__()
|
||||
|
@ -59,7 +60,8 @@ class expand_greyscale(object):
|
|||
color = tensor[:1].expand(3, -1, -1)
|
||||
alpha = tensor[1:]
|
||||
else:
|
||||
raise Exception(f'image with invalid number of channels given {channels}')
|
||||
raise Exception(
|
||||
f'image with invalid number of channels given {channels}')
|
||||
|
||||
if not sg2.exists(alpha) and self.transparent:
|
||||
alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device)
|
||||
|
@ -73,17 +75,21 @@ class Stylegan2Dataset(data.Dataset):
|
|||
EXTS = ['jpg', 'jpeg', 'png', 'webp']
|
||||
self.folder = opt['path']
|
||||
self.image_size = opt['target_size']
|
||||
self.paths = [p for ext in EXTS for p in Path(f'{self.folder}').glob(f'**/*.{ext}')]
|
||||
self.paths = [p for ext in EXTS for p in Path(
|
||||
f'{self.folder}').glob(f'**/*.{ext}')]
|
||||
aug_prob = opt['aug_prob']
|
||||
transparent = opt['transparent'] if 'transparent' in opt.keys() else False
|
||||
assert len(self.paths) > 0, f'No images were found in {self.folder} for training'
|
||||
transparent = opt['transparent'] if 'transparent' in opt.keys(
|
||||
) else False
|
||||
assert len(
|
||||
self.paths) > 0, f'No images were found in {self.folder} for training'
|
||||
|
||||
convert_image_fn = convert_transparent_to_rgb if not transparent else convert_rgb_to_transparent
|
||||
num_channels = 3 if not transparent else 4
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Lambda(convert_image_fn),
|
||||
transforms.Lambda(partial(resize_to_minimum_size, self.image_size)),
|
||||
transforms.Lambda(
|
||||
partial(resize_to_minimum_size, self.image_size)),
|
||||
transforms.Resize(self.image_size),
|
||||
RandomApply(aug_prob, transforms.RandomResizedCrop(self.image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)),
|
||||
transforms.CenterCrop(self.image_size)),
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import PIL.Image
|
||||
import zipfile
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torchvision
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
|
||||
from torchvision.transforms import Compose, Normalize, Resize, ToTensor
|
||||
|
||||
|
||||
class ZipFileDataset(torch.utils.data.Dataset):
|
||||
|
@ -14,9 +15,10 @@ class ZipFileDataset(torch.utils.data.Dataset):
|
|||
self.resolution = opt['resolution']
|
||||
self.paired_mode = opt['paired_mode']
|
||||
self.transforms = Compose([ToTensor(),
|
||||
Resize(self.resolution),
|
||||
Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
||||
])
|
||||
Resize(self.resolution),
|
||||
Normalize((0.485, 0.456, 0.406),
|
||||
(0.229, 0.224, 0.225))
|
||||
])
|
||||
self.zip = None
|
||||
|
||||
def __len__(self):
|
||||
|
@ -49,10 +51,12 @@ class ZipFileDataset(torch.utils.data.Dataset):
|
|||
aname = fname.replace('1.jpg', '0.jpg')
|
||||
out['alt_hq'] = self.load_image(aname)
|
||||
except:
|
||||
print(f"Error loading {fname} from zipfile. Attempting to recover by loading next element.")
|
||||
print(
|
||||
f"Error loading {fname} from zipfile. Attempting to recover by loading next element.")
|
||||
return self[i+1]
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = {
|
||||
'path': 'E:\\4k6k\\datasets\\images\\youtube-imagenet-paired\\output.zip',
|
||||
|
@ -65,4 +69,3 @@ if __name__ == '__main__':
|
|||
for i, d in enumerate(loader):
|
||||
torchvision.utils.save_image(d['hq'], f'{i}_hq.png')
|
||||
torchvision.utils.save_image(d['alt_hq'], f'{i}_althq.png')
|
||||
|
||||
|
|
0
dlas/data/text/__init__.py
Normal file
0
dlas/data/text/__init__.py
Normal file
|
@ -1,18 +1,20 @@
|
|||
from torch.utils.data import Dataset
|
||||
import datasets
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class HfDataset(Dataset):
|
||||
"""
|
||||
Simple wrapper for a HuggingFace dataset that can re-map keys if desired.
|
||||
"""
|
||||
|
||||
def __init__(self, corpi, cache_path=None, key_maps=None, dataset_spec_key='train'):
|
||||
self.hfd = []
|
||||
for corpus in corpi:
|
||||
dataset_name, config = corpus
|
||||
if config == '' or config == 'None':
|
||||
config = None
|
||||
self.hfd.append(datasets.load_dataset(dataset_name, config, cache_dir=cache_path)[dataset_spec_key])
|
||||
self.hfd.append(datasets.load_dataset(
|
||||
dataset_name, config, cache_dir=cache_path)[dataset_spec_key])
|
||||
self.key_maps = key_maps
|
||||
|
||||
def __getitem__(self, item):
|
||||
|
@ -32,5 +34,6 @@ class HfDataset(Dataset):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
d = HfDataset([['wikipedia', '20200501.en'], ['bookcorpus', '']], dataset_spec_key='train', cache_path='Z:\\huggingface_datasets\\cache')
|
||||
d = HfDataset([['wikipedia', '20200501.en'], ['bookcorpus', '']],
|
||||
dataset_spec_key='train', cache_path='Z:\\huggingface_datasets\\cache')
|
||||
print(d[5])
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from torch.utils.data import Dataset
|
||||
import torchvision.transforms as T
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import datasets
|
||||
|
||||
# Wrapper for basic pytorch datasets which re-wraps them into a format usable by ExtensibleTrainer.
|
||||
from data.images.cifar import CIFAR100, CIFAR10
|
||||
from utils.util import opt_get
|
||||
from dlas.data.images.cifar import CIFAR10, CIFAR100
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
class TorchDataset(Dataset):
|
||||
|
@ -17,7 +17,8 @@ class TorchDataset(Dataset):
|
|||
"imagenet": datasets.ImageNet,
|
||||
"imagefolder": datasets.ImageFolder
|
||||
}
|
||||
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[
|
||||
0.229, 0.224, 0.225])
|
||||
if opt_get(opt, ['random_crop'], False):
|
||||
transforms = [
|
||||
T.RandomResizedCrop(opt['image_size']),
|
||||
|
@ -34,7 +35,8 @@ class TorchDataset(Dataset):
|
|||
normalize,
|
||||
]
|
||||
transforms = T.Compose(transforms)
|
||||
self.dataset = DATASET_MAP[opt['dataset']](transform=transforms, **opt['kwargs'])
|
||||
self.dataset = DATASET_MAP[opt['dataset']](
|
||||
transform=transforms, **opt['kwargs'])
|
||||
self.len = opt_get(opt, ['fixed_len'], len(self.dataset))
|
||||
self.offset = opt_get(opt, ['offset'], 0)
|
||||
|
||||
|
@ -53,6 +55,7 @@ class TorchDataset(Dataset):
|
|||
def __len__(self):
|
||||
return self.len-self.offset
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = {
|
||||
'flip': True,
|
||||
|
|
|
@ -1,21 +1,22 @@
|
|||
import os
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
import numpy as np
|
||||
import glob
|
||||
import torch
|
||||
import torchvision
|
||||
import cv2
|
||||
|
||||
####################
|
||||
# Files & IO
|
||||
####################
|
||||
|
||||
###################### get image path list ######################
|
||||
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.webp', '.WEBP']
|
||||
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png',
|
||||
'.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.webp', '.WEBP']
|
||||
|
||||
|
||||
def torch2cv(tensor):
|
||||
|
@ -30,7 +31,8 @@ def torch2cv(tensor):
|
|||
|
||||
def cv2torch(cv, batchify=True):
|
||||
cv = cv2.cvtColor(cv, cv2.COLOR_BGR2RGB)
|
||||
tens = torch.from_numpy(np.ascontiguousarray(np.transpose(cv, (2, 0, 1)))).float()
|
||||
tens = torch.from_numpy(np.ascontiguousarray(
|
||||
np.transpose(cv, (2, 0, 1)))).float()
|
||||
if batchify:
|
||||
tens = tens.unsqueeze(0)
|
||||
return tens
|
||||
|
@ -65,7 +67,8 @@ def _get_paths_from_images(path, qualifier=is_image_file):
|
|||
|
||||
def _get_paths_from_lmdb(dataroot):
|
||||
"""get image path list from lmdb meta info"""
|
||||
meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb'))
|
||||
meta_info = pickle.load(
|
||||
open(os.path.join(dataroot, 'meta_info.pkl'), 'rb'))
|
||||
paths = meta_info['keys']
|
||||
sizes = meta_info['resolution']
|
||||
if len(sizes) == 1:
|
||||
|
@ -123,7 +126,8 @@ def read_img(env, path, size=None, rgb=False):
|
|||
# Indirect open then process to support unicode files.
|
||||
stream = open(path, "rb")
|
||||
bytes = bytearray(stream.read())
|
||||
img = cv2.imdecode(np.asarray(bytes, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
img = cv2.imdecode(np.asarray(bytes, dtype=np.uint8),
|
||||
cv2.IMREAD_UNCHANGED)
|
||||
elif env == 'lmdb':
|
||||
img = _read_img_lmdb(env, path, size)
|
||||
elif env == 'buffer':
|
||||
|
@ -161,7 +165,8 @@ def read_img_seq(path):
|
|||
# stack to Torch tensor
|
||||
imgs = np.stack(img_l, axis=0)
|
||||
imgs = imgs[:, :, :, [2, 1, 0]]
|
||||
imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
|
||||
imgs = torch.from_numpy(np.ascontiguousarray(
|
||||
np.transpose(imgs, (0, 3, 1, 2)))).float()
|
||||
return imgs
|
||||
|
||||
|
||||
|
@ -478,9 +483,12 @@ def imresize(img, scale, antialiasing=True):
|
|||
kernel_width = weights_H.size(1)
|
||||
for i in range(out_H):
|
||||
idx = int(indices_H[i][0])
|
||||
out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
||||
out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
||||
out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
||||
out_1[0, i, :] = img_aug[0, idx:idx + kernel_width,
|
||||
:].transpose(0, 1).mv(weights_H[i])
|
||||
out_1[1, i, :] = img_aug[1, idx:idx + kernel_width,
|
||||
:].transpose(0, 1).mv(weights_H[i])
|
||||
out_1[2, i, :] = img_aug[2, idx:idx + kernel_width,
|
||||
:].transpose(0, 1).mv(weights_H[i])
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
|
@ -501,9 +509,12 @@ def imresize(img, scale, antialiasing=True):
|
|||
kernel_width = weights_W.size(1)
|
||||
for i in range(out_W):
|
||||
idx = int(indices_W[i][0])
|
||||
out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i])
|
||||
out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i])
|
||||
out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i])
|
||||
out_2[0, :, i] = out_1_aug[0, :, idx:idx +
|
||||
kernel_width].mv(weights_W[i])
|
||||
out_2[1, :, i] = out_1_aug[1, :, idx:idx +
|
||||
kernel_width].mv(weights_W[i])
|
||||
out_2[2, :, i] = out_1_aug[2, :, idx:idx +
|
||||
kernel_width].mv(weights_W[i])
|
||||
|
||||
return out_2
|
||||
|
||||
|
@ -548,9 +559,12 @@ def imresize_np(img, scale, antialiasing=True):
|
|||
kernel_width = weights_H.size(1)
|
||||
for i in range(out_H):
|
||||
idx = int(indices_H[i][0])
|
||||
out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i])
|
||||
out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i])
|
||||
out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i])
|
||||
out_1[i, :, 0] = img_aug[idx:idx + kernel_width,
|
||||
:, 0].transpose(0, 1).mv(weights_H[i])
|
||||
out_1[i, :, 1] = img_aug[idx:idx + kernel_width,
|
||||
:, 1].transpose(0, 1).mv(weights_H[i])
|
||||
out_1[i, :, 2] = img_aug[idx:idx + kernel_width,
|
||||
:, 2].transpose(0, 1).mv(weights_H[i])
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
|
@ -571,9 +585,12 @@ def imresize_np(img, scale, antialiasing=True):
|
|||
kernel_width = weights_W.size(1)
|
||||
for i in range(out_W):
|
||||
idx = int(indices_W[i][0])
|
||||
out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i])
|
||||
out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i])
|
||||
out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i])
|
||||
out_2[:, i, 0] = out_1_aug[:, idx:idx +
|
||||
kernel_width, 0].mv(weights_W[i])
|
||||
out_2[:, i, 1] = out_1_aug[:, idx:idx +
|
||||
kernel_width, 1].mv(weights_W[i])
|
||||
out_2[:, i, 2] = out_1_aug[:, idx:idx +
|
||||
kernel_width, 2].mv(weights_W[i])
|
||||
|
||||
return out_2.numpy()
|
||||
|
||||
|
@ -587,7 +604,8 @@ def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=[], not
|
|||
print(f"Building cache for contents of {paths}..")
|
||||
output = []
|
||||
for p in paths:
|
||||
output.extend(find_files_of_type('img', p, qualifier=is_audio_file)[0])
|
||||
output.extend(find_files_of_type(
|
||||
'img', p, qualifier=is_audio_file)[0])
|
||||
if exclusion_list is not None and len(exclusion_list) > 0:
|
||||
print(f"Removing exclusion lists..")
|
||||
before = len(output)
|
||||
|
@ -597,6 +615,7 @@ def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=[], not
|
|||
print(f"Excluded {before-len(output)} files.")
|
||||
if endswith is not None:
|
||||
before = len(output)
|
||||
|
||||
def filter_fn(p):
|
||||
for e in endswith:
|
||||
if not p.endswith(e):
|
||||
|
@ -606,7 +625,8 @@ def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=[], not
|
|||
return False
|
||||
return True
|
||||
output = list(filter(filter_fn, output))
|
||||
print(f"!!Excluded {before-len(output)} files with endswith mask. For total of {len(output)} files")
|
||||
print(
|
||||
f"!!Excluded {before-len(output)} files with endswith mask. For total of {len(output)} files")
|
||||
print("Done.")
|
||||
torch.save(output, cache_path)
|
||||
return output
|
||||
|
@ -617,7 +637,8 @@ if __name__ == '__main__':
|
|||
# read images
|
||||
img = cv2.imread('test.png')
|
||||
img = img * 1.0 / 255
|
||||
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
|
||||
img = torch.from_numpy(np.transpose(
|
||||
img[:, :, [2, 1, 0]], (2, 0, 1))).float()
|
||||
# imresize
|
||||
scale = 1 / 4
|
||||
import time
|
||||
|
|
|
@ -7,12 +7,14 @@ class ZeroPadDictCollate():
|
|||
Given a list of dictionary outputs with torch.Tensors from a Dataset, iterates through each one, finds the longest
|
||||
tensor, and zero pads all the other tensors together.
|
||||
"""
|
||||
|
||||
def collate_tensors(self, batch, key):
|
||||
result = []
|
||||
largest_dims = [0 for _ in range(len(batch[0][key].shape))]
|
||||
for elem in batch:
|
||||
result.append(elem[key])
|
||||
largest_dims = [max(current_largest, new_consideration) for current_largest, new_consideration in zip(largest_dims, elem[key].shape)]
|
||||
largest_dims = [max(current_largest, new_consideration)
|
||||
for current_largest, new_consideration in zip(largest_dims, elem[key].shape)]
|
||||
# Now pad each tensor by the largest dimension.
|
||||
for i in range(len(result)):
|
||||
padding_tuple = ()
|
||||
|
@ -24,7 +26,6 @@ class ZeroPadDictCollate():
|
|||
|
||||
return torch.stack(result, dim=0)
|
||||
|
||||
|
||||
def collate_into_list(self, batch, key):
|
||||
result = []
|
||||
for elem in batch:
|
||||
|
@ -42,4 +43,4 @@ class ZeroPadDictCollate():
|
|||
collated[key] = torch.stack([b[key] for b in batch])
|
||||
else:
|
||||
collated[key] = self.collate_into_list(batch, key)
|
||||
return collated
|
||||
return collated
|
||||
|
|
|
@ -3,13 +3,11 @@ from abc import abstractmethod
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.utils.spectral_norm as SpectralNorm
|
||||
from math import sqrt
|
||||
import torch.nn.init as init
|
||||
|
||||
from utils.util import checkpoint
|
||||
import torch_intermediary as ml
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.utils.util import checkpoint
|
||||
|
||||
|
||||
def exists(val):
|
||||
|
@ -21,14 +19,14 @@ def default(val, d):
|
|||
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, p = 2, dim = -1)
|
||||
return F.normalize(t, p=2, dim=-1)
|
||||
|
||||
|
||||
def ema_inplace(moving_avg, new, decay):
|
||||
moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay))
|
||||
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
||||
|
||||
|
||||
def laplace_smoothing(x, n_categories, eps = 1e-5):
|
||||
def laplace_smoothing(x, n_categories, eps=1e-5):
|
||||
return (x + eps) / (x.sum() + n_categories * eps)
|
||||
|
||||
|
||||
|
@ -36,9 +34,9 @@ def sample_vectors(samples, num):
|
|||
num_samples, device = samples.shape[0], samples.device
|
||||
|
||||
if num_samples >= num:
|
||||
indices = torch.randperm(num_samples, device = device)[:num]
|
||||
indices = torch.randperm(num_samples, device=device)[:num]
|
||||
else:
|
||||
indices = torch.randint(0, num_samples, (num,), device = device)
|
||||
indices = torch.randint(0, num_samples, (num,), device=device)
|
||||
|
||||
return samples[indices]
|
||||
|
||||
|
@ -239,7 +237,8 @@ class AttentionPool2d(nn.Module):
|
|||
b, c, *_spatial = x.shape
|
||||
x = x.reshape(b, c, -1) # NC(HW)
|
||||
x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
||||
x = x + self.positional_embedding[None, :, :x.shape[-1]].to(x.dtype) # NC(HW+1)
|
||||
x = x + self.positional_embedding[None,
|
||||
:, :x.shape[-1]].to(x.dtype) # NC(HW+1)
|
||||
x = self.qkv_proj(x)
|
||||
x = self.attention(x)
|
||||
x = self.c_proj(x)
|
||||
|
@ -296,7 +295,8 @@ class Upsample(nn.Module):
|
|||
if dims == 1:
|
||||
ksize = 5
|
||||
pad = 2
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, ksize, padding=pad)
|
||||
self.conv = conv_nd(dims, self.channels,
|
||||
self.out_channels, ksize, padding=pad)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
|
@ -346,6 +346,7 @@ class cGLU(nn.Module):
|
|||
"""
|
||||
Gated GELU for channel-first architectures.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in, dim_out=None):
|
||||
super().__init__()
|
||||
dim_out = dim_in if dim_out is None else dim_out
|
||||
|
@ -395,7 +396,8 @@ class ResBlock(nn.Module):
|
|||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
|
||||
conv_nd(dims, channels, self.out_channels,
|
||||
kernel_size, padding=padding),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
@ -414,7 +416,8 @@ class ResBlock(nn.Module):
|
|||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
|
||||
conv_nd(dims, self.out_channels, self.out_channels,
|
||||
kernel_size, padding=padding)
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -425,7 +428,8 @@ class ResBlock(nn.Module):
|
|||
dims, channels, self.out_channels, kernel_size, padding=padding
|
||||
)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
|
@ -466,10 +470,10 @@ def build_local_attention_mask(n, l, fixed_region=0):
|
|||
A mask that can be applied to AttentionBlock to achieve local attention.
|
||||
"""
|
||||
assert l*2 < n, f'Local context must be less than global context. {l}, {n}'
|
||||
o = torch.arange(0,n)
|
||||
c = o.unsqueeze(-1).repeat(1,n)
|
||||
r = o.unsqueeze(0).repeat(n,1)
|
||||
localized = ((-(r-c).abs())+l).clamp(0,l-1) / (l-1)
|
||||
o = torch.arange(0, n)
|
||||
c = o.unsqueeze(-1).repeat(1, n)
|
||||
r = o.unsqueeze(0).repeat(n, 1)
|
||||
localized = ((-(r-c).abs())+l).clamp(0, l-1) / (l-1)
|
||||
localized[:fixed_region] = 1
|
||||
localized[:, :fixed_region] = 1
|
||||
mask = localized > 0
|
||||
|
@ -477,7 +481,7 @@ def build_local_attention_mask(n, l, fixed_region=0):
|
|||
|
||||
|
||||
def test_local_attention_mask():
|
||||
print(build_local_attention_mask(9,4,1))
|
||||
print(build_local_attention_mask(9, 4, 1))
|
||||
|
||||
|
||||
class RelativeQKBias(nn.Module):
|
||||
|
@ -487,17 +491,18 @@ class RelativeQKBias(nn.Module):
|
|||
|
||||
If symmetric=False, a different bias is applied to each side of the input element, otherwise the bias is symmetric.
|
||||
"""
|
||||
|
||||
def __init__(self, l, max_positions=4000, symmetric=True):
|
||||
super().__init__()
|
||||
if symmetric:
|
||||
self.emb = nn.Parameter(torch.randn(l+1) * .01)
|
||||
o = torch.arange(0,max_positions)
|
||||
c = o.unsqueeze(-1).repeat(1,max_positions)
|
||||
r = o.unsqueeze(0).repeat(max_positions,1)
|
||||
M = ((-(r-c).abs())+l).clamp(0,l)
|
||||
o = torch.arange(0, max_positions)
|
||||
c = o.unsqueeze(-1).repeat(1, max_positions)
|
||||
r = o.unsqueeze(0).repeat(max_positions, 1)
|
||||
M = ((-(r-c).abs())+l).clamp(0, l)
|
||||
else:
|
||||
self.emb = nn.Parameter(torch.randn(l*2+2) * .01)
|
||||
a = torch.arange(0,max_positions)
|
||||
a = torch.arange(0, max_positions)
|
||||
c = a.unsqueeze(-1) - a
|
||||
m = (c >= -l).logical_and(c <= l)
|
||||
M = (l+c+1)*m
|
||||
|
@ -508,7 +513,7 @@ class RelativeQKBias(nn.Module):
|
|||
# return self.emb[self.M[:n, :n]].view(1,n,n)
|
||||
# However, indexing operations like this have horrible efficiency on GPUs: https://github.com/pytorch/pytorch/issues/15245
|
||||
# So, enter this horrible, equivalent mess:
|
||||
return torch.gather(self.emb.unsqueeze(-1).repeat(1,n), 0, self.M[:n,:n]).view(1,n,n)
|
||||
return torch.gather(self.emb.unsqueeze(-1).repeat(1, n), 0, self.M[:n, :n]).view(1, n, n)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
|
@ -550,7 +555,8 @@ class AttentionBlock(nn.Module):
|
|||
# split heads before split qkv
|
||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||
|
||||
self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1)
|
||||
self.x_proj = nn.Identity() if out_channels == channels else conv_nd(
|
||||
1, channels, out_channels, 1)
|
||||
self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1))
|
||||
|
||||
def forward(self, x, mask=None, qk_bias=None):
|
||||
|
@ -572,7 +578,7 @@ class AttentionBlock(nn.Module):
|
|||
b, c, *spatial = x.shape
|
||||
if mask is not None:
|
||||
if len(mask.shape) == 2:
|
||||
mask = mask.unsqueeze(0).repeat(x.shape[0],1,1)
|
||||
mask = mask.unsqueeze(0).repeat(x.shape[0], 1, 1)
|
||||
if mask.shape[1] != x.shape[-1]:
|
||||
mask = mask[:, :x.shape[-1], :x.shape[-1]]
|
||||
|
||||
|
@ -606,7 +612,8 @@ class QKVAttentionLegacy(nn.Module):
|
|||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3,
|
||||
length).split(ch, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = torch.einsum(
|
||||
"bct,bcs->bts", q * scale, k * scale
|
||||
|
@ -651,7 +658,8 @@ class QKVAttention(nn.Module):
|
|||
mask = mask.repeat(self.n_heads, 1, 1)
|
||||
weight[mask.logical_not()] = -torch.inf
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
||||
a = torch.einsum("bts,bcs->bct", weight,
|
||||
v.reshape(bs * self.n_heads, ch, length))
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
|
||||
|
@ -678,7 +686,8 @@ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
|
|||
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
|
||||
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
|
||||
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
||||
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
|
||||
output = F.grid_sample(
|
||||
x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
|
||||
return output
|
||||
|
||||
|
||||
|
@ -690,7 +699,8 @@ class PixelUnshuffle(nn.Module):
|
|||
def forward(self, x):
|
||||
(b, f, w, h) = x.shape
|
||||
x = x.contiguous().view(b, f, w // self.r, self.r, h // self.r, self.r)
|
||||
x = x.permute(0, 1, 3, 5, 2, 4).contiguous().view(b, f * (self.r ** 2), w // self.r, h // self.r)
|
||||
x = x.permute(0, 1, 3, 5, 2, 4).contiguous().view(
|
||||
b, f * (self.r ** 2), w // self.r, h // self.r)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -704,6 +714,8 @@ def silu(input):
|
|||
|
||||
# create a class wrapper from PyTorch nn.Module, so
|
||||
# the function now can be easily used in models
|
||||
|
||||
|
||||
class SiLU(nn.Module):
|
||||
'''
|
||||
Applies the Sigmoid Linear Unit (SiLU) function element-wise:
|
||||
|
@ -720,11 +732,12 @@ class SiLU(nn.Module):
|
|||
>>> input = torch.randn(2)
|
||||
>>> output = m(input)
|
||||
'''
|
||||
|
||||
def __init__(self):
|
||||
'''
|
||||
Init method.
|
||||
'''
|
||||
super().__init__() # init the base class
|
||||
super().__init__() # init the base class
|
||||
|
||||
def forward(self, input):
|
||||
'''
|
||||
|
@ -735,12 +748,15 @@ class SiLU(nn.Module):
|
|||
|
||||
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
||||
kernel sizes. '''
|
||||
|
||||
|
||||
class ConvBnRelu(nn.Module):
|
||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True):
|
||||
super(ConvBnRelu, self).__init__()
|
||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||
assert kernel_size in padding_map.keys()
|
||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size,
|
||||
stride, padding_map[kernel_size], bias=bias)
|
||||
if norm:
|
||||
self.bn = nn.BatchNorm2d(filters_out)
|
||||
else:
|
||||
|
@ -753,7 +769,8 @@ class ConvBnRelu(nn.Module):
|
|||
# Init params.
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
@ -770,12 +787,15 @@ class ConvBnRelu(nn.Module):
|
|||
|
||||
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
||||
kernel sizes. '''
|
||||
|
||||
|
||||
class ConvBnSilu(nn.Module):
|
||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1):
|
||||
super(ConvBnSilu, self).__init__()
|
||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||
assert kernel_size in padding_map.keys()
|
||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size,
|
||||
stride, padding_map[kernel_size], bias=bias)
|
||||
if norm:
|
||||
self.bn = nn.BatchNorm2d(filters_out)
|
||||
else:
|
||||
|
@ -788,7 +808,8 @@ class ConvBnSilu(nn.Module):
|
|||
# Init params.
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
|
||||
m.weight.data *= weight_init_factor
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
@ -808,12 +829,15 @@ class ConvBnSilu(nn.Module):
|
|||
|
||||
''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard
|
||||
kernel sizes. '''
|
||||
|
||||
|
||||
class ConvBnLelu(nn.Module):
|
||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1):
|
||||
super(ConvBnLelu, self).__init__()
|
||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||
assert kernel_size in padding_map.keys()
|
||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size,
|
||||
stride, padding_map[kernel_size], bias=bias)
|
||||
if norm:
|
||||
self.bn = nn.BatchNorm2d(filters_out)
|
||||
else:
|
||||
|
@ -847,12 +871,15 @@ class ConvBnLelu(nn.Module):
|
|||
|
||||
''' Convenience class with Conv->GroupNorm->LeakyReLU. Includes weight initialization and auto-padding for standard
|
||||
kernel sizes. '''
|
||||
|
||||
|
||||
class ConvGnLelu(nn.Module):
|
||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1):
|
||||
super(ConvGnLelu, self).__init__()
|
||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||
assert kernel_size in padding_map.keys()
|
||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size,
|
||||
stride, padding_map[kernel_size], bias=bias)
|
||||
if norm:
|
||||
self.gn = nn.GroupNorm(num_groups, filters_out)
|
||||
else:
|
||||
|
@ -886,12 +913,15 @@ class ConvGnLelu(nn.Module):
|
|||
|
||||
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
|
||||
kernel sizes. '''
|
||||
|
||||
|
||||
class ConvGnSilu(nn.Module):
|
||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, convnd=nn.Conv2d):
|
||||
super(ConvGnSilu, self).__init__()
|
||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||
assert kernel_size in padding_map.keys()
|
||||
self.conv = convnd(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||
self.conv = convnd(filters_in, filters_out, kernel_size,
|
||||
stride, padding_map[kernel_size], bias=bias)
|
||||
if norm:
|
||||
self.gn = nn.GroupNorm(num_groups, filters_out)
|
||||
else:
|
||||
|
@ -904,7 +934,8 @@ class ConvGnSilu(nn.Module):
|
|||
# Init params.
|
||||
for m in self.modules():
|
||||
if isinstance(m, convnd):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
|
||||
m.weight.data *= weight_init_factor
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
@ -924,12 +955,15 @@ class ConvGnSilu(nn.Module):
|
|||
|
||||
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
|
||||
kernel sizes. '''
|
||||
|
||||
|
||||
class ConvBnRelu(nn.Module):
|
||||
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1):
|
||||
super(ConvBnRelu, self).__init__()
|
||||
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
|
||||
assert kernel_size in padding_map.keys()
|
||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
|
||||
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size,
|
||||
stride, padding_map[kernel_size], bias=bias)
|
||||
if norm:
|
||||
self.bn = nn.BatchNorm2d(filters_out)
|
||||
else:
|
||||
|
@ -942,7 +976,8 @@ class ConvBnRelu(nn.Module):
|
|||
# Init params.
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
|
||||
m.weight.data *= weight_init_factor
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
@ -969,7 +1004,8 @@ class MultiConvBlock(nn.Module):
|
|||
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor)] +
|
||||
[ConvBnLelu(filters_mid, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor) for i in range(depth - 2)] +
|
||||
[ConvBnLelu(filters_mid, filters_out, kernel_size, activation=False, norm=False, bias=False, weight_init_factor=weight_init_factor)])
|
||||
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init, dtype=torch.float))
|
||||
self.scale = nn.Parameter(torch.full(
|
||||
(1,), fill_value=scale_init, dtype=torch.float))
|
||||
self.bias = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, x, noise=None):
|
||||
|
@ -988,10 +1024,14 @@ class ExpansionBlock(nn.Module):
|
|||
super(ExpansionBlock, self).__init__()
|
||||
if filters_out is None:
|
||||
filters_out = filters_in // 2
|
||||
self.decimate = block(filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True)
|
||||
self.process_passthrough = block(filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True)
|
||||
self.conjoin = block(filters_out*2, filters_out, kernel_size=3, bias=False, activation=True, norm=False)
|
||||
self.process = block(filters_out, filters_out, kernel_size=3, bias=False, activation=True, norm=True)
|
||||
self.decimate = block(
|
||||
filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True)
|
||||
self.process_passthrough = block(
|
||||
filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True)
|
||||
self.conjoin = block(filters_out*2, filters_out,
|
||||
kernel_size=3, bias=False, activation=True, norm=False)
|
||||
self.process = block(filters_out, filters_out,
|
||||
kernel_size=3, bias=False, activation=True, norm=True)
|
||||
|
||||
# input is the feature signal with shape (b, f, w, h)
|
||||
# passthrough is the structure signal with shape (b, f/2, w*2, h*2)
|
||||
|
@ -1012,10 +1052,14 @@ class ExpansionBlock2(nn.Module):
|
|||
super(ExpansionBlock2, self).__init__()
|
||||
if filters_out is None:
|
||||
filters_out = filters_in // 2
|
||||
self.decimate = block(filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True)
|
||||
self.process_passthrough = block(filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True)
|
||||
self.conjoin = block(filters_out*2, filters_out*2, kernel_size=3, bias=False, activation=True, norm=False)
|
||||
self.reduce = block(filters_out*2, filters_out, kernel_size=3, bias=False, activation=True, norm=True)
|
||||
self.decimate = block(
|
||||
filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True)
|
||||
self.process_passthrough = block(
|
||||
filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True)
|
||||
self.conjoin = block(filters_out*2, filters_out*2,
|
||||
kernel_size=3, bias=False, activation=True, norm=False)
|
||||
self.reduce = block(filters_out*2, filters_out,
|
||||
kernel_size=3, bias=False, activation=True, norm=True)
|
||||
|
||||
# input is the feature signal with shape (b, f, w, h)
|
||||
# passthrough is the structure signal with shape (b, f/2, w*2, h*2)
|
||||
|
@ -1036,8 +1080,10 @@ class ConjoinBlock(nn.Module):
|
|||
filters_out = filters_in
|
||||
if filters_pt is None:
|
||||
filters_pt = filters_in
|
||||
self.process = block(filters_in + filters_pt, filters_in + filters_pt, kernel_size=3, bias=False, activation=True, norm=norm)
|
||||
self.decimate = block(filters_in + filters_pt, filters_out, kernel_size=1, bias=False, activation=False, norm=norm)
|
||||
self.process = block(filters_in + filters_pt, filters_in + filters_pt,
|
||||
kernel_size=3, bias=False, activation=True, norm=norm)
|
||||
self.decimate = block(filters_in + filters_pt, filters_out,
|
||||
kernel_size=1, bias=False, activation=False, norm=norm)
|
||||
|
||||
def forward(self, input, passthrough):
|
||||
x = torch.cat([input, passthrough], dim=1)
|
||||
|
@ -1053,7 +1099,8 @@ class ReferenceJoinBlock(nn.Module):
|
|||
scale_init=residual_weight_init_factor, norm=False,
|
||||
weight_init_factor=residual_weight_init_factor)
|
||||
if join:
|
||||
self.join_conv = block(nf, nf, kernel_size=kernel_size, norm=final_norm, bias=False, activation=True)
|
||||
self.join_conv = block(
|
||||
nf, nf, kernel_size=kernel_size, norm=final_norm, bias=False, activation=True)
|
||||
else:
|
||||
self.join_conv = None
|
||||
|
||||
|
@ -1070,7 +1117,8 @@ class ReferenceJoinBlock(nn.Module):
|
|||
class UpconvBlock(nn.Module):
|
||||
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True, activation=True, bias=False):
|
||||
super(UpconvBlock, self).__init__()
|
||||
self.process = block(filters_in, filters_out, kernel_size=3, bias=bias, activation=activation, norm=norm)
|
||||
self.process = block(filters_in, filters_out, kernel_size=3,
|
||||
bias=bias, activation=activation, norm=norm)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
|
@ -1083,21 +1131,29 @@ class FinalUpsampleBlock2x(nn.Module):
|
|||
super(FinalUpsampleBlock2x, self).__init__()
|
||||
if scale == 2:
|
||||
self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True),
|
||||
UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True),
|
||||
block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True),
|
||||
UpconvBlock(
|
||||
nf, nf // 2, block=block, norm=False, activation=True, bias=True),
|
||||
block(nf // 2, nf // 2, kernel_size=3,
|
||||
norm=False, activation=False, bias=True),
|
||||
block(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False))
|
||||
else:
|
||||
self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True),
|
||||
UpconvBlock(nf, nf, block=block, norm=False, activation=True, bias=True),
|
||||
block(nf, nf, kernel_size=3, norm=False, activation=False, bias=True),
|
||||
UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True),
|
||||
block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True),
|
||||
UpconvBlock(
|
||||
nf, nf, block=block, norm=False, activation=True, bias=True),
|
||||
block(nf, nf, kernel_size=3, norm=False,
|
||||
activation=False, bias=True),
|
||||
UpconvBlock(
|
||||
nf, nf // 2, block=block, norm=False, activation=True, bias=True),
|
||||
block(nf // 2, nf // 2, kernel_size=3,
|
||||
norm=False, activation=False, bias=True),
|
||||
block(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False))
|
||||
|
||||
def forward(self, x):
|
||||
return self.chain(x)
|
||||
|
||||
# torch.gather() which operates as it always fucking should have: pulling indexes from the input.
|
||||
|
||||
|
||||
def gather_2d(input, index):
|
||||
b, c, h, w = input.shape
|
||||
nodim = input.view(b, c, h * w)
|
||||
|
|
|
@ -3,13 +3,14 @@ from itertools import groupby
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import Wav2Vec2ForCTC
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Attention, Wav2Vec2Model
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import (Wav2Vec2Attention,
|
||||
Wav2Vec2Model)
|
||||
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from models.audio.tts.tacotron2.text import sequence_to_text
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
import torch_intermediary as ml
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.data.audio.unsupervised_audio_dataset import load_audio
|
||||
from dlas.models.audio.tts.tacotron2.text import sequence_to_text
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
def only_letters(string):
|
||||
|
@ -22,6 +23,7 @@ class Wav2VecFeatureExtractor(nn.Module):
|
|||
Basic wrapper that only does feature extraction. Useful to build out this portion of the model so it can be
|
||||
operated through DDP.
|
||||
"""
|
||||
|
||||
def __init__(self, basis_model='facebook/wav2vec2-large'):
|
||||
super().__init__()
|
||||
w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
|
||||
|
@ -34,7 +36,8 @@ class Wav2VecFeatureExtractor(nn.Module):
|
|||
def forward(self, audio, wav_lengths):
|
||||
with torch.no_grad():
|
||||
audio = audio[:, :, :wav_lengths.max()]
|
||||
audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
|
||||
audio_norm = (audio - audio.mean()) / \
|
||||
torch.sqrt(audio.var() + 1e-7)
|
||||
return self.extractor(audio_norm.squeeze(1))
|
||||
|
||||
|
||||
|
@ -42,13 +45,14 @@ class Wav2VecWrapper(nn.Module):
|
|||
"""
|
||||
Basic wrapper class that makes Wav2Vec2 usable by DLAS.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True,
|
||||
checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True,
|
||||
remove_feature_extractor=False, ramp_dropout_mode=False, ramp_dropout_end=20000, ramp_dropout_min=.1,
|
||||
ramp_dropout_max=.5, layer_drop_pct=.1):
|
||||
super().__init__()
|
||||
self.provide_attention_mask = provide_attention_mask
|
||||
|
||||
|
||||
self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
|
||||
# Perform some surgery to get the model we actually want.
|
||||
self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled
|
||||
|
@ -99,7 +103,8 @@ class Wav2VecWrapper(nn.Module):
|
|||
unaligned_tokens[b, text_lengths[b]:] = -100
|
||||
|
||||
model_inp = fea_extractor if self.remove_feature_extractor else audio
|
||||
outputs = self.w2v(input_values=model_inp, attention_mask=attention_mask, labels=unaligned_tokens)
|
||||
outputs = self.w2v(
|
||||
input_values=model_inp, attention_mask=attention_mask, labels=unaligned_tokens)
|
||||
|
||||
if self.output_wer:
|
||||
self.last_pred.append(torch.argmax(outputs.logits, dim=-1))
|
||||
|
@ -126,11 +131,15 @@ class Wav2VecWrapper(nn.Module):
|
|||
pred_strings = []
|
||||
for last_labels, last_pred in zip(self.last_labels, self.last_pred):
|
||||
last_labels[last_labels == -100] = 0
|
||||
label_strings.extend([only_letters(sequence_to_text(lbl)) for lbl in last_labels])
|
||||
pred_strings.extend([only_letters(sequence_to_text(self.decode_ctc(pred))) for pred in last_pred])
|
||||
wer = wer_metric.compute(predictions=pred_strings, references=label_strings)
|
||||
label_strings.extend(
|
||||
[only_letters(sequence_to_text(lbl)) for lbl in last_labels])
|
||||
pred_strings.extend([only_letters(sequence_to_text(
|
||||
self.decode_ctc(pred))) for pred in last_pred])
|
||||
wer = wer_metric.compute(
|
||||
predictions=pred_strings, references=label_strings)
|
||||
res['wer'] = wer
|
||||
print(f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}")
|
||||
print(
|
||||
f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}")
|
||||
if self.ramp_dropout_mode:
|
||||
res['dropout_rate'] = self.current_dropout_rate
|
||||
return res
|
||||
|
@ -149,7 +158,8 @@ class Wav2VecWrapper(nn.Module):
|
|||
def update_for_step(self, step, *args):
|
||||
if self.ramp_dropout_mode and step % 10 == 0:
|
||||
dropout_gap = self.ramp_dropout_max - self.ramp_dropout_min
|
||||
new_dropout_rate = self.ramp_dropout_min + dropout_gap * min(step / self.ramp_dropout_end, 1)
|
||||
new_dropout_rate = self.ramp_dropout_min + \
|
||||
dropout_gap * min(step / self.ramp_dropout_end, 1)
|
||||
self.current_dropout_rate = new_dropout_rate
|
||||
for name, module in self.w2v.named_modules():
|
||||
if isinstance(module, nn.Dropout):
|
||||
|
@ -187,14 +197,18 @@ def register_wav2vec2(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
fe = Wav2VecFeatureExtractor(basis_model='facebook/wav2vec2-large-960h')
|
||||
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True, ramp_dropout_mode=True)
|
||||
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h',
|
||||
freeze_transformer=True, remove_feature_extractor=True, ramp_dropout_mode=True)
|
||||
w2v.update_for_step(8000)
|
||||
fea = fe(torch.randn(2,1,50000), torch.tensor([20000, 30000]))
|
||||
loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]), fea)
|
||||
w2v.get_debug_values(0,"")
|
||||
fea = fe(torch.randn(2, 1, 50000), torch.tensor([20000, 30000]))
|
||||
loss = w2v(torch.randn(2, 1, 50000), torch.randint(0, 40, (2, 70)),
|
||||
torch.tensor([20000, 30000]), torch.tensor([35, 50]), fea)
|
||||
w2v.get_debug_values(0, "")
|
||||
|
||||
sd = torch.load('../experiments/train_wav2vec_mass_archived_r0/models/19500_wav2vec.pth')
|
||||
sd = torch.load(
|
||||
'../experiments/train_wav2vec_mass_archived_r0/models/19500_wav2vec.pth')
|
||||
w2v.load_state_dict(sd)
|
||||
pred = w2v.inference(load_audio('Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav', 16000).unsqueeze(0))
|
||||
pred = w2v.inference(load_audio(
|
||||
'Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav', 16000).unsqueeze(0))
|
||||
res = sequence_to_text(pred[0])
|
||||
print(res)
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
from typing import Any, Callable, List, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
from typing import Type, Any, Callable, Union, List, Optional
|
||||
import torch_intermediary as ml
|
||||
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||
|
@ -42,9 +42,11 @@ class BasicBlock(nn.Module):
|
|||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm1d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
raise ValueError(
|
||||
'BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
raise NotImplementedError(
|
||||
"Dilation > 1 not supported in BasicBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
|
@ -177,7 +179,8 @@ class ResNet(nn.Module):
|
|||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
@ -188,9 +191,11 @@ class ResNet(nn.Module):
|
|||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
|
||||
# type: ignore[arg-type]
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
|
||||
# type: ignore[arg-type]
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
|
||||
stride: int = 1, dilate: bool = False) -> nn.Sequential:
|
||||
|
@ -386,5 +391,5 @@ def register_audio_resnet(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
m = resnet34()
|
||||
o = m(torch.randn((1,1,48000)))
|
||||
print(o.shape)
|
||||
o = m(torch.randn((1, 1, 48000)))
|
||||
print(o.shape)
|
||||
|
|
|
@ -9,13 +9,14 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import distributed
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
||||
_compute_mask_indices, _sample_negative_indices)
|
||||
|
||||
from models.arch_util import ResBlock
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
import torch_intermediary as ml
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.arch_util import ResBlock
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import checkpoint
|
||||
|
||||
|
||||
class Mel2Vec2FeatureProjection(nn.Module):
|
||||
|
@ -243,6 +244,8 @@ class Wav2Vec2SamePadLayer(nn.Module):
|
|||
|
||||
|
||||
from torch.nn.utils.weight_norm import WeightNorm
|
||||
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
# save and delete all weightnorm weights on self
|
||||
weights = {}
|
||||
|
|
|
@ -2,13 +2,13 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.arch_util import AttentionBlock, ResBlock
|
||||
from models.audio.tts.lucidrains_dvae import DiscreteVAE
|
||||
from models.lucidrains.x_transformers import Encoder
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get, ceil_multiple, print_network
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.arch_util import AttentionBlock, ResBlock
|
||||
from dlas.models.audio.tts.lucidrains_dvae import DiscreteVAE
|
||||
from dlas.models.lucidrains.x_transformers import Encoder
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import ceil_multiple, opt_get, print_network
|
||||
|
||||
|
||||
class ConditioningEncoder(nn.Module):
|
||||
|
@ -22,23 +22,23 @@ class ConditioningEncoder(nn.Module):
|
|||
super().__init__()
|
||||
self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1)
|
||||
self.attn = Encoder(
|
||||
dim=embedding_dim,
|
||||
depth=attn_blocks,
|
||||
heads=num_attn_heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
zero_init_branch_output=True,
|
||||
ff_mult=2,
|
||||
do_checkpointing=do_checkpointing
|
||||
)
|
||||
dim=embedding_dim,
|
||||
depth=attn_blocks,
|
||||
heads=num_attn_heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
zero_init_branch_output=True,
|
||||
ff_mult=2,
|
||||
do_checkpointing=do_checkpointing
|
||||
)
|
||||
self.dim = embedding_dim
|
||||
|
||||
def forward(self, x):
|
||||
h = self.init(x).permute(0,2,1)
|
||||
h = self.attn(h).permute(0,2,1)
|
||||
h = self.init(x).permute(0, 2, 1)
|
||||
h = self.attn(h).permute(0, 2, 1)
|
||||
return h.mean(-1)
|
||||
|
||||
|
||||
|
@ -46,7 +46,7 @@ class ConditioningAR(nn.Module):
|
|||
def __init__(self, dim, layers, dropout=0, num_vectors=8192, cond_free_percent=.15, fp16=False):
|
||||
super().__init__()
|
||||
self.cond_encoder = ConditioningEncoder(256, dim)
|
||||
self.cond_free_emb = nn.Parameter(torch.randn(1,dim))
|
||||
self.cond_free_emb = nn.Parameter(torch.randn(1, dim))
|
||||
self.unconditioned_percentage = cond_free_percent
|
||||
self.fp16 = fp16
|
||||
|
||||
|
@ -65,13 +65,16 @@ class ConditioningAR(nn.Module):
|
|||
|
||||
cond = self.cond_encoder(conditioning)
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((cond.shape[0],1), device=cond.device) < self.unconditioned_percentage
|
||||
cond = torch.where(unconditioned_batches, self.cond_free_emb.repeat(cond.shape[0],1), cond)
|
||||
unconditioned_batches = torch.rand(
|
||||
(cond.shape[0], 1), device=cond.device) < self.unconditioned_percentage
|
||||
cond = torch.where(unconditioned_batches,
|
||||
self.cond_free_emb.repeat(cond.shape[0], 1), cond)
|
||||
unused_params.append(self.cond_free_emb)
|
||||
|
||||
h = self.embeddings(cheater_codes)
|
||||
h = torch.cat([cond.unsqueeze(1), h], dim=1)
|
||||
targets = cheater_codes # Since we padded above by 1, the input alignment works.
|
||||
# Since we padded above by 1, the input alignment works.
|
||||
targets = cheater_codes
|
||||
|
||||
with torch.autocast(cheater_codes.device.type, enabled=self.fp16):
|
||||
h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state
|
||||
|
@ -79,12 +82,13 @@ class ConditioningAR(nn.Module):
|
|||
if return_latent:
|
||||
return h.float()
|
||||
|
||||
logits = self.head(h[:,:-1]).permute(0,2,1)
|
||||
logits = self.head(h[:, :-1]).permute(0, 2, 1)
|
||||
loss = F.cross_entropy(logits, targets, reduction="none")
|
||||
|
||||
# Perform masking
|
||||
if code_lengths is not None:
|
||||
mask = torch.arange(0, loss.shape[1], device=h.device).unsqueeze(0).repeat(loss.shape[0], 1) < code_lengths.unsqueeze(1)
|
||||
mask = torch.arange(0, loss.shape[1], device=h.device).unsqueeze(
|
||||
0).repeat(loss.shape[0], 1) < code_lengths.unsqueeze(1)
|
||||
loss = loss * mask
|
||||
loss = loss.mean()
|
||||
|
||||
|
@ -114,14 +118,13 @@ def test_ar():
|
|||
model = ConditioningAR(512, 8, cond_free_percent=.5)
|
||||
print_network(model)
|
||||
|
||||
codes = torch.randint(0,8192, (2,400))
|
||||
cond = torch.randn(2,256,400)
|
||||
cl = torch.tensor([200,10])
|
||||
codes[1,10:] = 2
|
||||
codes = torch.randint(0, 8192, (2, 400))
|
||||
cond = torch.randn(2, 256, 400)
|
||||
cl = torch.tensor([200, 10])
|
||||
codes[1, 10:] = 2
|
||||
model(codes, cond, cl)
|
||||
pg = model.get_grad_norm_parameter_groups()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_ar()
|
||||
|
|
|
@ -13,17 +13,16 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from math import sqrt
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_intermediary as ml
|
||||
|
||||
from math import sqrt
|
||||
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from trainer.networks import register_model
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.trainer.networks import register_model
|
||||
|
||||
Linear = ml.Linear
|
||||
ConvTranspose2d = nn.ConvTranspose2d
|
||||
|
@ -43,7 +42,8 @@ def silu(x):
|
|||
class DiffusionEmbedding(nn.Module):
|
||||
def __init__(self, max_steps):
|
||||
super().__init__()
|
||||
self.register_buffer('embedding', self._build_embedding(max_steps), persistent=False)
|
||||
self.register_buffer('embedding', self._build_embedding(
|
||||
max_steps), persistent=False)
|
||||
self.projection1 = Linear(128, 512)
|
||||
self.projection2 = Linear(512, 512)
|
||||
|
||||
|
@ -76,8 +76,10 @@ class DiffusionEmbedding(nn.Module):
|
|||
class SpectrogramUpsampler(nn.Module):
|
||||
def __init__(self, n_mels):
|
||||
super().__init__()
|
||||
self.conv1 = ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8])
|
||||
self.conv2 = ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8])
|
||||
self.conv1 = ConvTranspose2d(1, 1, [3, 32], stride=[
|
||||
1, 16], padding=[1, 8])
|
||||
self.conv2 = ConvTranspose2d(1, 1, [3, 32], stride=[
|
||||
1, 16], padding=[1, 8])
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.unsqueeze(x, 1)
|
||||
|
@ -98,27 +100,32 @@ class ResidualBlock(nn.Module):
|
|||
:param uncond: disable spectrogram conditional
|
||||
'''
|
||||
super().__init__()
|
||||
self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation)
|
||||
self.dilated_conv = Conv1d(
|
||||
residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation)
|
||||
self.diffusion_projection = Linear(512, residual_channels)
|
||||
if not uncond: # conditional model
|
||||
self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1)
|
||||
self.conditioner_projection = Conv1d(
|
||||
n_mels, 2 * residual_channels, 1)
|
||||
else: # unconditional model
|
||||
self.conditioner_projection = None
|
||||
|
||||
self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)
|
||||
self.output_projection = Conv1d(
|
||||
residual_channels, 2 * residual_channels, 1)
|
||||
|
||||
def forward(self, x, diffusion_step, conditioner=None):
|
||||
assert (conditioner is None and self.conditioner_projection is None) or \
|
||||
(conditioner is not None and self.conditioner_projection is not None)
|
||||
|
||||
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
|
||||
diffusion_step = self.diffusion_projection(
|
||||
diffusion_step).unsqueeze(-1)
|
||||
y = x + diffusion_step
|
||||
if self.conditioner_projection is None: # using a unconditional model
|
||||
y = self.dilated_conv(y)
|
||||
else:
|
||||
y = self.dilated_conv(y)
|
||||
conditioner = self.conditioner_projection(conditioner)
|
||||
conditioner = F.interpolate(conditioner, size=y.shape[-1], mode='nearest')
|
||||
conditioner = F.interpolate(
|
||||
conditioner, size=y.shape[-1], mode='nearest')
|
||||
y = y + conditioner
|
||||
|
||||
gate, filter = torch.chunk(y, 2, dim=1)
|
||||
|
@ -141,7 +148,8 @@ class DiffWave(nn.Module):
|
|||
self.spectrogram_upsampler = SpectrogramUpsampler(n_mels)
|
||||
|
||||
self.residual_layers = nn.ModuleList([
|
||||
ResidualBlock(n_mels, residual_channels, 2 ** (i % dilation_cycle_length), uncond=unconditional)
|
||||
ResidualBlock(n_mels, residual_channels, 2 ** (i %
|
||||
dilation_cycle_length), uncond=unconditional)
|
||||
for i in range(residual_layers)
|
||||
])
|
||||
self.skip_projection = Conv1d(residual_channels, residual_channels, 1)
|
||||
|
@ -177,4 +185,5 @@ def register_diffwave(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
model = DiffWave()
|
||||
model(torch.randn(2,1,65536), torch.tensor([500,3999]), torch.randn(2,128,256))
|
||||
model(torch.randn(2, 1, 65536), torch.tensor(
|
||||
[500, 3999]), torch.randn(2, 128, 256))
|
||||
|
|
|
@ -3,10 +3,10 @@ import torch.nn.functional as F
|
|||
from torch import nn
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
|
||||
from models.arch_util import AttentionBlock, ResBlock
|
||||
from models.audio.tts.lucidrains_dvae import DiscreteVAE
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get, ceil_multiple, print_network
|
||||
from dlas.models.arch_util import AttentionBlock, ResBlock
|
||||
from dlas.models.audio.tts.lucidrains_dvae import DiscreteVAE
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import ceil_multiple, opt_get, print_network
|
||||
|
||||
|
||||
class ResEncoder16x(nn.Module):
|
||||
|
@ -18,20 +18,29 @@ class ResEncoder16x(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
attn = []
|
||||
|
||||
def edim(m):
|
||||
dd = min(spec_dim + m * 128, hidden_dim)
|
||||
return ceil_multiple(dd, 8)
|
||||
self.downsampler = nn.Sequential(
|
||||
ResBlock(spec_dim, out_channels=edim(2), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(3), out_channels=edim(3), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(4), out_channels=edim(4), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(spec_dim, out_channels=edim(2), use_conv=True, dims=1,
|
||||
down=True, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1,
|
||||
down=True, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(3), out_channels=edim(3), use_conv=True,
|
||||
dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1,
|
||||
down=True, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(4), out_channels=edim(4), use_conv=True,
|
||||
dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(4), out_channels=hidden_dim, use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled))
|
||||
self.encoder = nn.Sequential(
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True,
|
||||
dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True,
|
||||
dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True,
|
||||
dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
nn.GroupNorm(8, hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Conv1d(hidden_dim, embedding_dim, 1),
|
||||
|
|
|
@ -4,18 +4,22 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import autocast
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.arch_util import ResBlock
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, TimestepBlock
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.arch_util import ResBlock
|
||||
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
|
||||
timestep_embedding, zero_module)
|
||||
from dlas.models.diffusion.unet_diffusion import (AttentionBlock,
|
||||
TimestepBlock,
|
||||
TimestepEmbedSequential)
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import checkpoint
|
||||
|
||||
|
||||
def is_latent(t):
|
||||
return t.dtype == torch.float
|
||||
|
||||
|
||||
def is_sequence(t):
|
||||
return t.dtype == torch.long
|
||||
|
||||
|
@ -24,7 +28,8 @@ class MultiGroupEmbedding(nn.Module):
|
|||
def __init__(self, tokens, groups, dim):
|
||||
super().__init__()
|
||||
# nn.Embedding
|
||||
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
|
||||
self.m = nn.ModuleList(
|
||||
[ml.Embedding(tokens, dim // groups) for _ in range(groups)])
|
||||
|
||||
def forward(self, x):
|
||||
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
||||
|
@ -56,7 +61,8 @@ class TimestepResBlock(TimestepBlock):
|
|||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding),
|
||||
conv_nd(dims, channels, self.out_channels,
|
||||
eff_kernel, padding=eff_padding),
|
||||
)
|
||||
|
||||
self.emb_layers = nn.Sequential(
|
||||
|
@ -71,14 +77,16 @@ class TimestepResBlock(TimestepBlock):
|
|||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
|
||||
conv_nd(dims, self.out_channels, self.out_channels,
|
||||
kernel_size, padding=padding)
|
||||
),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
|
@ -111,8 +119,10 @@ class TimestepResBlock(TimestepBlock):
|
|||
class DiffusionLayer(TimestepBlock):
|
||||
def __init__(self, model_channels, dropout, num_heads):
|
||||
super().__init__()
|
||||
self.resblk = TimestepResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True)
|
||||
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
|
||||
self.resblk = TimestepResBlock(
|
||||
model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True)
|
||||
self.attn = AttentionBlock(
|
||||
model_channels, num_heads, relative_pos_embeddings=True)
|
||||
|
||||
def forward(self, x, time_emb):
|
||||
y = self.resblk(x, time_emb)
|
||||
|
@ -134,7 +144,8 @@ class FlatDiffusion(nn.Module):
|
|||
num_heads=8,
|
||||
# Parameters for regularization.
|
||||
layer_drop=.1,
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
# This implements a mechanism similar to what is used in classifier-free training.
|
||||
unconditioned_percentage=.1,
|
||||
train_mel_head=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -163,37 +174,52 @@ class FlatDiffusion(nn.Module):
|
|||
# nn.Embedding
|
||||
self.embeddings = ml.Embedding(token_count, model_channels)
|
||||
else:
|
||||
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
|
||||
self.embeddings = MultiGroupEmbedding(
|
||||
token_count, in_groups, model_channels)
|
||||
self.latent_conditioner = nn.Sequential(
|
||||
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads,
|
||||
relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads,
|
||||
relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads,
|
||||
relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads,
|
||||
relative_pos_embeddings=True),
|
||||
)
|
||||
self.code_converter = nn.Sequential(
|
||||
ResBlock(dims=1, channels=model_channels, dropout=dropout),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads,
|
||||
relative_pos_embeddings=True),
|
||||
ResBlock(dims=1, channels=model_channels, dropout=dropout),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads,
|
||||
relative_pos_embeddings=True),
|
||||
ResBlock(dims=1, channels=model_channels, dropout=dropout),
|
||||
)
|
||||
self.code_norm = normalization(model_channels)
|
||||
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
|
||||
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2),
|
||||
nn.Conv1d(
|
||||
model_channels, model_channels*2, 3, padding=1, stride=2),
|
||||
AttentionBlock(
|
||||
model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(
|
||||
model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(
|
||||
model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(
|
||||
model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False))
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
|
||||
self.unconditioned_embedding = nn.Parameter(
|
||||
torch.randn(1, model_channels, 1))
|
||||
self.conditioning_timestep_integrator = TimestepEmbedSequential(
|
||||
DiffusionLayer(model_channels, dropout, num_heads),
|
||||
DiffusionLayer(model_channels, dropout, num_heads),
|
||||
DiffusionLayer(model_channels, dropout, num_heads),
|
||||
)
|
||||
self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1)
|
||||
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
||||
self.integrating_conv = nn.Conv1d(
|
||||
model_channels*2, model_channels, kernel_size=1)
|
||||
self.mel_head = nn.Conv1d(
|
||||
model_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
|
||||
[TimestepResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])
|
||||
|
@ -201,7 +227,8 @@ class FlatDiffusion(nn.Module):
|
|||
self.out = nn.Sequential(
|
||||
normalization(model_channels),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
|
||||
zero_module(conv_nd(1, model_channels,
|
||||
out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
if train_mel_head:
|
||||
|
@ -231,7 +258,8 @@ class FlatDiffusion(nn.Module):
|
|||
conditioning_input.shape) == 3 else conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
|
||||
conds.append(self.contextual_embedder(
|
||||
speech_conditioning_input[:, j]))
|
||||
conds = torch.cat(conds, dim=-1)
|
||||
cond_emb = conds.mean(dim=-1)
|
||||
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
|
||||
|
@ -239,18 +267,21 @@ class FlatDiffusion(nn.Module):
|
|||
code_emb = self.latent_conditioner(aligned_conditioning)
|
||||
else:
|
||||
code_emb = self.embeddings(aligned_conditioning)
|
||||
code_emb = code_emb.permute(0,2,1)
|
||||
code_emb = code_emb.permute(0, 2, 1)
|
||||
|
||||
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
||||
unconditioned_batches = torch.zeros(
|
||||
(code_emb.shape[0], 1, 1), device=code_emb.device)
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
||||
device=code_emb.device) < self.unconditioned_percentage
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
|
||||
code_emb)
|
||||
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
|
||||
expanded_code_emb = F.interpolate(
|
||||
code_emb, size=expected_seq_len, mode='nearest')
|
||||
expanded_code_emb = self.code_converter(expanded_code_emb)
|
||||
expanded_code_emb = self.code_norm(expanded_code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
|
||||
expanded_code_emb = self.code_norm(
|
||||
expanded_code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
|
||||
|
||||
if not return_code_pred:
|
||||
return expanded_code_emb
|
||||
|
@ -260,7 +291,6 @@ class FlatDiffusion(nn.Module):
|
|||
mel_pred = mel_pred * unconditioned_batches.logical_not()
|
||||
return expanded_code_emb, mel_pred
|
||||
|
||||
|
||||
def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
@ -273,27 +303,36 @@ class FlatDiffusion(nn.Module):
|
|||
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert precomputed_aligned_embeddings is not None or (codes is not None and conditioning_input is not None)
|
||||
assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
|
||||
assert precomputed_aligned_embeddings is not None or (
|
||||
codes is not None and conditioning_input is not None)
|
||||
# These two are mutually exclusive.
|
||||
assert not (
|
||||
return_code_pred and precomputed_aligned_embeddings is not None)
|
||||
|
||||
unused_params = []
|
||||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
||||
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||
code_emb = self.unconditioned_embedding.repeat(
|
||||
x.shape[0], 1, x.shape[-1])
|
||||
unused_params.extend(
|
||||
list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||
unused_params.extend(list(self.latent_conditioner.parameters()))
|
||||
else:
|
||||
if precomputed_aligned_embeddings is not None:
|
||||
code_emb = precomputed_aligned_embeddings
|
||||
else:
|
||||
code_emb, mel_pred = self.timestep_independent(codes, conditioning_input, x.shape[-1], True)
|
||||
code_emb, mel_pred = self.timestep_independent(
|
||||
codes, conditioning_input, x.shape[-1], True)
|
||||
if is_latent(codes):
|
||||
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||
unused_params.extend(
|
||||
list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||
else:
|
||||
unused_params.extend(list(self.latent_conditioner.parameters()))
|
||||
unused_params.extend(
|
||||
list(self.latent_conditioner.parameters()))
|
||||
|
||||
unused_params.append(self.unconditioned_embedding)
|
||||
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
time_emb = self.time_embed(
|
||||
timestep_embedding(timesteps, self.model_channels))
|
||||
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
||||
x = self.inp_block(x)
|
||||
x = torch.cat([x, code_emb], dim=1)
|
||||
|
@ -325,10 +364,12 @@ class FlatDiffusion(nn.Module):
|
|||
conditioning_input.shape) == 3 else conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
|
||||
conds.append(self.contextual_embedder(
|
||||
speech_conditioning_input[:, j]))
|
||||
conds = torch.cat(conds, dim=-1)
|
||||
return conds.mean(dim=-1)
|
||||
|
||||
|
||||
@register_model
|
||||
def register_flat_diffusion(opt_net, opt):
|
||||
return FlatDiffusion(**opt_net['kwargs'])
|
||||
|
@ -336,13 +377,13 @@ def register_flat_diffusion(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
clip = torch.randn(2, 256, 400)
|
||||
aligned_latent = torch.randn(2,388,512)
|
||||
aligned_sequence = torch.randint(0,8,(2,100,8))
|
||||
aligned_latent = torch.randn(2, 388, 512)
|
||||
aligned_sequence = torch.randint(0, 8, (2, 100, 8))
|
||||
cond = torch.randn(2, 256, 400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = FlatDiffusion(512, layer_drop=.3, unconditioned_percentage=.5, in_groups=8, train_mel_head=True)
|
||||
model = FlatDiffusion(
|
||||
512, layer_drop=.3, unconditioned_percentage=.5, in_groups=8, train_mel_head=True)
|
||||
# Test with latent aligned conditioning
|
||||
#o = model(clip, ts, aligned_latent, cond)
|
||||
# o = model(clip, ts, aligned_latent, cond)
|
||||
# Test with sequence aligned conditioning
|
||||
o = model(clip, ts, aligned_sequence, cond, return_code_pred=True)
|
||||
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.arch_util import AttentionBlock, ResBlock
|
||||
from models.audio.music.music_quantizer import MusicQuantizer
|
||||
from models.audio.music.music_quantizer2 import MusicQuantizer2
|
||||
from models.audio.tts.lucidrains_dvae import DiscreteVAE
|
||||
from models.lucidrains.x_transformers import Encoder
|
||||
from models.vqvae.vqvae import Quantize
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get, checkpoint, ceil_multiple, print_network
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.arch_util import AttentionBlock, ResBlock
|
||||
from dlas.models.audio.music.music_quantizer import MusicQuantizer
|
||||
from dlas.models.audio.music.music_quantizer2 import MusicQuantizer2
|
||||
from dlas.models.audio.tts.lucidrains_dvae import DiscreteVAE
|
||||
from dlas.models.lucidrains.x_transformers import Encoder
|
||||
from dlas.models.vqvae.vqvae import Quantize
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import ceil_multiple, checkpoint, opt_get, print_network
|
||||
|
||||
|
||||
class ConditioningEncoder(nn.Module):
|
||||
|
@ -22,9 +22,11 @@ class ConditioningEncoder(nn.Module):
|
|||
num_attn_heads=4):
|
||||
super().__init__()
|
||||
attn = []
|
||||
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=3, stride=2, padding=1)
|
||||
self.init = nn.Conv1d(spec_dim, embedding_dim,
|
||||
kernel_size=3, stride=2, padding=1)
|
||||
for a in range(attn_blocks):
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_activation=True))
|
||||
attn.append(AttentionBlock(embedding_dim,
|
||||
num_attn_heads, do_activation=True))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
|
||||
|
@ -43,14 +45,20 @@ class UpperConditioningEncoder(nn.Module):
|
|||
super().__init__()
|
||||
attn = []
|
||||
self.init = nn.Sequential(nn.Conv1d(spec_dim, min(spec_dim+128, embedding_dim), kernel_size=3, stride=2, padding=1),
|
||||
nn.Conv1d(min(spec_dim+128, embedding_dim), min(spec_dim+256, embedding_dim), kernel_size=3, stride=2, padding=1),
|
||||
nn.Conv1d(min(spec_dim+256, embedding_dim), min(spec_dim+384, embedding_dim), kernel_size=3, stride=2, padding=1),
|
||||
nn.Conv1d(min(spec_dim+384, embedding_dim), min(spec_dim+512, embedding_dim), kernel_size=3, stride=2, padding=1),
|
||||
ResBlock(min(spec_dim+512, embedding_dim), dims=1),
|
||||
nn.Conv1d(min(spec_dim+512, embedding_dim), min(spec_dim+512, embedding_dim), kernel_size=3, stride=2, padding=1),
|
||||
nn.Conv1d(min(spec_dim+128, embedding_dim), min(
|
||||
spec_dim+256, embedding_dim), kernel_size=3, stride=2, padding=1),
|
||||
nn.Conv1d(min(spec_dim+256, embedding_dim), min(
|
||||
spec_dim+384, embedding_dim), kernel_size=3, stride=2, padding=1),
|
||||
nn.Conv1d(min(spec_dim+384, embedding_dim), min(
|
||||
spec_dim+512, embedding_dim), kernel_size=3, stride=2, padding=1),
|
||||
ResBlock(
|
||||
min(spec_dim+512, embedding_dim), dims=1),
|
||||
nn.Conv1d(min(spec_dim+512, embedding_dim), min(
|
||||
spec_dim+512, embedding_dim), kernel_size=3, stride=2, padding=1),
|
||||
ResBlock(min(spec_dim+512, embedding_dim), dims=1))
|
||||
for a in range(attn_blocks):
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_activation=True))
|
||||
attn.append(AttentionBlock(embedding_dim,
|
||||
num_attn_heads, do_activation=True))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
|
||||
|
@ -67,21 +75,31 @@ class UpperQuantizer(nn.Module):
|
|||
num_tokens):
|
||||
super().__init__()
|
||||
attn = []
|
||||
|
||||
def edim(m):
|
||||
dd = max(embedding_dim//m, 128, spec_dim)
|
||||
return ceil_multiple(dd, 8)
|
||||
self.encoder = nn.Sequential(
|
||||
ResBlock(spec_dim, out_channels=edim(6), use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(6), out_channels=edim(5), use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(5), out_channels=edim(4), use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(4), out_channels=edim(3), use_conv=True, dims=1, down=True),
|
||||
ResBlock(spec_dim, out_channels=edim(6),
|
||||
use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(6), out_channels=edim(5),
|
||||
use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(5), out_channels=edim(4),
|
||||
use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(4), out_channels=edim(3),
|
||||
use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(3), out_channels=edim(3), use_conv=True, dims=1),
|
||||
ResBlock(edim(3), out_channels=edim(2), use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(3), out_channels=edim(2),
|
||||
use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(2), out_channels=edim(2), use_conv=True, dims=1),
|
||||
ResBlock(edim(2), out_channels=embedding_dim, use_conv=True, dims=1, down=True),
|
||||
ResBlock(embedding_dim, out_channels=embedding_dim, use_conv=True, dims=1),
|
||||
ResBlock(embedding_dim, out_channels=embedding_dim, use_conv=True, dims=1),
|
||||
ResBlock(embedding_dim, out_channels=embedding_dim, use_conv=True, dims=1),
|
||||
ResBlock(edim(2), out_channels=embedding_dim,
|
||||
use_conv=True, dims=1, down=True),
|
||||
ResBlock(embedding_dim, out_channels=embedding_dim,
|
||||
use_conv=True, dims=1),
|
||||
ResBlock(embedding_dim, out_channels=embedding_dim,
|
||||
use_conv=True, dims=1),
|
||||
ResBlock(embedding_dim, out_channels=embedding_dim,
|
||||
use_conv=True, dims=1),
|
||||
nn.GroupNorm(8, embedding_dim)
|
||||
)
|
||||
self.quantizer = Quantize(embedding_dim, num_tokens)
|
||||
|
@ -95,7 +113,7 @@ class UpperQuantizer(nn.Module):
|
|||
h = x
|
||||
for lyr in self.encoder:
|
||||
h = lyr(h)
|
||||
h = h.permute(0,2,1)
|
||||
h = h.permute(0, 2, 1)
|
||||
h_quant, commitment_loss, codes = self.quantizer(h)
|
||||
self.log_codes(codes)
|
||||
return h_quant, commitment_loss
|
||||
|
@ -105,7 +123,8 @@ class UpperQuantizer(nn.Module):
|
|||
if self.internal_step % 10 == 0:
|
||||
codes = codes.flatten()
|
||||
l = codes.shape[0]
|
||||
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
||||
i = self.code_ind if (
|
||||
self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
||||
self.codes[i:i+l] = codes.cpu()
|
||||
self.code_ind = self.code_ind + l
|
||||
if self.code_ind >= self.codes.shape[0]:
|
||||
|
@ -122,7 +141,8 @@ class GptMusicLower(nn.Module):
|
|||
self.freeze_upper_until = freeze_upper_until
|
||||
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
|
||||
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, use_cache=False)
|
||||
self.target_quantizers = nn.ModuleList([DiscreteVAE(**vqargs).eval() for _ in range(num_vaes)])
|
||||
self.target_quantizers = nn.ModuleList(
|
||||
[DiscreteVAE(**vqargs).eval() for _ in range(num_vaes)])
|
||||
self.upper_quantizer = UpperQuantizer(256, dim, num_upper_vectors)
|
||||
self.fp16 = fp16
|
||||
self.internal_step = 0
|
||||
|
@ -132,14 +152,17 @@ class GptMusicLower(nn.Module):
|
|||
p.DO_NOT_TRAIN = True
|
||||
p.requires_grad = False
|
||||
|
||||
self.conditioning_encoder = ConditioningEncoder(256, dim, attn_blocks=4, num_attn_heads=dim//64)
|
||||
self.conditioning_encoder = ConditioningEncoder(
|
||||
256, dim, attn_blocks=4, num_attn_heads=dim//64)
|
||||
|
||||
self.gpt = GPT2Model(self.config)
|
||||
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||
|
||||
# nn.Embedding
|
||||
self.embeddings = nn.ModuleList([ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
|
||||
self.heads = nn.ModuleList([ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
|
||||
self.embeddings = nn.ModuleList(
|
||||
[ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
|
||||
self.heads = nn.ModuleList(
|
||||
[ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
|
||||
|
||||
def forward(self, mel, conditioning, return_latent=False):
|
||||
unused_params = []
|
||||
|
@ -159,14 +182,17 @@ class GptMusicLower(nn.Module):
|
|||
unused_params.extend(list(self.upper_quantizer.parameters()))
|
||||
else:
|
||||
self.upper_quantizer = self.upper_quantizer.train()
|
||||
upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True)
|
||||
upper_vector = F.interpolate(upper_vector.permute(0,2,1), size=codes.shape[1], mode='linear')
|
||||
upper_vector = upper_vector.permute(0,2,1)
|
||||
upper_vector, upper_diversity = self.upper_quantizer(
|
||||
mel, return_decoder_latent=True)
|
||||
upper_vector = F.interpolate(upper_vector.permute(
|
||||
0, 2, 1), size=codes.shape[1], mode='linear')
|
||||
upper_vector = upper_vector.permute(0, 2, 1)
|
||||
|
||||
inputs = codes[:, :-1]
|
||||
targets = codes
|
||||
upper_vector = upper_vector[:, :-1]
|
||||
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
||||
h = [embedding(inputs[:, :, i])
|
||||
for i, embedding in enumerate(self.embeddings)]
|
||||
h = torch.cat(h, dim=-1) + upper_vector
|
||||
|
||||
with torch.autocast(mel.device.type, enabled=self.fp16):
|
||||
|
@ -183,8 +209,8 @@ class GptMusicLower(nn.Module):
|
|||
|
||||
losses = 0
|
||||
for i, head in enumerate(self.heads):
|
||||
logits = head(h).permute(0,2,1)
|
||||
loss = F.cross_entropy(logits, targets[:,:,i])
|
||||
logits = head(h).permute(0, 2, 1)
|
||||
loss = F.cross_entropy(logits, targets[:, :, i])
|
||||
losses = losses + loss
|
||||
|
||||
unused_adder = 0
|
||||
|
@ -221,11 +247,15 @@ class GptMusicUpper(nn.Module):
|
|||
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True,
|
||||
use_cache=False)
|
||||
self.upper_quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[dim,
|
||||
max(512,dim-128),
|
||||
max(512,dim-256),
|
||||
max(512,dim-384),
|
||||
max(512,dim-512),
|
||||
max(512,dim-512)], codevector_dim=dim,
|
||||
max(512,
|
||||
dim-128),
|
||||
max(512,
|
||||
dim-256),
|
||||
max(512,
|
||||
dim-384),
|
||||
max(512,
|
||||
dim-512),
|
||||
max(512, dim-512)], codevector_dim=dim,
|
||||
codebook_size=num_upper_vectors, codebook_groups=num_upper_groups,
|
||||
expressive_downsamples=True)
|
||||
# Following are unused quantizer constructs we delete to avoid DDP errors (and to be efficient.. of course..)
|
||||
|
@ -235,15 +265,17 @@ class GptMusicUpper(nn.Module):
|
|||
p.DO_NOT_TRAIN = True
|
||||
p.requires_grad = False
|
||||
|
||||
self.conditioning_encoder = UpperConditioningEncoder(256, dim, attn_blocks=4, num_attn_heads=dim//64)
|
||||
self.conditioning_encoder = UpperConditioningEncoder(
|
||||
256, dim, attn_blocks=4, num_attn_heads=dim//64)
|
||||
|
||||
self.gpt = GPT2Model(self.config)
|
||||
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||
|
||||
# nn.Embedding
|
||||
self.embeddings = nn.ModuleList([ml.Embedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)])
|
||||
self.heads = nn.ModuleList([ml.Linear(dim, num_upper_vectors) for _ in range(num_upper_groups)])
|
||||
|
||||
self.embeddings = nn.ModuleList([ml.Embedding(
|
||||
num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)])
|
||||
self.heads = nn.ModuleList(
|
||||
[ml.Linear(dim, num_upper_vectors) for _ in range(num_upper_groups)])
|
||||
|
||||
def forward(self, mel, conditioning, return_latent=False):
|
||||
with torch.no_grad():
|
||||
|
@ -252,7 +284,8 @@ class GptMusicUpper(nn.Module):
|
|||
|
||||
inputs = codes[:, :-1]
|
||||
targets = codes
|
||||
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
||||
h = [embedding(inputs[:, :, i])
|
||||
for i, embedding in enumerate(self.embeddings)]
|
||||
h = torch.cat(h, dim=-1)
|
||||
|
||||
with torch.autocast(mel.device.type, enabled=self.fp16):
|
||||
|
@ -269,8 +302,8 @@ class GptMusicUpper(nn.Module):
|
|||
|
||||
losses = 0
|
||||
for i, head in enumerate(self.heads):
|
||||
logits = head(h).permute(0,2,1)
|
||||
loss = F.cross_entropy(logits, targets[:,:,i])
|
||||
logits = head(h).permute(0, 2, 1)
|
||||
loss = F.cross_entropy(logits, targets[:, :, i])
|
||||
losses = losses + loss
|
||||
|
||||
return losses / self.num_groups
|
||||
|
@ -293,6 +326,7 @@ class GptMusicUpper(nn.Module):
|
|||
def register_music_gpt_lower(opt_net, opt):
|
||||
return GptMusicLower(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
@register_model
|
||||
def register_music_gpt_upper(opt_net, opt):
|
||||
return GptMusicUpper(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
@ -301,11 +335,11 @@ def register_music_gpt_upper(opt_net, opt):
|
|||
def test_lower():
|
||||
model = GptMusicLower(dim=512, layers=12, fp16=False, freeze_upper_until=1000,
|
||||
num_target_vectors=8192, num_upper_vectors=8192, num_vaes=4,
|
||||
vqargs= {
|
||||
'positional_dims': 1, 'channels': 64,
|
||||
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
|
||||
'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
|
||||
})
|
||||
vqargs={
|
||||
'positional_dims': 1, 'channels': 64,
|
||||
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
|
||||
'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
|
||||
})
|
||||
quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\7500_generator.pth',
|
||||
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\11000_generator.pth',
|
||||
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_high\\models\\11500_generator.pth',
|
||||
|
@ -316,7 +350,7 @@ def test_lower():
|
|||
torch.save(model.state_dict(), 'sample.pth')
|
||||
print_network(model)
|
||||
|
||||
mel = torch.randn(2,256,400)
|
||||
mel = torch.randn(2, 256, 400)
|
||||
model(mel, mel)
|
||||
pg = model.get_grad_norm_parameter_groups()
|
||||
|
||||
|
@ -335,14 +369,15 @@ def test_lower():
|
|||
|
||||
def test_upper():
|
||||
lower = GptMusicLower(512, 12)
|
||||
lower.load_state_dict(torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth'))
|
||||
lower.load_state_dict(torch.load(
|
||||
'D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth'))
|
||||
model = GptMusicUpper(512, 12)
|
||||
model.upper_quantizer.load_state_dict(lower.upper_quantizer.state_dict())
|
||||
torch.save(model.state_dict(), 'sample.pth')
|
||||
mel = torch.randn(2,256,2500)
|
||||
mel = torch.randn(2, 256, 2500)
|
||||
model(mel, mel)
|
||||
model.get_grad_norm_parameter_groups()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_lower()
|
||||
test_lower()
|
||||
|
|
|
@ -2,12 +2,12 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.arch_util import AttentionBlock, ResBlock
|
||||
from models.audio.tts.lucidrains_dvae import DiscreteVAE
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get, ceil_multiple, print_network
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.arch_util import AttentionBlock, ResBlock
|
||||
from dlas.models.audio.tts.lucidrains_dvae import DiscreteVAE
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import ceil_multiple, opt_get, print_network
|
||||
|
||||
|
||||
class UpperEncoder(nn.Module):
|
||||
|
@ -19,22 +19,30 @@ class UpperEncoder(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
attn = []
|
||||
|
||||
def edim(m):
|
||||
dd = min(spec_dim + m * 128, hidden_dim)
|
||||
return ceil_multiple(dd, 8)
|
||||
self.downsampler = nn.Sequential(
|
||||
ResBlock(spec_dim, out_channels=edim(1), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(1), out_channels=edim(2), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(spec_dim, out_channels=edim(1), use_conv=True, dims=1,
|
||||
down=True, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(1), out_channels=edim(2), use_conv=True, dims=1,
|
||||
down=True, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1,
|
||||
down=True, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(3), out_channels=edim(4), use_conv=True,
|
||||
dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(edim(4), out_channels=hidden_dim, use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled))
|
||||
self.encoder = nn.Sequential(
|
||||
AttentionBlock(hidden_dim, 4, do_activation=True),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True,
|
||||
dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
AttentionBlock(hidden_dim, 4, do_activation=True),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True,
|
||||
dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
AttentionBlock(hidden_dim, 4, do_activation=True),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True,
|
||||
dims=1, checkpointing_enabled=checkpointing_enabled),
|
||||
nn.GroupNorm(8, hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Conv1d(hidden_dim, embedding_dim, 1),
|
||||
|
@ -47,8 +55,6 @@ class UpperEncoder(nn.Module):
|
|||
return h
|
||||
|
||||
|
||||
|
||||
|
||||
class GptMusicLower(nn.Module):
|
||||
def __init__(self, dim, layers, encoder_out_dim, dropout=0, num_target_vectors=8192, fp16=True, num_vaes=4, vqargs={}):
|
||||
super().__init__()
|
||||
|
@ -58,7 +64,8 @@ class GptMusicLower(nn.Module):
|
|||
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True,
|
||||
use_cache=False)
|
||||
|
||||
self.target_quantizers = nn.ModuleList([DiscreteVAE(**vqargs).eval() for _ in range(num_vaes)])
|
||||
self.target_quantizers = nn.ModuleList(
|
||||
[DiscreteVAE(**vqargs).eval() for _ in range(num_vaes)])
|
||||
self.upper_encoder = UpperEncoder(256, dim, encoder_out_dim)
|
||||
self.encoder_projector = nn.Conv1d(encoder_out_dim, dim, 1)
|
||||
self.fp16 = fp16
|
||||
|
@ -75,8 +82,10 @@ class GptMusicLower(nn.Module):
|
|||
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||
|
||||
# nn.Embedding
|
||||
self.embeddings = nn.ModuleList([ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
|
||||
self.heads = nn.ModuleList([ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
|
||||
self.embeddings = nn.ModuleList(
|
||||
[ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
|
||||
self.heads = nn.ModuleList(
|
||||
[ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
|
||||
|
||||
def forward(self, mel, return_latent=False):
|
||||
unused_params = []
|
||||
|
@ -92,20 +101,23 @@ class GptMusicLower(nn.Module):
|
|||
upper_vector = self.upper_encoder(mel)
|
||||
upper_vector = self.encoder_projector(upper_vector)
|
||||
# WTB slerp
|
||||
upper_vector = F.interpolate(upper_vector, size=codes.shape[1], mode='linear')
|
||||
upper_vector = upper_vector.permute(0,2,1)
|
||||
upper_vector = F.interpolate(
|
||||
upper_vector, size=codes.shape[1], mode='linear')
|
||||
upper_vector = upper_vector.permute(0, 2, 1)
|
||||
|
||||
inputs = codes[:, :-1]
|
||||
targets = codes
|
||||
upper_vector = upper_vector[:, :-1]
|
||||
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
||||
h = [embedding(inputs[:, :, i])
|
||||
for i, embedding in enumerate(self.embeddings)]
|
||||
h = torch.cat(h, dim=-1) + upper_vector
|
||||
|
||||
with torch.autocast(mel.device.type, enabled=self.fp16):
|
||||
# Stick the conditioning embedding on the front of the input sequence.
|
||||
# The transformer will learn how to integrate it.
|
||||
# This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token.
|
||||
h = torch.cat([self.start_token.repeat(h.shape[0], 1, 1), h], dim=1)
|
||||
h = torch.cat(
|
||||
[self.start_token.repeat(h.shape[0], 1, 1), h], dim=1)
|
||||
|
||||
h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state
|
||||
|
||||
|
@ -114,8 +126,8 @@ class GptMusicLower(nn.Module):
|
|||
|
||||
losses = 0
|
||||
for i, head in enumerate(self.heads):
|
||||
logits = head(h).permute(0,2,1)
|
||||
loss = F.cross_entropy(logits, targets[:,:,i])
|
||||
logits = head(h).permute(0, 2, 1)
|
||||
loss = F.cross_entropy(logits, targets[:, :, i])
|
||||
losses = losses + loss
|
||||
|
||||
unused_adder = 0
|
||||
|
@ -143,10 +155,10 @@ def register_music_gpt_lower2(opt_net, opt):
|
|||
|
||||
def test_lower():
|
||||
model = GptMusicLower(dim=1024, encoder_out_dim=256, layers=16, fp16=False, num_target_vectors=8192, num_vaes=4,
|
||||
vqargs= {'positional_dims': 1, 'channels': 64,
|
||||
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
|
||||
'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
|
||||
})
|
||||
vqargs={'positional_dims': 1, 'channels': 64,
|
||||
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
|
||||
'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
|
||||
})
|
||||
quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\7500_generator.pth',
|
||||
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\11000_generator.pth',
|
||||
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_high\\models\\11500_generator.pth',
|
||||
|
@ -157,7 +169,7 @@ def test_lower():
|
|||
torch.save(model.state_dict(), 'sample.pth')
|
||||
print_network(model)
|
||||
|
||||
mel = torch.randn(2,256,400)
|
||||
mel = torch.randn(2, 256, 400)
|
||||
model(mel)
|
||||
pg = model.get_grad_norm_parameter_groups()
|
||||
|
||||
|
|
|
@ -3,13 +3,15 @@ import functools
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.diffusion.nn import timestep_embedding
|
||||
from models.lucidrains.vq import VectorQuantize
|
||||
from models.lucidrains.x_transformers import FeedForward, Attention, Decoder, RMSScaleShiftNorm
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.diffusion.nn import timestep_embedding
|
||||
from dlas.models.lucidrains.vq import VectorQuantize
|
||||
from dlas.models.lucidrains.x_transformers import (Attention, Decoder,
|
||||
FeedForward,
|
||||
RMSScaleShiftNorm)
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import checkpoint
|
||||
|
||||
|
||||
class SelfClassifyingHead(nn.Module):
|
||||
|
@ -19,7 +21,7 @@ class SelfClassifyingHead(nn.Module):
|
|||
self.num_classes = classes
|
||||
self.temperature = init_temperature
|
||||
self.dec = Decoder(dim=dim, depth=head_depth, heads=4, ff_dropout=dropout, ff_mult=2, attn_dropout=dropout,
|
||||
use_rmsnorm=True, ff_glu=True, do_checkpointing=False)
|
||||
use_rmsnorm=True, ff_glu=True, do_checkpointing=False)
|
||||
self.quantizer = VectorQuantize(out_dim, classes, use_cosine_sim=False, threshold_ema_dead_code=2,
|
||||
sample_codebook_temp=init_temperature)
|
||||
self.to_output = ml.Linear(dim, out_dim)
|
||||
|
@ -39,7 +41,8 @@ class SelfClassifyingHead(nn.Module):
|
|||
codes = []
|
||||
q_reg = 0
|
||||
for i in range(self.seq_len):
|
||||
q, c = checkpoint(functools.partial(self.do_ar_step, used_codes=codes), torch.stack(stack, dim=1))
|
||||
q, c = checkpoint(functools.partial(
|
||||
self.do_ar_step, used_codes=codes), torch.stack(stack, dim=1))
|
||||
q_reg = q_reg + (q ** 2).mean()
|
||||
s = torch.sigmoid(q)
|
||||
|
||||
|
@ -48,10 +51,13 @@ class SelfClassifyingHead(nn.Module):
|
|||
|
||||
# If the addition would strictly make the result worse, set it to 0. Sometimes.
|
||||
if len(results) > 0:
|
||||
worsen = (F.mse_loss(outputs[-1], target, reduction='none').sum(-1) < F.mse_loss(output, target, reduction='none').sum(-1)).float()
|
||||
worsen = (F.mse_loss(outputs[-1], target, reduction='none').sum(-1) < F.mse_loss(
|
||||
output, target, reduction='none').sum(-1)).float()
|
||||
probabilistic_worsen = torch.rand_like(worsen) * worsen > .5
|
||||
output = output * probabilistic_worsen.unsqueeze(-1) # This is non-differentiable, but still deterministic.
|
||||
c[probabilistic_worsen] = -1 # Code of -1 means the code was unused.
|
||||
# This is non-differentiable, but still deterministic.
|
||||
output = output * probabilistic_worsen.unsqueeze(-1)
|
||||
# Code of -1 means the code was unused.
|
||||
c[probabilistic_worsen] = -1
|
||||
s = s * probabilistic_worsen.unsqueeze(-1)
|
||||
outputs[-1] = s
|
||||
|
||||
|
@ -65,7 +71,8 @@ class VectorResBlock(nn.Module):
|
|||
def __init__(self, dim, dropout):
|
||||
super().__init__()
|
||||
self.norm = nn.BatchNorm1d(dim)
|
||||
self.ff = FeedForward(dim, mult=2, glu=True, dropout=dropout, zero_init_output=True)
|
||||
self.ff = FeedForward(dim, mult=2, glu=True,
|
||||
dropout=dropout, zero_init_output=True)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.norm(x.unsqueeze(-1)).squeeze(-1)
|
||||
|
@ -92,8 +99,10 @@ class InstrumentQuantizer(nn.Module):
|
|||
super().__init__()
|
||||
self.op_dim = op_dim
|
||||
self.proj = ml.Linear(op_dim, dim)
|
||||
self.encoder = nn.ModuleList([VectorResBlock(dim, dropout) for _ in range(enc_depth)])
|
||||
self.heads = SelfClassifyingHead(dim, num_classes, op_dim, head_depth, class_seq_len, dropout, max_temp)
|
||||
self.encoder = nn.ModuleList(
|
||||
[VectorResBlock(dim, dropout) for _ in range(enc_depth)])
|
||||
self.heads = SelfClassifyingHead(
|
||||
dim, num_classes, op_dim, head_depth, class_seq_len, dropout, max_temp)
|
||||
self.min_gumbel_temperature = min_temp
|
||||
self.max_gumbel_temperature = max_temp
|
||||
self.gumbel_temperature_decay = temp_decay
|
||||
|
@ -109,16 +118,19 @@ class InstrumentQuantizer(nn.Module):
|
|||
x = (x + 1) / 2
|
||||
|
||||
b, c, s = x.shape
|
||||
px = x.permute(0,2,1) # B,S,C shape
|
||||
px = x.permute(0, 2, 1) # B,S,C shape
|
||||
f = px.reshape(-1, self.op_dim)
|
||||
h = self.proj(f)
|
||||
for lyr in self.encoder:
|
||||
h = lyr(h)
|
||||
|
||||
reconstructions, codes, q_reg = self.heads(h, f)
|
||||
reconstruction_losses = torch.stack([F.mse_loss(r.reshape(b, s, c), px) for r in reconstructions])
|
||||
r_follow = torch.arange(1, reconstruction_losses.shape[0]+1, device=x.device)
|
||||
reconstruction_losses = (reconstruction_losses * r_follow / r_follow.shape[0])
|
||||
reconstruction_losses = torch.stack(
|
||||
[F.mse_loss(r.reshape(b, s, c), px) for r in reconstructions])
|
||||
r_follow = torch.arange(
|
||||
1, reconstruction_losses.shape[0]+1, device=x.device)
|
||||
reconstruction_losses = (
|
||||
reconstruction_losses * r_follow / r_follow.shape[0])
|
||||
self.log_codes(codes)
|
||||
|
||||
return reconstruction_losses, q_reg
|
||||
|
@ -126,7 +138,8 @@ class InstrumentQuantizer(nn.Module):
|
|||
def log_codes(self, codes):
|
||||
if self.internal_step % 5 == 0:
|
||||
l = codes.shape[0]
|
||||
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
||||
i = self.code_ind if (
|
||||
self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
||||
self.codes[i:i+l] = codes.cpu()
|
||||
self.code_ind = self.code_ind + l
|
||||
if self.code_ind >= self.codes.shape[0]:
|
||||
|
@ -143,9 +156,9 @@ class InstrumentQuantizer(nn.Module):
|
|||
def update_for_step(self, step, *args):
|
||||
self.internal_step = step
|
||||
self.heads.quantizer._codebook.sample_codebook_temp = max(
|
||||
self.max_gumbel_temperature * self.gumbel_temperature_decay**step,
|
||||
self.min_gumbel_temperature,
|
||||
)
|
||||
self.max_gumbel_temperature * self.gumbel_temperature_decay**step,
|
||||
self.min_gumbel_temperature,
|
||||
)
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
groups = {
|
||||
|
@ -162,6 +175,6 @@ def register_instrument_quantizer(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
inp = torch.randn((4,256,200)).clamp(-1,1)
|
||||
inp = torch.randn((4, 256, 200)).clamp(-1, 1)
|
||||
model = InstrumentQuantizer(256, 512, 4096, 8, 3)
|
||||
model(inp)
|
||||
|
|
|
@ -2,10 +2,10 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.arch_util import ResBlock, AttentionBlock
|
||||
from models.audio.music.flat_diffusion import MultiGroupEmbedding
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
from dlas.models.arch_util import AttentionBlock, ResBlock
|
||||
from dlas.models.audio.music.flat_diffusion import MultiGroupEmbedding
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import checkpoint
|
||||
|
||||
|
||||
class Code2Mel(nn.Module):
|
||||
|
@ -13,27 +13,33 @@ class Code2Mel(nn.Module):
|
|||
super().__init__()
|
||||
self.emb = MultiGroupEmbedding(num_tokens, num_groups, base_dim)
|
||||
self.base_blocks = nn.Sequential(ResBlock(base_dim, dropout, dims=1),
|
||||
AttentionBlock(base_dim, num_heads=base_dim//64),
|
||||
AttentionBlock(
|
||||
base_dim, num_heads=base_dim//64),
|
||||
ResBlock(base_dim, dropout, dims=1))
|
||||
l2dim = base_dim-256
|
||||
self.l2_up_block = nn.Conv1d(base_dim, l2dim, kernel_size=5, padding=2)
|
||||
self.l2_blocks = nn.Sequential(ResBlock(l2dim, dropout, kernel_size=5, dims=1),
|
||||
AttentionBlock(l2dim, num_heads=base_dim//64),
|
||||
ResBlock(l2dim, dropout, kernel_size=5, dims=1),
|
||||
AttentionBlock(l2dim, num_heads=base_dim//64),
|
||||
ResBlock(l2dim, dropout, dims=1),
|
||||
ResBlock(l2dim, dropout, dims=1))
|
||||
AttentionBlock(
|
||||
l2dim, num_heads=base_dim//64),
|
||||
ResBlock(l2dim, dropout,
|
||||
kernel_size=5, dims=1),
|
||||
AttentionBlock(
|
||||
l2dim, num_heads=base_dim//64),
|
||||
ResBlock(l2dim, dropout, dims=1),
|
||||
ResBlock(l2dim, dropout, dims=1))
|
||||
l3dim = l2dim-256
|
||||
self.l3_up_block = nn.Conv1d(l2dim, l3dim, kernel_size=5, padding=2)
|
||||
self.l3_blocks = nn.Sequential(ResBlock(l3dim, dropout, kernel_size=5, dims=1),
|
||||
AttentionBlock(l3dim, num_heads=base_dim//64),
|
||||
ResBlock(l3dim, dropout, kernel_size=5, dims=1),
|
||||
AttentionBlock(
|
||||
l3dim, num_heads=base_dim//64),
|
||||
ResBlock(l3dim, dropout,
|
||||
kernel_size=5, dims=1),
|
||||
ResBlock(l3dim, dropout, dims=1))
|
||||
self.final_block = nn.Conv1d(l3dim, out_dim, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, codes, target):
|
||||
with torch.autocast(codes.device.type):
|
||||
h = self.emb(codes).permute(0,2,1)
|
||||
h = self.emb(codes).permute(0, 2, 1)
|
||||
h = checkpoint(self.base_blocks, h)
|
||||
h = F.interpolate(h, scale_factor=2, mode='linear')
|
||||
h = self.l2_up_block(h)
|
||||
|
@ -52,6 +58,6 @@ def register_code2mel(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
model = Code2Mel()
|
||||
codes = torch.randint(0,16, (2,200,4))
|
||||
target = torch.randn(2,256,804)
|
||||
model(codes, target)
|
||||
codes = torch.randint(0, 16, (2, 200, 4))
|
||||
target = torch.randn(2, 256, 804)
|
||||
model(codes, target)
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch_intermediary as ml
|
||||
from torch import nn
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
class Mel2VecCodesGpt(nn.Module):
|
||||
|
@ -19,8 +19,10 @@ class Mel2VecCodesGpt(nn.Module):
|
|||
self.gpt = GPT2Model(self.config)
|
||||
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||
# nn.Embedding
|
||||
self.embeddings = nn.ModuleList([ml.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)])
|
||||
self.heads = nn.ModuleList([ml.Linear(dim, num_vectors) for _ in range(num_groups)])
|
||||
self.embeddings = nn.ModuleList(
|
||||
[ml.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)])
|
||||
self.heads = nn.ModuleList(
|
||||
[ml.Linear(dim, num_vectors) for _ in range(num_groups)])
|
||||
|
||||
def forward(self, codes):
|
||||
assert codes.shape[-1] == self.num_groups
|
||||
|
@ -28,14 +30,15 @@ class Mel2VecCodesGpt(nn.Module):
|
|||
inputs = codes[:, :-1]
|
||||
targets = codes[:, 1:]
|
||||
|
||||
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
||||
h = [embedding(inputs[:, :, i])
|
||||
for i, embedding in enumerate(self.embeddings)]
|
||||
h = torch.cat(h, dim=-1)
|
||||
h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state
|
||||
|
||||
losses = 0
|
||||
for i, head in enumerate(self.heads):
|
||||
logits = head(h).permute(0,2,1)
|
||||
loss = F.cross_entropy(logits, targets[:,:,i])
|
||||
logits = head(h).permute(0, 2, 1)
|
||||
loss = F.cross_entropy(logits, targets[:, :, i])
|
||||
losses = losses + loss
|
||||
|
||||
return losses / self.num_groups
|
||||
|
@ -48,5 +51,5 @@ def register_music_gpt(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
model = Mel2VecCodesGpt(512, 8)
|
||||
codes = torch.randint(0,8, (2,300,8))
|
||||
model(codes)
|
||||
codes = torch.randint(0, 8, (2, 300, 8))
|
||||
model(codes)
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
import functools
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch_intermediary as ml
|
||||
from torch import nn
|
||||
|
||||
from models.arch_util import zero_module
|
||||
from models.vqvae.vqvae import Quantize
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, ceil_multiple, print_network
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.arch_util import zero_module
|
||||
from dlas.models.vqvae.vqvae import Quantize
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import ceil_multiple, checkpoint, print_network
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
|
@ -37,13 +37,13 @@ class ResBlock(nn.Module):
|
|||
def __init__(self, chan):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv1d(chan, chan, 3, padding = 1),
|
||||
nn.Conv1d(chan, chan, 3, padding=1),
|
||||
nn.GroupNorm(8, chan),
|
||||
nn.SiLU(),
|
||||
nn.Conv1d(chan, chan, 3, padding = 1),
|
||||
nn.Conv1d(chan, chan, 3, padding=1),
|
||||
nn.GroupNorm(8, chan),
|
||||
nn.SiLU(),
|
||||
zero_module(nn.Conv1d(chan, chan, 3, padding = 1)),
|
||||
zero_module(nn.Conv1d(chan, chan, 3, padding=1)),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -74,7 +74,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
|
||||
# storage for codebook variables (codewords)
|
||||
self.codevectors = nn.Parameter(
|
||||
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
|
||||
torch.FloatTensor(1, self.num_groups * self.num_vars,
|
||||
codevector_dim // self.num_groups)
|
||||
)
|
||||
self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars)
|
||||
|
||||
|
@ -95,7 +96,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
else:
|
||||
marginal_probs = probs.mean(dim=0)
|
||||
|
||||
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
|
||||
perplexity = torch.exp(-torch.sum(marginal_probs *
|
||||
torch.log(marginal_probs + 1e-7), dim=-1)).sum()
|
||||
return perplexity
|
||||
|
||||
def get_codes(self, hidden_states):
|
||||
|
@ -103,9 +105,11 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
|
||||
# project to codevector dim
|
||||
hidden_states = self.weight_proj(hidden_states)
|
||||
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
|
||||
hidden_states = hidden_states.view(
|
||||
batch_size * sequence_length * self.num_groups, -1)
|
||||
codevector_idx = hidden_states.argmax(dim=-1)
|
||||
idxs = codevector_idx.view(batch_size, sequence_length, self.num_groups)
|
||||
idxs = codevector_idx.view(
|
||||
batch_size, sequence_length, self.num_groups)
|
||||
return idxs
|
||||
|
||||
def forward(self, hidden_states, mask_time_indices=None, return_probs=False):
|
||||
|
@ -113,7 +117,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
|
||||
# project to codevector dim
|
||||
hidden_states = self.weight_proj(hidden_states)
|
||||
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
|
||||
hidden_states = hidden_states.view(
|
||||
batch_size * sequence_length * self.num_groups, -1)
|
||||
|
||||
if self.training:
|
||||
# sample code vector probs via gumbel in differentiable way
|
||||
|
@ -125,7 +130,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
codevector_soft_dist = torch.softmax(
|
||||
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
|
||||
)
|
||||
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
|
||||
perplexity = self._compute_perplexity(
|
||||
codevector_soft_dist, mask_time_indices)
|
||||
else:
|
||||
# take argmax in non-differentiable way
|
||||
# compute hard codevector distribution (one hot)
|
||||
|
@ -133,15 +139,20 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
|
||||
-1, codevector_idx.view(-1, 1), 1.0
|
||||
)
|
||||
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
|
||||
codevector_probs = codevector_probs.view(
|
||||
batch_size * sequence_length, self.num_groups, -1)
|
||||
|
||||
perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
|
||||
perplexity = self._compute_perplexity(
|
||||
codevector_probs, mask_time_indices)
|
||||
|
||||
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
|
||||
codevector_probs = codevector_probs.view(
|
||||
batch_size * sequence_length, -1)
|
||||
# use probs to retrieve codevectors
|
||||
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
|
||||
codevectors_per_group = codevector_probs.unsqueeze(
|
||||
-1) * self.codevectors
|
||||
codevectors = (
|
||||
codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
|
||||
codevectors_per_group.view(
|
||||
batch_size * sequence_length, self.num_groups, self.num_vars, -1)
|
||||
.sum(-2)
|
||||
.view(batch_size, sequence_length, -1)
|
||||
)
|
||||
|
@ -164,7 +175,8 @@ class MusicQuantizer(nn.Module):
|
|||
self.use_vqvae_quantizer = use_vqvae_quantizer
|
||||
if use_vqvae_quantizer:
|
||||
self.quantizer = Quantize(inner_dim[0], codebook_size)
|
||||
assert codevector_dim == inner_dim[0] # Because this quantizer doesn't support different sizes.
|
||||
# Because this quantizer doesn't support different sizes.
|
||||
assert codevector_dim == inner_dim[0]
|
||||
else:
|
||||
self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim[0], codevector_dim=codevector_dim,
|
||||
num_codevector_groups=codebook_groups,
|
||||
|
@ -174,8 +186,10 @@ class MusicQuantizer(nn.Module):
|
|||
self.num_losses_record = []
|
||||
|
||||
if down_steps == 0:
|
||||
self.down = nn.Conv1d(inp_channels, inner_dim[0], kernel_size=3, padding=1)
|
||||
self.up = nn.Conv1d(inner_dim[0], inp_channels, kernel_size=3, padding=1)
|
||||
self.down = nn.Conv1d(
|
||||
inp_channels, inner_dim[0], kernel_size=3, padding=1)
|
||||
self.up = nn.Conv1d(
|
||||
inner_dim[0], inp_channels, kernel_size=3, padding=1)
|
||||
elif down_steps == 2:
|
||||
self.down = nn.Sequential(nn.Conv1d(inp_channels, inner_dim[-1], kernel_size=3, padding=1),
|
||||
Downsample(inner_dim[-1], inner_dim[-2]),
|
||||
|
@ -201,25 +215,27 @@ class MusicQuantizer(nn.Module):
|
|||
def get_codes(self, mel):
|
||||
h = self.down(mel)
|
||||
h = self.encoder(h)
|
||||
h = self.enc_norm(h.permute(0,2,1))
|
||||
h = self.enc_norm(h.permute(0, 2, 1))
|
||||
return self.quantizer.get_codes(h)
|
||||
|
||||
def forward(self, mel, return_decoder_latent=False):
|
||||
orig_mel = mel
|
||||
cm = ceil_multiple(mel.shape[-1], 4)
|
||||
if cm != 0:
|
||||
mel = F.pad(mel, (0,cm-mel.shape[-1]))
|
||||
mel = F.pad(mel, (0, cm-mel.shape[-1]))
|
||||
|
||||
h = self.down(mel)
|
||||
h = self.encoder(h)
|
||||
h = self.enc_norm(h.permute(0,2,1))
|
||||
h = self.enc_norm(h.permute(0, 2, 1))
|
||||
if self.use_vqvae_quantizer:
|
||||
codevectors, diversity, codes = self.quantizer(h)
|
||||
else:
|
||||
codevectors, perplexity, codes = self.quantizer(h, return_probs=True)
|
||||
diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors
|
||||
codevectors, perplexity, codes = self.quantizer(
|
||||
h, return_probs=True)
|
||||
diversity = (self.quantizer.num_codevectors -
|
||||
perplexity) / self.quantizer.num_codevectors
|
||||
self.log_codes(codes)
|
||||
h = self.decoder(codevectors.permute(0,2,1))
|
||||
h = self.decoder(codevectors.permute(0, 2, 1))
|
||||
if return_decoder_latent:
|
||||
return h, diversity
|
||||
|
||||
|
@ -233,13 +249,14 @@ class MusicQuantizer(nn.Module):
|
|||
if self.internal_step % 5 == 0:
|
||||
if not self.use_vqvae_quantizer:
|
||||
codes = torch.argmax(codes, dim=-1)
|
||||
ccodes = codes[:,:,0]
|
||||
for j in range(1,codes.shape[-1]):
|
||||
ccodes += codes[:,:,j] * self.codebook_size ** j
|
||||
ccodes = codes[:, :, 0]
|
||||
for j in range(1, codes.shape[-1]):
|
||||
ccodes += codes[:, :, j] * self.codebook_size ** j
|
||||
codes = ccodes
|
||||
codes = codes.flatten()
|
||||
l = codes.shape[0]
|
||||
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
||||
i = self.code_ind if (
|
||||
self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
||||
self.codes[i:i+l] = codes.cpu()
|
||||
self.code_ind = self.code_ind + l
|
||||
if self.code_ind >= self.codes.shape[0]:
|
||||
|
@ -259,7 +276,8 @@ def register_music_quantizer(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = MusicQuantizer(inner_dim=[1024,1024,512], codevector_dim=1024, codebook_size=8192, codebook_groups=0, use_vqvae_quantizer=True)
|
||||
model = MusicQuantizer(inner_dim=[1024, 1024, 512], codevector_dim=1024,
|
||||
codebook_size=8192, codebook_groups=0, use_vqvae_quantizer=True)
|
||||
print_network(model)
|
||||
mel = torch.randn((2,256,782))
|
||||
model(mel)
|
||||
mel = torch.randn((2, 256, 782))
|
||||
model(mel)
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
import functools
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch_intermediary as ml
|
||||
from torch import nn
|
||||
|
||||
from models.arch_util import zero_module
|
||||
from models.vqvae.vqvae import Quantize
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, ceil_multiple, print_network
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.arch_util import zero_module
|
||||
from dlas.models.vqvae.vqvae import Quantize
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import ceil_multiple, checkpoint, print_network
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
|
@ -16,7 +16,8 @@ class Downsample(nn.Module):
|
|||
super().__init__()
|
||||
self.interpolate = not stride_down
|
||||
if stride_down:
|
||||
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=3, padding=1, stride=2)
|
||||
self.conv = nn.Conv1d(
|
||||
chan_in, chan_out, kernel_size=3, padding=1, stride=2)
|
||||
else:
|
||||
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=3, padding=1)
|
||||
if norm:
|
||||
|
@ -49,13 +50,13 @@ class ResBlock(nn.Module):
|
|||
def __init__(self, chan):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv1d(chan, chan, 3, padding = 1),
|
||||
nn.Conv1d(chan, chan, 3, padding=1),
|
||||
nn.GroupNorm(8, chan),
|
||||
nn.SiLU(),
|
||||
nn.Conv1d(chan, chan, 3, padding = 1),
|
||||
nn.Conv1d(chan, chan, 3, padding=1),
|
||||
nn.GroupNorm(8, chan),
|
||||
nn.SiLU(),
|
||||
zero_module(nn.Conv1d(chan, chan, 3, padding = 1)),
|
||||
zero_module(nn.Conv1d(chan, chan, 3, padding=1)),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -86,7 +87,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
|
||||
# storage for codebook variables (codewords)
|
||||
self.codevectors = nn.Parameter(
|
||||
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
|
||||
torch.FloatTensor(1, self.num_groups * self.num_vars,
|
||||
codevector_dim // self.num_groups)
|
||||
)
|
||||
self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars)
|
||||
|
||||
|
@ -107,7 +109,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
else:
|
||||
marginal_probs = probs.mean(dim=0)
|
||||
|
||||
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
|
||||
perplexity = torch.exp(-torch.sum(marginal_probs *
|
||||
torch.log(marginal_probs + 1e-7), dim=-1)).sum()
|
||||
return perplexity
|
||||
|
||||
def get_codes(self, hidden_states):
|
||||
|
@ -115,9 +118,11 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
|
||||
# project to codevector dim
|
||||
hidden_states = self.weight_proj(hidden_states)
|
||||
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
|
||||
hidden_states = hidden_states.view(
|
||||
batch_size * sequence_length * self.num_groups, -1)
|
||||
codevector_idx = hidden_states.argmax(dim=-1)
|
||||
idxs = codevector_idx.view(batch_size, sequence_length, self.num_groups)
|
||||
idxs = codevector_idx.view(
|
||||
batch_size, sequence_length, self.num_groups)
|
||||
return idxs
|
||||
|
||||
def forward(self, hidden_states, mask_time_indices=None, return_probs=False):
|
||||
|
@ -125,7 +130,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
|
||||
# project to codevector dim
|
||||
hidden_states = self.weight_proj(hidden_states)
|
||||
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
|
||||
hidden_states = hidden_states.view(
|
||||
batch_size * sequence_length * self.num_groups, -1)
|
||||
|
||||
if self.training:
|
||||
# sample code vector probs via gumbel in differentiable way
|
||||
|
@ -137,7 +143,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
codevector_soft_dist = torch.softmax(
|
||||
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
|
||||
)
|
||||
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
|
||||
perplexity = self._compute_perplexity(
|
||||
codevector_soft_dist, mask_time_indices)
|
||||
else:
|
||||
# take argmax in non-differentiable way
|
||||
# compute hard codevector distribution (one hot)
|
||||
|
@ -145,15 +152,20 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
|
||||
-1, codevector_idx.view(-1, 1), 1.0
|
||||
)
|
||||
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
|
||||
codevector_probs = codevector_probs.view(
|
||||
batch_size * sequence_length, self.num_groups, -1)
|
||||
|
||||
perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
|
||||
perplexity = self._compute_perplexity(
|
||||
codevector_probs, mask_time_indices)
|
||||
|
||||
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
|
||||
codevector_probs = codevector_probs.view(
|
||||
batch_size * sequence_length, -1)
|
||||
# use probs to retrieve codevectors
|
||||
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
|
||||
codevectors_per_group = codevector_probs.unsqueeze(
|
||||
-1) * self.codevectors
|
||||
codevectors = (
|
||||
codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
|
||||
codevectors_per_group.view(
|
||||
batch_size * sequence_length, self.num_groups, self.num_vars, -1)
|
||||
.sum(-2)
|
||||
.view(batch_size, sequence_length, -1)
|
||||
)
|
||||
|
@ -183,12 +195,14 @@ class MusicQuantizer2(nn.Module):
|
|||
self.num_losses_record = []
|
||||
|
||||
if down_steps == 0:
|
||||
self.down = nn.Conv1d(inp_channels, inner_dim[0], kernel_size=3, padding=1)
|
||||
self.up = nn.Conv1d(inner_dim[0], inp_channels, kernel_size=3, padding=1)
|
||||
self.down = nn.Conv1d(
|
||||
inp_channels, inner_dim[0], kernel_size=3, padding=1)
|
||||
self.up = nn.Conv1d(
|
||||
inner_dim[0], inp_channels, kernel_size=3, padding=1)
|
||||
elif down_steps == 2:
|
||||
self.down = nn.Sequential(nn.Conv1d(inp_channels, inner_dim[-1], kernel_size=3, padding=1),
|
||||
*[Downsample(inner_dim[-i], inner_dim[-i-1], norm=expressive_downsamples, act=expressive_downsamples,
|
||||
stride_down=expressive_downsamples) for i in range(1,len(inner_dim))])
|
||||
stride_down=expressive_downsamples) for i in range(1, len(inner_dim))])
|
||||
self.up = nn.Sequential(*[Upsample(inner_dim[i], inner_dim[i+1]) for i in range(len(inner_dim)-1)] +
|
||||
[nn.Conv1d(inner_dim[-1], inp_channels, kernel_size=3, padding=1)])
|
||||
|
||||
|
@ -209,22 +223,23 @@ class MusicQuantizer2(nn.Module):
|
|||
def get_codes(self, mel):
|
||||
h = self.down(mel)
|
||||
h = self.encoder(h)
|
||||
h = self.enc_norm(h.permute(0,2,1))
|
||||
h = self.enc_norm(h.permute(0, 2, 1))
|
||||
return self.quantizer.get_codes(h)
|
||||
|
||||
def forward(self, mel, return_decoder_latent=False):
|
||||
orig_mel = mel
|
||||
cm = ceil_multiple(mel.shape[-1], 2 ** (len(self.down)-1))
|
||||
if cm != 0:
|
||||
mel = F.pad(mel, (0,cm-mel.shape[-1]))
|
||||
mel = F.pad(mel, (0, cm-mel.shape[-1]))
|
||||
|
||||
h = self.down(mel)
|
||||
h = self.encoder(h)
|
||||
h = self.enc_norm(h.permute(0,2,1))
|
||||
h = self.enc_norm(h.permute(0, 2, 1))
|
||||
codevectors, perplexity, codes = self.quantizer(h, return_probs=True)
|
||||
diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors
|
||||
diversity = (self.quantizer.num_codevectors -
|
||||
perplexity) / self.quantizer.num_codevectors
|
||||
self.log_codes(codes)
|
||||
h = self.decoder(codevectors.permute(0,2,1))
|
||||
h = self.decoder(codevectors.permute(0, 2, 1))
|
||||
if return_decoder_latent:
|
||||
return h, diversity
|
||||
|
||||
|
@ -237,13 +252,14 @@ class MusicQuantizer2(nn.Module):
|
|||
def log_codes(self, codes):
|
||||
if self.internal_step % 5 == 0:
|
||||
codes = torch.argmax(codes, dim=-1)
|
||||
ccodes = codes[:,:,0]
|
||||
for j in range(1,codes.shape[-1]):
|
||||
ccodes += codes[:,:,j] * self.codebook_size ** j
|
||||
ccodes = codes[:, :, 0]
|
||||
for j in range(1, codes.shape[-1]):
|
||||
ccodes += codes[:, :, j] * self.codebook_size ** j
|
||||
codes = ccodes
|
||||
codes = codes.flatten()
|
||||
l = codes.shape[0]
|
||||
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
||||
i = self.code_ind if (
|
||||
self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
||||
self.codes[i:i+l] = codes.cpu()
|
||||
self.code_ind = self.code_ind + l
|
||||
if self.code_ind >= self.codes.shape[0]:
|
||||
|
@ -258,9 +274,9 @@ class MusicQuantizer2(nn.Module):
|
|||
|
||||
def update_for_step(self, step, *args):
|
||||
self.quantizer.temperature = max(
|
||||
self.max_gumbel_temperature * self.gumbel_temperature_decay**step,
|
||||
self.min_gumbel_temperature,
|
||||
)
|
||||
self.max_gumbel_temperature * self.gumbel_temperature_decay**step,
|
||||
self.min_gumbel_temperature,
|
||||
)
|
||||
|
||||
|
||||
@register_model
|
||||
|
@ -269,7 +285,8 @@ def register_music_quantizer2(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = MusicQuantizer2(inner_dim=[1024], codevector_dim=1024, codebook_size=256, codebook_groups=2)
|
||||
model = MusicQuantizer2(
|
||||
inner_dim=[1024], codevector_dim=1024, codebook_size=256, codebook_groups=2)
|
||||
print_network(model)
|
||||
mel = torch.randn((2,256,782))
|
||||
model(mel)
|
||||
mel = torch.randn((2, 256, 782))
|
||||
model(mel)
|
||||
|
|
|
@ -8,14 +8,17 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import torchvision
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import TimestepBlock
|
||||
from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, RotaryEmbedding, \
|
||||
FeedForward
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, print_network, load_audio
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
|
||||
timestep_embedding, zero_module)
|
||||
from dlas.models.diffusion.unet_diffusion import TimestepBlock
|
||||
from dlas.models.lucidrains.x_transformers import (Attention, Encoder,
|
||||
FeedForward,
|
||||
RMSScaleShiftNorm,
|
||||
RotaryEmbedding)
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import checkpoint, load_audio, print_network
|
||||
|
||||
|
||||
class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
|
@ -389,17 +392,20 @@ def inference_tfdpc5_with_cheater():
|
|||
use_fp16=False, unconditioned_percentage=0).eval().cuda()
|
||||
model.load_state_dict(torch.load('x:/dlas/experiments/train_music_cheater_gen_v3/models/59000_generator_ema.pth'))
|
||||
|
||||
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector
|
||||
from trainer.injectors.audio_injectors import \
|
||||
TorchMelSpectrogramInjector
|
||||
spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000, 'true_normalization': True,
|
||||
'normalize': True, 'in': 'in', 'out': 'out'}, {}).cuda()
|
||||
ref_mel = spec_fn({'in': sample.unsqueeze(0)})['out']
|
||||
from trainer.injectors.audio_injectors import MusicCheaterLatentInjector
|
||||
from trainer.injectors.audio_injectors import \
|
||||
MusicCheaterLatentInjector
|
||||
cheater_encoder = MusicCheaterLatentInjector({'in': 'in', 'out': 'out'}, {}).cuda()
|
||||
ref_cheater = cheater_encoder({'in': ref_mel})['out']
|
||||
|
||||
from models.diffusion.respace import SpacedDiffusion
|
||||
from models.diffusion.respace import space_timesteps
|
||||
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
||||
from models.diffusion.gaussian_diffusion import \
|
||||
get_named_beta_schedule
|
||||
from models.diffusion.respace import (SpacedDiffusion,
|
||||
space_timesteps)
|
||||
diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [128]), model_mean_type='epsilon',
|
||||
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
|
||||
conditioning_free=True, conditioning_free_k=1)
|
||||
|
@ -415,7 +421,8 @@ def inference_tfdpc5_with_cheater():
|
|||
# Just decode the ref.
|
||||
#gen_cheater = ref_cheater
|
||||
|
||||
from models.audio.music.transformer_diffusion12 import TransformerDiffusionWithCheaterLatent
|
||||
from models.audio.music.transformer_diffusion12 import \
|
||||
TransformerDiffusionWithCheaterLatent
|
||||
diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [32]), model_mean_type='epsilon',
|
||||
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
|
||||
conditioning_free=True, conditioning_free_k=1)
|
||||
|
|
|
@ -4,18 +4,21 @@ from time import time
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.arch_util import ResBlock
|
||||
from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower
|
||||
from models.audio.music.music_quantizer2 import MusicQuantizer2
|
||||
from models.audio.tts.lucidrains_dvae import DiscreteVAE
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import TimestepBlock
|
||||
from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, RotaryEmbedding, \
|
||||
FeedForward
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, print_network
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.arch_util import ResBlock
|
||||
from dlas.models.audio.music.gpt_music2 import GptMusicLower, UpperEncoder
|
||||
from dlas.models.audio.music.music_quantizer2 import MusicQuantizer2
|
||||
from dlas.models.audio.tts.lucidrains_dvae import DiscreteVAE
|
||||
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
|
||||
timestep_embedding, zero_module)
|
||||
from dlas.models.diffusion.unet_diffusion import TimestepBlock
|
||||
from dlas.models.lucidrains.x_transformers import (Attention, Encoder,
|
||||
FeedForward,
|
||||
RMSScaleShiftNorm,
|
||||
RotaryEmbedding)
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import checkpoint, print_network
|
||||
|
||||
|
||||
def is_latent(t):
|
||||
|
|
|
@ -1,26 +1,29 @@
|
|||
import itertools
|
||||
import random
|
||||
from random import randrange
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask, cGLU, \
|
||||
RelativeQKBias
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import TimestepBlock
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.arch_util import (AttentionBlock, RelativeQKBias, ResBlock,
|
||||
TimestepEmbedSequential,
|
||||
build_local_attention_mask, cGLU)
|
||||
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
|
||||
timestep_embedding, zero_module)
|
||||
from dlas.models.diffusion.unet_diffusion import TimestepBlock
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import checkpoint
|
||||
|
||||
|
||||
class SubBlock(nn.Module):
|
||||
def __init__(self, inp_dim, contraction_dim, heads, dropout):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads)
|
||||
self.register_buffer('mask', build_local_attention_mask(n=6000, l=64), persistent=False)
|
||||
self.attn = AttentionBlock(
|
||||
inp_dim, out_channels=contraction_dim, num_heads=heads)
|
||||
self.register_buffer('mask', build_local_attention_mask(
|
||||
n=6000, l=64), persistent=False)
|
||||
self.pos_bias = RelativeQKBias(l=64, max_positions=6000)
|
||||
ff_contract = contraction_dim//2
|
||||
self.ff1 = nn.Sequential(nn.Conv1d(inp_dim+contraction_dim, ff_contract, kernel_size=1),
|
||||
|
@ -31,7 +34,8 @@ class SubBlock(nn.Module):
|
|||
cGLU(ff_contract))
|
||||
|
||||
def forward(self, x):
|
||||
ah = self.dropout(self.attn(x, mask=self.mask, qk_bias=self.pos_bias(x.shape[-1])))
|
||||
ah = self.dropout(self.attn(x, mask=self.mask,
|
||||
qk_bias=self.pos_bias(x.shape[-1])))
|
||||
h = torch.cat([ah, x], dim=1)
|
||||
hf = self.dropout(checkpoint(self.ff1, h))
|
||||
h = torch.cat([h, hf], dim=1)
|
||||
|
@ -44,17 +48,21 @@ class ConcatAttentionBlock(TimestepBlock):
|
|||
super().__init__()
|
||||
self.contraction_dim = contraction_dim
|
||||
self.prenorm = nn.GroupNorm(8, trunk_dim)
|
||||
self.block1 = SubBlock(trunk_dim+blk_dim, contraction_dim, heads, dropout)
|
||||
self.block2 = SubBlock(trunk_dim+blk_dim+contraction_dim*2, contraction_dim, heads, dropout)
|
||||
self.out = nn.Conv1d(contraction_dim*4, trunk_dim, kernel_size=1, bias=False)
|
||||
self.block1 = SubBlock(
|
||||
trunk_dim+blk_dim, contraction_dim, heads, dropout)
|
||||
self.block2 = SubBlock(
|
||||
trunk_dim+blk_dim+contraction_dim*2, contraction_dim, heads, dropout)
|
||||
self.out = nn.Conv1d(contraction_dim*4, trunk_dim,
|
||||
kernel_size=1, bias=False)
|
||||
self.out.weight.data.zero_()
|
||||
|
||||
def forward(self, x, blk_emb):
|
||||
h = self.prenorm(x)
|
||||
h = torch.cat([h, blk_emb.unsqueeze(-1).repeat(1,1,x.shape[-1])], dim=1)
|
||||
h = torch.cat(
|
||||
[h, blk_emb.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
|
||||
h = self.block1(h)
|
||||
h = self.block2(h)
|
||||
h = self.out(h[:,-self.contraction_dim*4:])
|
||||
h = self.out(h[:, -self.contraction_dim*4:])
|
||||
return h + x
|
||||
|
||||
|
||||
|
@ -72,10 +80,13 @@ class ConditioningEncoder(nn.Module):
|
|||
self.init = nn.Conv1d(spec_dim, hidden_dim, kernel_size=5, stride=2)
|
||||
# nn.Embedding
|
||||
self.resolution_embedding = ml.Embedding(num_resolutions, hidden_dim)
|
||||
self.resolution_embedding.weight.data.mul(.1) # Reduces the relative influence of this embedding from the start.
|
||||
# Reduces the relative influence of this embedding from the start.
|
||||
self.resolution_embedding.weight.data.mul(.1)
|
||||
for a in range(attn_blocks):
|
||||
attn.append(AttentionBlock(hidden_dim, num_attn_heads, do_checkpoint=do_checkpointing))
|
||||
attn.append(ResBlock(hidden_dim, dims=1, checkpointing_enabled=do_checkpointing))
|
||||
attn.append(AttentionBlock(hidden_dim, num_attn_heads,
|
||||
do_checkpoint=do_checkpointing))
|
||||
attn.append(ResBlock(hidden_dim, dims=1,
|
||||
checkpointing_enabled=do_checkpointing))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.out = ml.Linear(hidden_dim, out_dim, bias=False)
|
||||
self.dim = hidden_dim
|
||||
|
@ -91,6 +102,7 @@ class TransformerDiffusion(nn.Module):
|
|||
"""
|
||||
A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way?
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
resolution_steps=8,
|
||||
|
@ -108,7 +120,8 @@ class TransformerDiffusion(nn.Module):
|
|||
dropout=0,
|
||||
use_fp16=False,
|
||||
# Parameters for regularization.
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
# This implements a mechanism similar to what is used in classifier-free training.
|
||||
unconditioned_percentage=.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -135,17 +148,21 @@ class TransformerDiffusion(nn.Module):
|
|||
)
|
||||
# nn.Embedding
|
||||
self.resolution_embed = ml.Embedding(resolution_steps, time_proj_dim)
|
||||
self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64)
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim))
|
||||
self.conditioning_encoder = ConditioningEncoder(
|
||||
in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64)
|
||||
self.unconditioned_embedding = nn.Parameter(
|
||||
torch.randn(1, cond_proj_dim))
|
||||
|
||||
self.inp_block = conv_nd(1, in_channels+input_vec_dim, model_channels, 3, 1, 1)
|
||||
self.inp_block = conv_nd(
|
||||
1, in_channels+input_vec_dim, model_channels, 3, 1, 1)
|
||||
self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_proj_dim*3 + cond_proj_dim,
|
||||
num_heads, dropout) for _ in range(num_layers)])
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(model_channels),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
|
||||
zero_module(conv_nd(1, model_channels,
|
||||
out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
self.debug_codes = {}
|
||||
|
@ -163,17 +180,20 @@ class TransformerDiffusion(nn.Module):
|
|||
s_diff = s.shape[-1] - self.max_window
|
||||
if s_diff > 1:
|
||||
start = randrange(0, s_diff)
|
||||
s = s[:,:,start:start+self.max_window]
|
||||
s = s[:, :, start:start+self.max_window]
|
||||
s_prior = F.interpolate(s, scale_factor=.25, mode='nearest')
|
||||
s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True)
|
||||
s_prior = F.interpolate(s_prior, size=(
|
||||
s.shape[-1],), mode='linear', align_corners=True)
|
||||
|
||||
# Now diffuse the prior randomly between the x timestep and 0.
|
||||
adv = torch.rand_like(ts.float())
|
||||
t_prior = (adv * ts).long() - 1
|
||||
# The t_prior-1 below is an important detail: it forces s_prior to be unmodified for ts=0. It also means that t_prior is not on the same timescale as ts (instead it is shifted by 1).
|
||||
s_prior_diffused = diffuser.q_sample(s_prior, t_prior-1, torch.randn_like(s_prior), allow_negatives=True)
|
||||
s_prior_diffused = diffuser.q_sample(
|
||||
s_prior, t_prior-1, torch.randn_like(s_prior), allow_negatives=True)
|
||||
|
||||
self.preprocessed = (s_prior_diffused, t_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device))
|
||||
self.preprocessed = (s_prior_diffused, t_prior, torch.tensor(
|
||||
[resolution] * x.shape[0], dtype=torch.long, device=x.device))
|
||||
return s
|
||||
|
||||
def forward(self, x, timesteps, prior_timesteps=None, x_prior=None, resolution=None, conditioning_input=None, conditioning_free=False):
|
||||
|
@ -202,12 +222,16 @@ class TransformerDiffusion(nn.Module):
|
|||
x_prior, prior_timesteps, resolution = self.preprocessed
|
||||
self.preprocessed = None
|
||||
else:
|
||||
assert x.shape[-1] > x_prior.shape[-1] * 3.9, f'{x.shape} {x_prior.shape}'
|
||||
assert x.shape[-1] > x_prior.shape[-1] * \
|
||||
3.9, f'{x.shape} {x_prior.shape}'
|
||||
if prior_timesteps is None:
|
||||
# This is taken to mean a fully diffused prior was given.
|
||||
prior_timesteps = torch.tensor([0], device=x.device) # Assuming batch_size=1 for inference.
|
||||
x_prior = F.interpolate(x_prior, size=(x.shape[-1],), mode='linear', align_corners=True)
|
||||
assert torch.all(timesteps - prior_timesteps >= 0), f'Prior timesteps should always be lower (more resolved) than input timesteps. {timesteps}, {prior_timesteps}'
|
||||
# Assuming batch_size=1 for inference.
|
||||
prior_timesteps = torch.tensor([0], device=x.device)
|
||||
x_prior = F.interpolate(x_prior, size=(
|
||||
x.shape[-1],), mode='linear', align_corners=True)
|
||||
assert torch.all(timesteps - prior_timesteps >=
|
||||
0), f'Prior timesteps should always be lower (more resolved) than input timesteps. {timesteps}, {prior_timesteps}'
|
||||
|
||||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1)
|
||||
|
@ -218,19 +242,26 @@ class TransformerDiffusion(nn.Module):
|
|||
clen = randrange(MIN_COND_LEN, MAX_COND_LEN)
|
||||
gap = conditioning_input.shape[-1] - clen
|
||||
cstart = randrange(0, gap)
|
||||
conditioning_input = conditioning_input[:,:,cstart:cstart+clen]
|
||||
code_emb = self.conditioning_encoder(conditioning_input, resolution)
|
||||
conditioning_input = conditioning_input[:,
|
||||
:, cstart:cstart+clen]
|
||||
code_emb = self.conditioning_encoder(
|
||||
conditioning_input, resolution)
|
||||
|
||||
# Mask out the conditioning input and x_prior inputs for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((x.shape[0], 1), device=x.device) < self.unconditioned_percentage
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1), code_emb)
|
||||
unconditioned_batches = torch.rand(
|
||||
(x.shape[0], 1), device=x.device) < self.unconditioned_percentage
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(
|
||||
code_emb.shape[0], 1), code_emb)
|
||||
|
||||
with torch.autocast(x.device.type, enabled=self.enable_fp16):
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
|
||||
prior_time_emb = self.prior_time_embed(timestep_embedding(prior_timesteps, self.time_embed_dim))
|
||||
time_emb = self.time_embed(
|
||||
timestep_embedding(timesteps, self.time_embed_dim))
|
||||
prior_time_emb = self.prior_time_embed(
|
||||
timestep_embedding(prior_timesteps, self.time_embed_dim))
|
||||
res_emb = self.resolution_embed(resolution)
|
||||
blk_emb = torch.cat([time_emb, prior_time_emb, res_emb, code_emb], dim=1)
|
||||
blk_emb = torch.cat(
|
||||
[time_emb, prior_time_emb, res_emb, code_emb], dim=1)
|
||||
|
||||
h = torch.cat([x, x_prior], dim=1)
|
||||
h = self.inp_block(h)
|
||||
|
@ -250,13 +281,16 @@ class TransformerDiffusion(nn.Module):
|
|||
return out
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers]))
|
||||
attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers]))
|
||||
attn1 = list(itertools.chain.from_iterable(
|
||||
[lyr.block1.attn.parameters() for lyr in self.layers]))
|
||||
attn2 = list(itertools.chain.from_iterable(
|
||||
[lyr.block2.attn.parameters() for lyr in self.layers]))
|
||||
ff1 = list(itertools.chain.from_iterable([lyr.block1.ff1.parameters() for lyr in self.layers] +
|
||||
[lyr.block1.ff2.parameters() for lyr in self.layers]))
|
||||
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff1.parameters() for lyr in self.layers] +
|
||||
[lyr.block2.ff2.parameters() for lyr in self.layers]))
|
||||
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers]))
|
||||
blkout_layers = list(itertools.chain.from_iterable(
|
||||
[lyr.out.parameters() for lyr in self.layers]))
|
||||
groups = {
|
||||
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])),
|
||||
'blk1_attention_layers': attn1,
|
||||
|
@ -276,7 +310,8 @@ class TransformerDiffusion(nn.Module):
|
|||
return groups
|
||||
|
||||
def before_step(self, step):
|
||||
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers]))
|
||||
scaled_grad_parameters = list(itertools.chain.from_iterable(
|
||||
[lyr.out.parameters() for lyr in self.layers]))
|
||||
# Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
|
||||
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
|
||||
# directly fiddling with the gradients.
|
||||
|
@ -291,14 +326,13 @@ def register_transformer_diffusion13(opt_net, opt):
|
|||
|
||||
|
||||
def test_tfd():
|
||||
from models.diffusion.respace import SpacedDiffusion
|
||||
from models.diffusion.respace import space_timesteps
|
||||
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
||||
from models.diffusion.respace import SpacedDiffusion, space_timesteps
|
||||
diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [4000]), model_mean_type='epsilon',
|
||||
model_var_type='learned_range', loss_type='mse',
|
||||
betas=get_named_beta_schedule('linear', 4000))
|
||||
clip = torch.randn(2,256,10336)
|
||||
cond = torch.randn(2,256,10336)
|
||||
model_var_type='learned_range', loss_type='mse',
|
||||
betas=get_named_beta_schedule('linear', 4000))
|
||||
clip = torch.randn(2, 256, 10336)
|
||||
cond = torch.randn(2, 256, 10336)
|
||||
ts = torch.LongTensor([0, 0])
|
||||
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
|
||||
num_heads=512//64, input_vec_dim=256, num_layers=12, dropout=.1,
|
||||
|
@ -316,5 +350,5 @@ def remove_conditioning(sd_path):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
#remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_pre\\models\\12500_generator.pth')
|
||||
# remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_pre\\models\\12500_generator.pth')
|
||||
test_tfd()
|
||||
|
|
|
@ -4,18 +4,21 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.arch_util import TimestepEmbedSequential
|
||||
from models.audio.music.encoders import ResEncoder16x
|
||||
from models.audio.music.transformer_diffusion13 import ConcatAttentionBlock
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, print_network
|
||||
from dlas.models.arch_util import TimestepEmbedSequential
|
||||
from dlas.models.audio.music.encoders import ResEncoder16x
|
||||
from dlas.models.audio.music.transformer_diffusion13 import \
|
||||
ConcatAttentionBlock
|
||||
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
|
||||
timestep_embedding, zero_module)
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import checkpoint, print_network
|
||||
|
||||
|
||||
class TransformerDiffusion(nn.Module):
|
||||
"""
|
||||
A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way?
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
time_embed_dim=256,
|
||||
|
@ -27,11 +30,13 @@ class TransformerDiffusion(nn.Module):
|
|||
out_channels=512, # mean and variance
|
||||
num_heads=4,
|
||||
dropout=0,
|
||||
use_corner_alignment=False, # This is an interpolation parameter only provided for backwards compatibility. ALL NEW TRAINS SHOULD SET THIS TO TRUE.
|
||||
# This is an interpolation parameter only provided for backwards compatibility. ALL NEW TRAINS SHOULD SET THIS TO TRUE.
|
||||
use_corner_alignment=False,
|
||||
use_fp16=False,
|
||||
new_code_expansion=False,
|
||||
# Parameters for regularization.
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
# This implements a mechanism similar to what is used in classifier-free training.
|
||||
unconditioned_percentage=.1,
|
||||
# Parameters for re-training head
|
||||
freeze_except_code_converters=False,
|
||||
):
|
||||
|
@ -55,7 +60,8 @@ class TransformerDiffusion(nn.Module):
|
|||
)
|
||||
|
||||
self.input_converter = nn.Conv1d(input_vec_dim, model_channels, 1)
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
|
||||
self.unconditioned_embedding = nn.Parameter(
|
||||
torch.randn(1, model_channels, 1))
|
||||
self.intg = nn.Conv1d(model_channels*2, model_channels, 1)
|
||||
self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim//4,
|
||||
num_heads, dropout) for _ in range(num_layers)])
|
||||
|
@ -63,7 +69,8 @@ class TransformerDiffusion(nn.Module):
|
|||
self.out = nn.Sequential(
|
||||
normalization(model_channels),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
|
||||
zero_module(conv_nd(1, model_channels,
|
||||
out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
if freeze_except_code_converters:
|
||||
|
@ -76,13 +83,16 @@ class TransformerDiffusion(nn.Module):
|
|||
p.requires_grad = True
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers]))
|
||||
attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers]))
|
||||
attn1 = list(itertools.chain.from_iterable(
|
||||
[lyr.block1.attn.parameters() for lyr in self.layers]))
|
||||
attn2 = list(itertools.chain.from_iterable(
|
||||
[lyr.block2.attn.parameters() for lyr in self.layers]))
|
||||
ff1 = list(itertools.chain.from_iterable([lyr.block1.ff1.parameters() for lyr in self.layers] +
|
||||
[lyr.block1.ff2.parameters() for lyr in self.layers]))
|
||||
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff1.parameters() for lyr in self.layers] +
|
||||
[lyr.block2.ff2.parameters() for lyr in self.layers]))
|
||||
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers]))
|
||||
blkout_layers = list(itertools.chain.from_iterable(
|
||||
[lyr.out.parameters() for lyr in self.layers]))
|
||||
groups = {
|
||||
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])),
|
||||
'blk1_attention_layers': attn1,
|
||||
|
@ -101,7 +111,8 @@ class TransformerDiffusion(nn.Module):
|
|||
|
||||
def forward(self, x, timesteps, prior=None, conditioning_free=False):
|
||||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
||||
code_emb = self.unconditioned_embedding.repeat(
|
||||
x.shape[0], 1, x.shape[-1])
|
||||
else:
|
||||
code_emb = self.input_converter(prior)
|
||||
|
||||
|
@ -112,10 +123,12 @@ class TransformerDiffusion(nn.Module):
|
|||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1),
|
||||
code_emb)
|
||||
|
||||
code_emb = F.interpolate(code_emb, size=x.shape[-1], mode='nearest')
|
||||
code_emb = F.interpolate(
|
||||
code_emb, size=x.shape[-1], mode='nearest')
|
||||
|
||||
with torch.autocast(x.device.type, enabled=self.enable_fp16):
|
||||
blk_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
|
||||
blk_emb = self.time_embed(
|
||||
timestep_embedding(timesteps, self.time_embed_dim))
|
||||
x = self.inp_block(x)
|
||||
|
||||
x = self.intg(torch.cat([x, code_emb], dim=1))
|
||||
|
@ -141,7 +154,8 @@ class TransformerDiffusionWithCheaterLatent(nn.Module):
|
|||
self.internal_step = 0
|
||||
self.freeze_encoder_until = freeze_encoder_until
|
||||
self.diff = TransformerDiffusion(**kwargs)
|
||||
self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder)
|
||||
self.encoder = ResEncoder16x(
|
||||
256, 1024, 256, checkpointing_enabled=checkpoint_encoder)
|
||||
|
||||
def forward(self, x, timesteps, truth_mel, conditioning_free=False, cheater=None):
|
||||
unused_parameters = []
|
||||
|
@ -158,7 +172,8 @@ class TransformerDiffusionWithCheaterLatent(nn.Module):
|
|||
for p in unused_parameters:
|
||||
proj = proj + p.mean() * 0
|
||||
|
||||
diff = self.diff(x, timesteps, prior=proj, conditioning_free=conditioning_free)
|
||||
diff = self.diff(x, timesteps, prior=proj,
|
||||
conditioning_free=conditioning_free)
|
||||
return diff
|
||||
|
||||
def get_debug_values(self, step, __):
|
||||
|
@ -172,7 +187,8 @@ class TransformerDiffusionWithCheaterLatent(nn.Module):
|
|||
|
||||
def before_step(self, step):
|
||||
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \
|
||||
list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers]))
|
||||
list(itertools.chain.from_iterable(
|
||||
[lyr.prenorm.parameters() for lyr in self.diff.layers]))
|
||||
# Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
|
||||
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
|
||||
# directly fiddling with the gradients.
|
||||
|
@ -196,10 +212,10 @@ def register_transformer_diffusion_14_with_cheater_latent(opt_net, opt):
|
|||
|
||||
|
||||
def test_tfd():
|
||||
clip = torch.randn(2,256,400)
|
||||
clip = torch.randn(2, 256, 400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
|
||||
num_heads=3, input_vec_dim=256, num_layers=12, dropout=.1)
|
||||
num_heads=3, input_vec_dim=256, num_layers=12, dropout=.1)
|
||||
model(clip, ts, clip)
|
||||
|
||||
|
||||
|
@ -209,14 +225,14 @@ def test_cheater_model():
|
|||
|
||||
# For music:
|
||||
model = TransformerDiffusionWithCheaterLatent(in_channels=256, out_channels=512,
|
||||
model_channels=1024, contraction_dim=512, num_heads=8,
|
||||
input_vec_dim=256, num_layers=16,
|
||||
dropout=.1, new_code_expansion=True,
|
||||
)
|
||||
#diff_weights = torch.load('extracted_diff.pth')
|
||||
#model.diff.load_state_dict(diff_weights, strict=False)
|
||||
#model.encoder.load_state_dict(torch.load('../experiments/music_cheater_encoder_256.pth', map_location=torch.device('cpu')), strict=True)
|
||||
#torch.save(model.state_dict(), 'sample.pth')
|
||||
model_channels=1024, contraction_dim=512, num_heads=8,
|
||||
input_vec_dim=256, num_layers=16,
|
||||
dropout=.1, new_code_expansion=True,
|
||||
)
|
||||
# diff_weights = torch.load('extracted_diff.pth')
|
||||
# model.diff.load_state_dict(diff_weights, strict=False)
|
||||
# model.encoder.load_state_dict(torch.load('../experiments/music_cheater_encoder_256.pth', map_location=torch.device('cpu')), strict=True)
|
||||
# torch.save(model.state_dict(), 'sample.pth')
|
||||
|
||||
print_network(model)
|
||||
o = model(clip, ts, clip)
|
||||
|
@ -234,7 +250,8 @@ def extract_cheater_encoder(in_f, out_f):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
#test_local_attention_mask()
|
||||
extract_cheater_encoder('X:\\dlas\\experiments\\tfd14_and_cheater.pth', 'X:\\dlas\\experiments\\tfd14_cheater_encoder.pth')
|
||||
#test_cheater_model()
|
||||
#extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True)
|
||||
# test_local_attention_mask()
|
||||
extract_cheater_encoder('X:\\dlas\\experiments\\tfd14_and_cheater.pth',
|
||||
'X:\\dlas\\experiments\\tfd14_cheater_encoder.pth')
|
||||
# test_cheater_model()
|
||||
# extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True)
|
||||
|
|
|
@ -1,29 +1,23 @@
|
|||
from abc import abstractmethod
|
||||
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision # For debugging, not actually used.
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.audio.music.gpt_music import GptMusicLower
|
||||
from models.audio.music.music_quantizer import MusicQuantizer
|
||||
from models.diffusion.fp16_util import convert_module_to_f16, convert_module_to_f32
|
||||
from models.diffusion.nn import (
|
||||
conv_nd,
|
||||
linear,
|
||||
avg_pool_nd,
|
||||
zero_module,
|
||||
normalization,
|
||||
timestep_embedding,
|
||||
)
|
||||
from models.lucidrains.x_transformers import Encoder
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, print_network, ceil_multiple
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.audio.music.gpt_music import GptMusicLower
|
||||
from dlas.models.audio.music.music_quantizer import MusicQuantizer
|
||||
from dlas.models.diffusion.fp16_util import (convert_module_to_f16,
|
||||
convert_module_to_f32)
|
||||
from dlas.models.diffusion.nn import (avg_pool_nd, conv_nd, linear,
|
||||
normalization, timestep_embedding,
|
||||
zero_module)
|
||||
from dlas.models.lucidrains.x_transformers import Encoder
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import ceil_multiple, checkpoint, print_network
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
|
@ -80,7 +74,8 @@ class Upsample(nn.Module):
|
|||
if dims == 1:
|
||||
ksize = 5
|
||||
pad = 2
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, ksize, padding=pad)
|
||||
self.conv = conv_nd(dims, self.channels,
|
||||
self.out_channels, ksize, padding=pad)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
|
@ -123,7 +118,7 @@ class Downsample(nn.Module):
|
|||
elif dims == 2:
|
||||
stride = 2
|
||||
else:
|
||||
stride = (1,2,2)
|
||||
stride = (1, 2, 2)
|
||||
if factor is not None:
|
||||
stride = factor
|
||||
if use_conv:
|
||||
|
@ -180,7 +175,8 @@ class ResBlock(TimestepBlock):
|
|||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
|
||||
conv_nd(dims, channels, self.out_channels,
|
||||
kernel_size, padding=padding),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
@ -206,7 +202,8 @@ class ResBlock(TimestepBlock):
|
|||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
|
||||
conv_nd(dims, self.out_channels, self.out_channels,
|
||||
kernel_size, padding=padding)
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -217,7 +214,8 @@ class ResBlock(TimestepBlock):
|
|||
dims, channels, self.out_channels, kernel_size, padding=padding
|
||||
)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, 1)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
|
@ -346,13 +344,15 @@ class QKVAttentionLegacy(nn.Module):
|
|||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3,
|
||||
length).split(ch, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts", q * scale, k * scale
|
||||
) # More stable with f16 than dividing afterwards
|
||||
if rel_pos is not None:
|
||||
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
|
||||
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(
|
||||
bs * self.n_heads, weight.shape[-2], weight.shape[-1])
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
if mask is not None:
|
||||
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
|
||||
|
@ -400,7 +400,8 @@ class QKVAttention(nn.Module):
|
|||
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
|
||||
weight = weight * mask
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
||||
a = th.einsum("bts,bcs->bct", weight,
|
||||
v.reshape(bs * self.n_heads, ch, length))
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
|
@ -460,7 +461,8 @@ class UNetMusicModel(nn.Module):
|
|||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_raw_y_as_embedding=False,
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
# This implements a mechanism similar to what is used in classifier-free training.
|
||||
unconditioned_percentage=.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -493,44 +495,48 @@ class UNetMusicModel(nn.Module):
|
|||
if self.ar_prior:
|
||||
self.ar_input = ml.Linear(input_vec_dim, model_channels)
|
||||
self.ar_prior_intg = Encoder(
|
||||
dim=model_channels,
|
||||
depth=4,
|
||||
heads=num_heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
zero_init_branch_output=True,
|
||||
ff_mult=1,
|
||||
)
|
||||
dim=model_channels,
|
||||
depth=4,
|
||||
heads=num_heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
zero_init_branch_output=True,
|
||||
ff_mult=1,
|
||||
)
|
||||
else:
|
||||
self.input_converter = ml.Linear(input_vec_dim, model_channels)
|
||||
self.code_converter = Encoder(
|
||||
dim=model_channels,
|
||||
depth=4,
|
||||
heads=num_heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
zero_init_branch_output=True,
|
||||
ff_mult=1,
|
||||
)
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))
|
||||
self.x_processor = conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
dim=model_channels,
|
||||
depth=4,
|
||||
heads=num_heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
zero_init_branch_output=True,
|
||||
ff_mult=1,
|
||||
)
|
||||
self.unconditioned_embedding = nn.Parameter(
|
||||
torch.randn(1, 1, model_channels))
|
||||
self.x_processor = conv_nd(
|
||||
dims, in_channels, model_channels, 3, padding=1)
|
||||
|
||||
if self.num_classes is not None:
|
||||
# nn.Embedding
|
||||
self.label_emb = ml.Embedding(num_classes, time_embed_dim)
|
||||
self.use_raw_y_as_embedding = use_raw_y_as_embedding
|
||||
assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive.
|
||||
# These are mutually-exclusive.
|
||||
assert not ((self.num_classes is not None) and use_raw_y_as_embedding)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, model_channels*2, model_channels, 1, padding=0)
|
||||
conv_nd(dims, model_channels*2,
|
||||
model_channels, 1, padding=0)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
@ -657,7 +663,8 @@ class UNetMusicModel(nn.Module):
|
|||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||
zero_module(conv_nd(dims, model_channels,
|
||||
out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps, y, conditioning_free=False):
|
||||
|
@ -666,27 +673,35 @@ class UNetMusicModel(nn.Module):
|
|||
if cm != 0:
|
||||
pc = (cm - x.shape[-1]) / x.shape[-1]
|
||||
x = F.pad(x, (0, cm - x.shape[-1]))
|
||||
y = F.pad(y.permute(0,2,1), (0, int(pc * y.shape[-1]))).permute(0,2,1)
|
||||
y = F.pad(y.permute(0, 2, 1), (0, int(
|
||||
pc * y.shape[-1]))).permute(0, 2, 1)
|
||||
|
||||
unused_params = []
|
||||
hs = []
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
emb = self.time_embed(timestep_embedding(
|
||||
timesteps, self.model_channels))
|
||||
|
||||
if conditioning_free:
|
||||
expanded_code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1).permute(0,2,1)
|
||||
expanded_code_emb = self.unconditioned_embedding.repeat(
|
||||
x.shape[0], x.shape[-1], 1).permute(0, 2, 1)
|
||||
if self.ar_prior:
|
||||
unused_params.extend(list(self.ar_input.parameters()) + list(self.ar_prior_intg.parameters()))
|
||||
unused_params.extend(
|
||||
list(self.ar_input.parameters()) + list(self.ar_prior_intg.parameters()))
|
||||
else:
|
||||
unused_params.extend(list(self.input_converter.parameters()) + list(self.code_converter.parameters()))
|
||||
unused_params.extend(
|
||||
list(self.input_converter.parameters()) + list(self.code_converter.parameters()))
|
||||
else:
|
||||
code_emb = self.ar_input(y) if self.ar_prior else self.input_converter(y)
|
||||
code_emb = self.ar_input(
|
||||
y) if self.ar_prior else self.input_converter(y)
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
||||
device=code_emb.device) < self.unconditioned_percentage
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(y.shape[0], 1, 1),
|
||||
code_emb)
|
||||
code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb)
|
||||
expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=x.shape[-1], mode='nearest')
|
||||
code_emb = self.ar_prior_intg(
|
||||
code_emb) if self.ar_prior else self.code_converter(code_emb)
|
||||
expanded_code_emb = F.interpolate(code_emb.permute(
|
||||
0, 2, 1), size=x.shape[-1], mode='nearest')
|
||||
|
||||
h = x.type(self.dtype)
|
||||
expanded_code_emb = expanded_code_emb.type(self.dtype)
|
||||
|
@ -720,7 +735,8 @@ class UNetMusicModelWithQuantizer(nn.Module):
|
|||
self.internal_step = 0
|
||||
self.freeze_quantizer_until = freeze_quantizer_until
|
||||
self.diff = UNetMusicModel(**kwargs)
|
||||
self.m2v = MusicQuantizer(inp_channels=256, inner_dim=[1024,1024,512], codevector_dim=1024, codebook_size=512, codebook_groups=2)
|
||||
self.m2v = MusicQuantizer(inp_channels=256, inner_dim=[
|
||||
1024, 1024, 512], codevector_dim=1024, codebook_size=512, codebook_groups=2)
|
||||
self.m2v.quantizer.temperature = self.m2v.min_gumbel_temperature
|
||||
del self.m2v.up
|
||||
|
||||
|
@ -728,15 +744,16 @@ class UNetMusicModelWithQuantizer(nn.Module):
|
|||
self.internal_step = step
|
||||
qstep = max(0, self.internal_step - self.freeze_quantizer_until)
|
||||
self.m2v.quantizer.temperature = max(
|
||||
self.m2v.max_gumbel_temperature * self.m2v.gumbel_temperature_decay**qstep,
|
||||
self.m2v.min_gumbel_temperature,
|
||||
)
|
||||
self.m2v.max_gumbel_temperature * self.m2v.gumbel_temperature_decay**qstep,
|
||||
self.m2v.min_gumbel_temperature,
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False):
|
||||
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
|
||||
with torch.set_grad_enabled(quant_grad_enabled):
|
||||
proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True)
|
||||
proj = proj.permute(0,2,1)
|
||||
proj, diversity_loss = self.m2v(
|
||||
truth_mel, return_decoder_latent=True)
|
||||
proj = proj.permute(0, 2, 1)
|
||||
|
||||
# Make sure this does not cause issues in DDP by explicitly using the parameters for nothing.
|
||||
if not quant_grad_enabled:
|
||||
|
@ -746,7 +763,8 @@ class UNetMusicModelWithQuantizer(nn.Module):
|
|||
proj = proj + unused
|
||||
diversity_loss = diversity_loss * 0
|
||||
|
||||
diff = self.diff(x, timesteps, proj, conditioning_free=conditioning_free)
|
||||
diff = self.diff(x, timesteps, proj,
|
||||
conditioning_free=conditioning_free)
|
||||
if disable_diversity:
|
||||
return diff
|
||||
return diff, diversity_loss
|
||||
|
@ -790,7 +808,8 @@ class UNetMusicModelARPrior(nn.Module):
|
|||
with torch.no_grad():
|
||||
prior = self.ar(truth_mel, conditioning_input, return_latent=True)
|
||||
|
||||
diff = self.diff(x, timesteps, prior, conditioning_free=conditioning_free)
|
||||
diff = self.diff(x, timesteps, prior,
|
||||
conditioning_free=conditioning_free)
|
||||
return diff
|
||||
|
||||
|
||||
|
@ -798,6 +817,7 @@ class UNetMusicModelARPrior(nn.Module):
|
|||
def register_unet_diffusion_music_codes(opt_net, opt):
|
||||
return UNetMusicModelWithQuantizer(**opt_net['args'])
|
||||
|
||||
|
||||
@register_model
|
||||
def register_unet_diffusion_music_ar_prior(opt_net, opt):
|
||||
return UNetMusicModelARPrior(**opt_net['args'])
|
||||
|
@ -808,20 +828,21 @@ if __name__ == '__main__':
|
|||
cond = torch.randn(2, 256, 300)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = UNetMusicModelARPrior(in_channels=256, out_channels=512, model_channels=640, num_res_blocks=3, input_vec_dim=512,
|
||||
attention_resolutions=(2,4), channel_mult=(1,2,3), dims=1,
|
||||
attention_resolutions=(2, 4), channel_mult=(1, 2, 3), dims=1,
|
||||
use_scale_shift_norm=True, dropout=.1, num_heads=8, unconditioned_percentage=.4, freeze_unet=True)
|
||||
print_network(model)
|
||||
model.get_grad_norm_parameter_groups()
|
||||
|
||||
ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth')
|
||||
ar_weights = torch.load(
|
||||
'D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth')
|
||||
model.ar.load_state_dict(ar_weights, strict=True)
|
||||
diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_unet_music\\models\\55500_generator_ema.pth')
|
||||
diff_weights = torch.load(
|
||||
'X:\\dlas\\experiments\\train_music_diffusion_unet_music\\models\\55500_generator_ema.pth')
|
||||
pruned_diff_weights = {}
|
||||
for k,v in diff_weights.items():
|
||||
for k, v in diff_weights.items():
|
||||
if k.startswith('diff.'):
|
||||
pruned_diff_weights[k.replace('diff.', '')] = v
|
||||
model.diff.load_state_dict(pruned_diff_weights, strict=False)
|
||||
torch.save(model.state_dict(), 'sample.pth')
|
||||
|
||||
model(clip, ts, cond, cond)
|
||||
|
||||
|
|
|
@ -6,14 +6,19 @@ import torch.nn.functional as F
|
|||
from torch import autocast
|
||||
from x_transformers import Encoder
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
|
||||
Downsample, Upsample, TimestepBlock
|
||||
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder
|
||||
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
from dlas.models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from dlas.models.audio.tts.unet_diffusion_tts7 import \
|
||||
CheckpointedXTransformerEncoder
|
||||
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
|
||||
timestep_embedding, zero_module)
|
||||
from dlas.models.diffusion.unet_diffusion import (AttentionBlock, Downsample,
|
||||
TimestepBlock,
|
||||
TimestepEmbedSequential,
|
||||
Upsample)
|
||||
from dlas.scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import checkpoint
|
||||
|
||||
|
||||
def is_sequence(t):
|
||||
return t.dtype == torch.long
|
||||
|
@ -44,7 +49,8 @@ class ResBlock(TimestepBlock):
|
|||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding),
|
||||
conv_nd(dims, channels, self.out_channels,
|
||||
eff_kernel, padding=eff_padding),
|
||||
)
|
||||
|
||||
self.emb_layers = nn.Sequential(
|
||||
|
@ -59,14 +65,16 @@ class ResBlock(TimestepBlock):
|
|||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
|
||||
conv_nd(dims, self.out_channels, self.out_channels,
|
||||
kernel_size, padding=padding)
|
||||
),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
|
@ -95,6 +103,7 @@ class ResBlock(TimestepBlock):
|
|||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class DiffusionWaveformGen(nn.Module):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
|
@ -138,12 +147,12 @@ class DiffusionWaveformGen(nn.Module):
|
|||
out_channels=2, # mean and variance
|
||||
dropout=0,
|
||||
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
|
||||
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
|
||||
channel_mult=(1, 1.5, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
|
||||
num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
|
||||
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
|
||||
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
|
||||
token_conditioning_resolutions=(1,16,),
|
||||
attention_resolutions=(512,1024,2048),
|
||||
token_conditioning_resolutions=(1, 16,),
|
||||
attention_resolutions=(512, 1024, 2048),
|
||||
conv_resample=True,
|
||||
dims=1,
|
||||
use_fp16=False,
|
||||
|
@ -154,10 +163,12 @@ class DiffusionWaveformGen(nn.Module):
|
|||
scale_factor=2,
|
||||
time_embed_dim_multiplier=4,
|
||||
freeze_main_net=False,
|
||||
efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
|
||||
# Uses kernels with width of 1 in several places rather than 3.
|
||||
efficient_convs=True,
|
||||
use_scale_shift_norm=True,
|
||||
# Parameters for regularization.
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
# This implements a mechanism similar to what is used in classifier-free training.
|
||||
unconditioned_percentage=.1,
|
||||
# Parameters for super-sampling.
|
||||
super_sampling=False,
|
||||
super_sampling_max_noising_factor=.1,
|
||||
|
@ -168,7 +179,8 @@ class DiffusionWaveformGen(nn.Module):
|
|||
num_heads_upsample = num_heads
|
||||
|
||||
if super_sampling:
|
||||
in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input.
|
||||
# In super-sampling mode, the LR input is concatenated directly onto the input.
|
||||
in_channels *= 2
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
|
@ -218,22 +230,31 @@ class DiffusionWaveformGen(nn.Module):
|
|||
rotary_pos_emb=True,
|
||||
)
|
||||
))
|
||||
self.latent_converter = nn.Conv1d(in_latent_channels, conditioning_dim, 1)
|
||||
self.aligned_latent_padding_embedding = nn.Parameter(torch.randn(1,in_latent_channels,1))
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
|
||||
self.latent_converter = nn.Conv1d(
|
||||
in_latent_channels, conditioning_dim, 1)
|
||||
self.aligned_latent_padding_embedding = nn.Parameter(
|
||||
torch.randn(1, in_latent_channels, 1))
|
||||
self.unconditioned_embedding = nn.Parameter(
|
||||
torch.randn(1, conditioning_dim, 1))
|
||||
self.conditioning_timestep_integrator = TimestepEmbedSequential(
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim,
|
||||
dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
AttentionBlock(conditioning_dim, num_heads=num_heads,
|
||||
num_head_channels=num_head_channels),
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim,
|
||||
dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
AttentionBlock(conditioning_dim, num_heads=num_heads,
|
||||
num_head_channels=num_head_channels),
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim,
|
||||
dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
)
|
||||
self.conditioning_expansion = conditioning_expansion
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
|
||||
conv_nd(dims, in_channels, model_channels,
|
||||
kernel_size, padding=padding)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
@ -344,7 +365,8 @@ class DiffusionWaveformGen(nn.Module):
|
|||
if level and i == num_blocks:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor)
|
||||
Upsample(ch, conv_resample, dims=dims,
|
||||
out_channels=out_ch, factor=scale_factor)
|
||||
)
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
|
@ -353,7 +375,8 @@ class DiffusionWaveformGen(nn.Module):
|
|||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels,
|
||||
kernel_size, padding=padding)),
|
||||
)
|
||||
|
||||
if self.freeze_main_net:
|
||||
|
@ -385,13 +408,14 @@ class DiffusionWaveformGen(nn.Module):
|
|||
cm = ceil_multiple(x.shape[-1], self.alignment_size)
|
||||
if cm != 0:
|
||||
pc = (cm-x.shape[-1])/x.shape[-1]
|
||||
x = F.pad(x, (0,cm-x.shape[-1]))
|
||||
x = F.pad(x, (0, cm-x.shape[-1]))
|
||||
# Also fix aligned_latent, which is aligned to x.
|
||||
if self.is_latent(aligned_conditioning):
|
||||
aligned_conditioning = torch.cat([aligned_conditioning,
|
||||
self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1)
|
||||
else:
|
||||
aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1])))
|
||||
aligned_conditioning = F.pad(
|
||||
aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1])))
|
||||
return x, aligned_conditioning
|
||||
|
||||
def forward(self, x, timesteps, aligned_conditioning, conditioning_free=False):
|
||||
|
@ -416,11 +440,13 @@ class DiffusionWaveformGen(nn.Module):
|
|||
with autocast(x.device.type, enabled=self.enable_fp16):
|
||||
|
||||
hs = []
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
time_emb = self.time_embed(
|
||||
timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
# Note: this block does not need to repeated on inference, since it is not timestep-dependent.
|
||||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
|
||||
code_emb = self.unconditioned_embedding.repeat(
|
||||
x.shape[0], 1, 1)
|
||||
else:
|
||||
if self.is_latent(aligned_conditioning):
|
||||
code_emb = self.latent_converter(aligned_conditioning)
|
||||
|
@ -428,15 +454,18 @@ class DiffusionWaveformGen(nn.Module):
|
|||
code_emb = self.mel_converter(aligned_conditioning)
|
||||
|
||||
# Everything after this comment is timestep dependent.
|
||||
code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
|
||||
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
||||
code_emb = torch.repeat_interleave(
|
||||
code_emb, self.conditioning_expansion, dim=-1)
|
||||
code_emb = self.conditioning_timestep_integrator(
|
||||
code_emb, time_emb)
|
||||
|
||||
first = True
|
||||
time_emb = time_emb.float()
|
||||
h = x
|
||||
for k, module in enumerate(self.input_blocks):
|
||||
if isinstance(module, nn.Conv1d):
|
||||
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
|
||||
h_tok = F.interpolate(module(code_emb), size=(
|
||||
h.shape[-1]), mode='nearest')
|
||||
h = h + h_tok
|
||||
else:
|
||||
with autocast(x.device.type, enabled=self.enable_fp16 and not first):
|
||||
|
@ -455,7 +484,8 @@ class DiffusionWaveformGen(nn.Module):
|
|||
|
||||
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
||||
extraneous_addition = 0
|
||||
params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters())
|
||||
params = [self.aligned_latent_padding_embedding,
|
||||
self.unconditioned_embedding] + list(self.latent_converter.parameters())
|
||||
for p in params:
|
||||
extraneous_addition = extraneous_addition + p.mean()
|
||||
out = out + extraneous_addition * 0
|
||||
|
@ -470,13 +500,13 @@ def register_unet_diffusion_waveform_gen(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
clip = torch.randn(2, 1, 32868)
|
||||
aligned_latent = torch.randn(2,388,1024)
|
||||
aligned_sequence = torch.randn(2,120,220)
|
||||
aligned_latent = torch.randn(2, 388, 1024)
|
||||
aligned_sequence = torch.randn(2, 120, 220)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = DiffusionWaveformGen(128,
|
||||
channel_mult=[1,1.5,2, 3, 4, 6, 8],
|
||||
channel_mult=[1, 1.5, 2, 3, 4, 6, 8],
|
||||
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
|
||||
token_conditioning_resolutions=[1,4,16,64],
|
||||
token_conditioning_resolutions=[1, 4, 16, 64],
|
||||
attention_resolutions=[],
|
||||
num_heads=8,
|
||||
kernel_size=3,
|
||||
|
@ -488,4 +518,3 @@ if __name__ == '__main__':
|
|||
o = model(clip, ts, aligned_latent)
|
||||
# Test with sequence aligned conditioning
|
||||
o = model(clip, ts, aligned_sequence)
|
||||
|
||||
|
|
|
@ -4,12 +4,14 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import TimestepEmbedSequential, \
|
||||
Downsample, Upsample, TimestepBlock
|
||||
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, print_network
|
||||
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
|
||||
timestep_embedding, zero_module)
|
||||
from dlas.models.diffusion.unet_diffusion import (Downsample, TimestepBlock,
|
||||
TimestepEmbedSequential,
|
||||
Upsample)
|
||||
from dlas.scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import checkpoint, print_network
|
||||
|
||||
|
||||
def is_sequence(t):
|
||||
|
@ -368,4 +370,3 @@ if __name__ == '__main__':
|
|||
# Test with sequence aligned conditioning
|
||||
o = model(clip, ts, aligned_sequence)
|
||||
print_network(model)
|
||||
|
||||
|
|
|
@ -3,12 +3,14 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch import autocast
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import TimestepEmbedSequential, \
|
||||
Downsample, Upsample, TimestepBlock
|
||||
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
|
||||
timestep_embedding, zero_module)
|
||||
from dlas.models.diffusion.unet_diffusion import (Downsample, TimestepBlock,
|
||||
TimestepEmbedSequential,
|
||||
Upsample)
|
||||
from dlas.scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import checkpoint
|
||||
|
||||
|
||||
def is_sequence(t):
|
||||
|
@ -40,7 +42,8 @@ class ResBlock(TimestepBlock):
|
|||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding),
|
||||
conv_nd(dims, channels, self.out_channels,
|
||||
eff_kernel, padding=eff_padding),
|
||||
)
|
||||
|
||||
self.emb_layers = nn.Sequential(
|
||||
|
@ -55,14 +58,16 @@ class ResBlock(TimestepBlock):
|
|||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
|
||||
conv_nd(dims, self.out_channels, self.out_channels,
|
||||
kernel_size, padding=padding)
|
||||
),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
|
@ -91,6 +96,7 @@ class ResBlock(TimestepBlock):
|
|||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class DiffusionWaveformGen(nn.Module):
|
||||
"""
|
||||
The full UNet model with residual blocks and timestep embedding.
|
||||
|
@ -122,11 +128,11 @@ class DiffusionWaveformGen(nn.Module):
|
|||
out_channels=2, # mean and variance
|
||||
dropout=0,
|
||||
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
|
||||
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
|
||||
channel_mult=(1, 1.5, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
|
||||
num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
|
||||
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
|
||||
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
|
||||
token_conditioning_resolutions=(1,16,),
|
||||
token_conditioning_resolutions=(1, 16,),
|
||||
conv_resample=True,
|
||||
dims=1,
|
||||
use_fp16=False,
|
||||
|
@ -134,10 +140,12 @@ class DiffusionWaveformGen(nn.Module):
|
|||
scale_factor=2,
|
||||
time_embed_dim_multiplier=4,
|
||||
freeze_main_net=False,
|
||||
efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
|
||||
# Uses kernels with width of 1 in several places rather than 3.
|
||||
efficient_convs=True,
|
||||
use_scale_shift_norm=True,
|
||||
# Parameters for regularization.
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
# This implements a mechanism similar to what is used in classifier-free training.
|
||||
unconditioned_percentage=.1,
|
||||
# Parameters for super-sampling.
|
||||
super_sampling=False,
|
||||
super_sampling_max_noising_factor=.1,
|
||||
|
@ -145,7 +153,8 @@ class DiffusionWaveformGen(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
if super_sampling:
|
||||
in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input.
|
||||
# In super-sampling mode, the LR input is concatenated directly onto the input.
|
||||
in_channels *= 2
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
|
@ -175,19 +184,25 @@ class DiffusionWaveformGen(nn.Module):
|
|||
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
||||
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
||||
# transformer network.
|
||||
self.mel_converter = nn.Conv1d(in_mel_channels, conditioning_dim, 3, padding=1)
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
|
||||
self.mel_converter = nn.Conv1d(
|
||||
in_mel_channels, conditioning_dim, 3, padding=1)
|
||||
self.unconditioned_embedding = nn.Parameter(
|
||||
torch.randn(1, conditioning_dim, 1))
|
||||
self.conditioning_timestep_integrator = TimestepEmbedSequential(
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim,
|
||||
dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim,
|
||||
dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim,
|
||||
dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
)
|
||||
self.conditioning_expansion = conditioning_expansion
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
|
||||
conv_nd(dims, in_channels, model_channels,
|
||||
kernel_size, padding=padding)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
@ -268,7 +283,8 @@ class DiffusionWaveformGen(nn.Module):
|
|||
if level and i == num_blocks:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor)
|
||||
Upsample(ch, conv_resample, dims=dims,
|
||||
out_channels=out_ch, factor=scale_factor)
|
||||
)
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
|
@ -277,7 +293,8 @@ class DiffusionWaveformGen(nn.Module):
|
|||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels,
|
||||
kernel_size, padding=padding)),
|
||||
)
|
||||
|
||||
if self.freeze_main_net:
|
||||
|
@ -306,8 +323,9 @@ class DiffusionWaveformGen(nn.Module):
|
|||
cm = ceil_multiple(x.shape[-1], self.alignment_size)
|
||||
if cm != 0:
|
||||
pc = (cm-x.shape[-1])/x.shape[-1]
|
||||
x = F.pad(x, (0,cm-x.shape[-1]))
|
||||
aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1])))
|
||||
x = F.pad(x, (0, cm-x.shape[-1]))
|
||||
aligned_conditioning = F.pad(
|
||||
aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1])))
|
||||
return x, aligned_conditioning
|
||||
|
||||
def forward(self, x, timesteps, aligned_conditioning, conditioning_free=False):
|
||||
|
@ -327,24 +345,29 @@ class DiffusionWaveformGen(nn.Module):
|
|||
with autocast(x.device.type, enabled=self.enable_fp16):
|
||||
|
||||
hs = []
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
time_emb = self.time_embed(
|
||||
timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
# Note: this block does not need to repeated on inference, since it is not timestep-dependent.
|
||||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
|
||||
code_emb = self.unconditioned_embedding.repeat(
|
||||
x.shape[0], 1, 1)
|
||||
else:
|
||||
code_emb = self.mel_converter(aligned_conditioning)
|
||||
|
||||
# Everything after this comment is timestep dependent.
|
||||
code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
|
||||
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
||||
code_emb = torch.repeat_interleave(
|
||||
code_emb, self.conditioning_expansion, dim=-1)
|
||||
code_emb = self.conditioning_timestep_integrator(
|
||||
code_emb, time_emb)
|
||||
|
||||
first = True
|
||||
time_emb = time_emb.float()
|
||||
h = x
|
||||
for k, module in enumerate(self.input_blocks):
|
||||
if isinstance(module, nn.Conv1d):
|
||||
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
|
||||
h_tok = F.interpolate(module(code_emb), size=(
|
||||
h.shape[-1]), mode='nearest')
|
||||
h = h + h_tok
|
||||
else:
|
||||
with autocast(x.device.type, enabled=self.enable_fp16 and not first):
|
||||
|
@ -378,12 +401,12 @@ def register_unet_diffusion_waveform_gen2(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
clip = torch.randn(2, 1, 32868)
|
||||
aligned_sequence = torch.randn(2,120,220)
|
||||
aligned_sequence = torch.randn(2, 120, 220)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = DiffusionWaveformGen(128,
|
||||
channel_mult=[1,1.5,2, 3, 4, 6, 8],
|
||||
channel_mult=[1, 1.5, 2, 3, 4, 6, 8],
|
||||
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
|
||||
token_conditioning_resolutions=[1,4,16,64],
|
||||
token_conditioning_resolutions=[1, 4, 16, 64],
|
||||
kernel_size=3,
|
||||
scale_factor=2,
|
||||
time_embed_dim_multiplier=4,
|
||||
|
@ -391,4 +414,3 @@ if __name__ == '__main__':
|
|||
efficient_convs=False)
|
||||
# Test with sequence aligned conditioning
|
||||
o = model(clip, ts, aligned_sequence)
|
||||
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import GPT2PreTrainedModel, GPT2Config
|
||||
from transformers import GPT2Config, GPT2PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
|
||||
from models.arch_util import AttentionBlock
|
||||
from models.lucidrains.x_transformers import TransformerWrapper, Decoder, Encoder
|
||||
from trainer.networks import register_model
|
||||
from dlas.models.arch_util import AttentionBlock
|
||||
from dlas.models.lucidrains.x_transformers import (Decoder, Encoder,
|
||||
TransformerWrapper)
|
||||
from dlas.trainer.networks import register_model
|
||||
|
||||
|
||||
class InferenceModel(GPT2PreTrainedModel):
|
||||
|
@ -14,6 +15,7 @@ class InferenceModel(GPT2PreTrainedModel):
|
|||
Implementation of GPT2PreTrainedModel from transformers, which allows us to use their generation library with
|
||||
this transformer.
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__(GPT2Config())
|
||||
self.transformer = model
|
||||
|
@ -83,7 +85,8 @@ class InferenceModel(GPT2PreTrainedModel):
|
|||
):
|
||||
assert self.context is not None
|
||||
assert inputs_embeds is None # Not supported by this inference model.
|
||||
assert labels is None # Training not supported by this inference model.
|
||||
# Training not supported by this inference model.
|
||||
assert labels is None
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
out = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True, past_key_values=past_key_values,
|
||||
|
@ -115,7 +118,8 @@ class InferenceModel(GPT2PreTrainedModel):
|
|||
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
||||
"""
|
||||
return tuple(
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device))
|
||||
for past_state in layer_past)
|
||||
for layer_past in past
|
||||
)
|
||||
|
||||
|
@ -124,6 +128,7 @@ class ResBlock(nn.Module):
|
|||
"""
|
||||
Basic residual convolutional block that uses GroupNorm.
|
||||
"""
|
||||
|
||||
def __init__(self, chan):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
|
@ -148,11 +153,13 @@ class ConditioningEncoder(nn.Module):
|
|||
super().__init__()
|
||||
attn = []
|
||||
self.init = nn.Sequential(nn.Conv1d(spec_dim, embedding_dim//4, kernel_size=5, padding=2),
|
||||
nn.Conv1d(embedding_dim//4, embedding_dim//2, kernel_size=3, padding=1, stride=2),
|
||||
nn.Conv1d(embedding_dim//4, embedding_dim //
|
||||
2, kernel_size=3, padding=1, stride=2),
|
||||
ResBlock(embedding_dim//2),
|
||||
nn.Conv1d(embedding_dim//2, embedding_dim, kernel_size=3, padding=1, stride=2))
|
||||
for a in range(attn_blocks):
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
|
||||
attn.append(AttentionBlock(embedding_dim,
|
||||
num_attn_heads, do_checkpoint=do_checkpointing))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
|
||||
|
@ -165,50 +172,53 @@ class ConditioningEncoder(nn.Module):
|
|||
class AutoregressiveCodegen(nn.Module):
|
||||
def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1):
|
||||
super().__init__()
|
||||
assert depth >= 8 # This is the minimum bound to support the context interleaving that happens later.
|
||||
# This is the minimum bound to support the context interleaving that happens later.
|
||||
assert depth >= 8
|
||||
|
||||
self.START_TOKEN=8192
|
||||
self.STOP_TOKEN=8193
|
||||
self.START_TOKEN = 8192
|
||||
self.STOP_TOKEN = 8193
|
||||
self.START_TEXT_TOKEN = 255
|
||||
self.STOP_TEXT_TOKEN = 0
|
||||
self.max_text_token_id = num_text_tokens
|
||||
self.max_mel_token_id = num_mel_tokens
|
||||
self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False)
|
||||
self.mel_embedding = ConditioningEncoder(
|
||||
80, model_dim, do_checkpointing=False)
|
||||
self.encoder = TransformerWrapper(
|
||||
num_tokens=num_text_tokens,
|
||||
use_pos_emb=False,
|
||||
max_seq_len=-1,
|
||||
attn_layers = Encoder(
|
||||
depth=depth,
|
||||
heads=model_dim//64,
|
||||
dim=model_dim,
|
||||
attn_dropout=dropout,
|
||||
ff_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
ff_mult=1,
|
||||
rotary_pos_emb=True,
|
||||
attn_rel_pos_bias=True,
|
||||
))
|
||||
self.encoder.norm = nn.Identity() # This layer and the next are unused.
|
||||
num_tokens=num_text_tokens,
|
||||
use_pos_emb=False,
|
||||
max_seq_len=-1,
|
||||
attn_layers=Encoder(
|
||||
depth=depth,
|
||||
heads=model_dim//64,
|
||||
dim=model_dim,
|
||||
attn_dropout=dropout,
|
||||
ff_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
ff_mult=1,
|
||||
rotary_pos_emb=True,
|
||||
attn_rel_pos_bias=True,
|
||||
))
|
||||
# This layer and the next are unused.
|
||||
self.encoder.norm = nn.Identity()
|
||||
self.encoder.to_logits = nn.Identity()
|
||||
self.decoder = TransformerWrapper(
|
||||
num_tokens=num_mel_tokens,
|
||||
use_pos_emb=False,
|
||||
max_seq_len=-1,
|
||||
attn_layers=Decoder(
|
||||
depth=depth,
|
||||
heads=model_dim//64,
|
||||
dim=model_dim,
|
||||
attn_dropout=dropout,
|
||||
ff_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
ff_mult=1,
|
||||
rotary_pos_emb=True,
|
||||
cross_attend=True,
|
||||
attn_rel_pos_bias=True,
|
||||
))
|
||||
num_tokens=num_mel_tokens,
|
||||
use_pos_emb=False,
|
||||
max_seq_len=-1,
|
||||
attn_layers=Decoder(
|
||||
depth=depth,
|
||||
heads=model_dim//64,
|
||||
dim=model_dim,
|
||||
attn_dropout=dropout,
|
||||
ff_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
ff_mult=1,
|
||||
rotary_pos_emb=True,
|
||||
cross_attend=True,
|
||||
attn_rel_pos_bias=True,
|
||||
))
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
return {
|
||||
|
@ -218,8 +228,10 @@ class AutoregressiveCodegen(nn.Module):
|
|||
}
|
||||
|
||||
def forward(self, text_codes, conditioning_signal, mel_codes, wav_lengths, return_loss=True):
|
||||
assert text_codes.max() < self.max_text_token_id and text_codes.min() >= 0, f'Invalid text code encountered: {text_codes.max()}, {text_codes.min()}'
|
||||
assert mel_codes.max() < self.max_mel_token_id and mel_codes.min() >= 0, f'Invalid mel code encountered: {mel_codes.max()}, {mel_codes.min()}'
|
||||
assert text_codes.max() < self.max_text_token_id and text_codes.min(
|
||||
) >= 0, f'Invalid text code encountered: {text_codes.max()}, {text_codes.min()}'
|
||||
assert mel_codes.max() < self.max_mel_token_id and mel_codes.min(
|
||||
) >= 0, f'Invalid mel code encountered: {mel_codes.max()}, {mel_codes.min()}'
|
||||
|
||||
# Format mel_codes with a stop token on the end.
|
||||
mel_lengths = wav_lengths // 1024 + 1
|
||||
|
@ -235,8 +247,8 @@ class AutoregressiveCodegen(nn.Module):
|
|||
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
|
||||
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
|
||||
# Since all positional embeddings are relative, it is (probably) important to "fix" the text with some permanent embeddings.
|
||||
text_codes = F.pad(text_codes, (1,0), value=self.START_TEXT_TOKEN)
|
||||
text_codes = F.pad(text_codes, (0,1), value=self.STOP_TEXT_TOKEN)
|
||||
text_codes = F.pad(text_codes, (1, 0), value=self.START_TEXT_TOKEN)
|
||||
text_codes = F.pad(text_codes, (0, 1), value=self.STOP_TEXT_TOKEN)
|
||||
_, enc_text = self.encoder(text_codes, return_hiddens=True)
|
||||
# Interleave cond_emb into the first few contexts.
|
||||
full_context = enc_text
|
||||
|
@ -245,11 +257,11 @@ class AutoregressiveCodegen(nn.Module):
|
|||
full_context[6] = cond_emb
|
||||
|
||||
# Execute the decoder
|
||||
dec_inputs = F.pad(mel_codes, (1,0), value=self.START_TOKEN)[:, :-1]
|
||||
dec_inputs = F.pad(mel_codes, (1, 0), value=self.START_TOKEN)[:, :-1]
|
||||
dec = self.decoder(dec_inputs, full_context=full_context)
|
||||
if not return_loss:
|
||||
return dec
|
||||
loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
|
||||
loss_mel = F.cross_entropy(dec.permute(0, 2, 1), mel_codes)
|
||||
return loss_mel
|
||||
|
||||
def generate(self, conditioning_signal, text_codes, max_tokens=256, **hf_generate_kwargs):
|
||||
|
@ -261,8 +273,8 @@ class AutoregressiveCodegen(nn.Module):
|
|||
for i in range(conditioning_signal.shape[1]):
|
||||
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
|
||||
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
|
||||
text_codes = F.pad(text_codes, (1,0), value=self.START_TEXT_TOKEN)
|
||||
text_codes = F.pad(text_codes, (0,1), value=self.STOP_TEXT_TOKEN)
|
||||
text_codes = F.pad(text_codes, (1, 0), value=self.START_TEXT_TOKEN)
|
||||
text_codes = F.pad(text_codes, (0, 1), value=self.STOP_TEXT_TOKEN)
|
||||
_, enc_text = self.encoder(text_codes, return_hiddens=True)
|
||||
# Interleave cond_emb into the first few contexts.
|
||||
full_context = enc_text
|
||||
|
@ -273,7 +285,7 @@ class AutoregressiveCodegen(nn.Module):
|
|||
|
||||
gen = inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
|
||||
max_length=max_tokens, output_attentions=False, return_dict_in_generate=True, use_cache=True,
|
||||
**hf_generate_kwargs)
|
||||
**hf_generate_kwargs)
|
||||
return gen.sequences
|
||||
|
||||
|
||||
|
@ -285,8 +297,8 @@ def register_autoregressive_codegen(opt_net, opt):
|
|||
if __name__ == '__main__':
|
||||
codegen = AutoregressiveCodegen(256, 10)
|
||||
torch.save(codegen.state_dict(), 'sample.pth')
|
||||
#codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200)))
|
||||
codegen(torch.randint(0,256, (2,200)),
|
||||
torch.randn(2,80,120),
|
||||
torch.randint(0,8192, (2,350)),
|
||||
torch.tensor([192,350]))
|
||||
# codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200)))
|
||||
codegen(torch.randint(0, 256, (2, 200)),
|
||||
torch.randn(2, 80, 120),
|
||||
torch.randint(0, 8192, (2, 350)),
|
||||
torch.tensor([192, 350]))
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import GPT2PreTrainedModel, GPT2Config
|
||||
from transformers import GPT2Config, GPT2PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
|
||||
from models.arch_util import AttentionBlock
|
||||
from models.lucidrains.x_transformers import TransformerWrapper, Encoder, Decoder
|
||||
from trainer.networks import register_model
|
||||
from dlas.models.arch_util import AttentionBlock
|
||||
from dlas.models.lucidrains.x_transformers import (Decoder, Encoder,
|
||||
TransformerWrapper)
|
||||
from dlas.trainer.networks import register_model
|
||||
|
||||
|
||||
class InferenceModel(GPT2PreTrainedModel):
|
||||
|
@ -14,6 +15,7 @@ class InferenceModel(GPT2PreTrainedModel):
|
|||
Implementation of GPT2PreTrainedModel from transformers, which allows us to use their generation library with
|
||||
this transformer.
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__(GPT2Config())
|
||||
self.transformer = model
|
||||
|
@ -83,10 +85,12 @@ class InferenceModel(GPT2PreTrainedModel):
|
|||
):
|
||||
assert self.context is not None
|
||||
assert inputs_embeds is None # Not supported by this inference model.
|
||||
assert labels is None # Training not supported by this inference model.
|
||||
# Training not supported by this inference model.
|
||||
assert labels is None
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
hidden_states = self.transformer.decoder(input_ids, context=self.context, return_embeddings=True)
|
||||
hidden_states = self.transformer.decoder(
|
||||
input_ids, context=self.context, return_embeddings=True)
|
||||
logits = self.transformer.decoder.to_logits(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
|
@ -109,7 +113,8 @@ class InferenceModel(GPT2PreTrainedModel):
|
|||
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
||||
"""
|
||||
return tuple(
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device))
|
||||
for past_state in layer_past)
|
||||
for layer_past in past
|
||||
)
|
||||
|
||||
|
@ -118,6 +123,7 @@ class ResBlock(nn.Module):
|
|||
"""
|
||||
Basic residual convolutional block that uses GroupNorm.
|
||||
"""
|
||||
|
||||
def __init__(self, chan):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
|
@ -142,11 +148,13 @@ class ConditioningEncoder(nn.Module):
|
|||
super().__init__()
|
||||
attn = []
|
||||
self.init = nn.Sequential(nn.Conv1d(spec_dim, embedding_dim//4, kernel_size=5, padding=2),
|
||||
nn.Conv1d(embedding_dim//4, embedding_dim//2, kernel_size=3, padding=1, stride=2),
|
||||
nn.Conv1d(embedding_dim//4, embedding_dim //
|
||||
2, kernel_size=3, padding=1, stride=2),
|
||||
ResBlock(embedding_dim//2),
|
||||
nn.Conv1d(embedding_dim//2, embedding_dim, kernel_size=3, padding=1, stride=2))
|
||||
for a in range(attn_blocks):
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
|
||||
attn.append(AttentionBlock(embedding_dim,
|
||||
num_attn_heads, do_checkpoint=do_checkpointing))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
|
||||
|
@ -160,45 +168,46 @@ class AutoregressiveCodegen(nn.Module):
|
|||
def __init__(self, model_dim, encoder_depth, decoder_depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1, ff_mult=1):
|
||||
super().__init__()
|
||||
|
||||
self.START_TOKEN=8192
|
||||
self.STOP_TOKEN=8193
|
||||
self.START_TOKEN = 8192
|
||||
self.STOP_TOKEN = 8193
|
||||
self.max_text_token_id = num_text_tokens
|
||||
self.max_mel_token_id = num_mel_tokens
|
||||
self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False)
|
||||
self.mel_embedding = ConditioningEncoder(
|
||||
80, model_dim, do_checkpointing=False)
|
||||
self.encoder = TransformerWrapper(
|
||||
num_tokens=num_text_tokens,
|
||||
use_pos_emb=False,
|
||||
max_seq_len=-1,
|
||||
attn_layers = Encoder(
|
||||
depth=encoder_depth,
|
||||
heads=model_dim//64,
|
||||
dim=model_dim,
|
||||
attn_dropout=dropout,
|
||||
ff_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
ff_mult=ff_mult,
|
||||
rotary_pos_emb=True,
|
||||
attn_rel_pos_bias=True,
|
||||
))
|
||||
num_tokens=num_text_tokens,
|
||||
use_pos_emb=False,
|
||||
max_seq_len=-1,
|
||||
attn_layers=Encoder(
|
||||
depth=encoder_depth,
|
||||
heads=model_dim//64,
|
||||
dim=model_dim,
|
||||
attn_dropout=dropout,
|
||||
ff_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
ff_mult=ff_mult,
|
||||
rotary_pos_emb=True,
|
||||
attn_rel_pos_bias=True,
|
||||
))
|
||||
self.encoder.to_logits = nn.Identity() # This is unused.
|
||||
self.decoder = TransformerWrapper(
|
||||
num_tokens=num_mel_tokens,
|
||||
use_pos_emb=False,
|
||||
max_seq_len=-1,
|
||||
attn_layers=Decoder(
|
||||
depth=decoder_depth,
|
||||
heads=model_dim//64,
|
||||
dim=model_dim,
|
||||
attn_dropout=dropout,
|
||||
ff_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
ff_mult=ff_mult,
|
||||
rotary_pos_emb=True,
|
||||
cross_attend=True,
|
||||
attn_rel_pos_bias=True,
|
||||
))
|
||||
num_tokens=num_mel_tokens,
|
||||
use_pos_emb=False,
|
||||
max_seq_len=-1,
|
||||
attn_layers=Decoder(
|
||||
depth=decoder_depth,
|
||||
heads=model_dim//64,
|
||||
dim=model_dim,
|
||||
attn_dropout=dropout,
|
||||
ff_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
ff_mult=ff_mult,
|
||||
rotary_pos_emb=True,
|
||||
cross_attend=True,
|
||||
attn_rel_pos_bias=True,
|
||||
))
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
return {
|
||||
|
@ -208,8 +217,10 @@ class AutoregressiveCodegen(nn.Module):
|
|||
}
|
||||
|
||||
def forward(self, text_codes, conditioning_signal, mel_codes, wav_lengths, return_loss=True):
|
||||
assert text_codes.max() < self.max_text_token_id and text_codes.min() >= 0, f'Invalid text code encountered: {text_codes.max()}, {text_codes.min()}'
|
||||
assert mel_codes.max() < self.max_mel_token_id and mel_codes.min() >= 0, f'Invalid mel code encountered: {mel_codes.max()}, {mel_codes.min()}'
|
||||
assert text_codes.max() < self.max_text_token_id and text_codes.min(
|
||||
) >= 0, f'Invalid text code encountered: {text_codes.max()}, {text_codes.min()}'
|
||||
assert mel_codes.max() < self.max_mel_token_id and mel_codes.min(
|
||||
) >= 0, f'Invalid mel code encountered: {mel_codes.max()}, {mel_codes.min()}'
|
||||
|
||||
# Format mel_codes with a stop token on the end.
|
||||
mel_lengths = wav_lengths // 1024 + 1
|
||||
|
@ -228,11 +239,11 @@ class AutoregressiveCodegen(nn.Module):
|
|||
context = torch.cat([cond_emb, enc_text], dim=1)
|
||||
|
||||
# Execute the decoder
|
||||
dec_inputs = F.pad(mel_codes, (1,0), value=self.START_TOKEN)[:, :-1]
|
||||
dec_inputs = F.pad(mel_codes, (1, 0), value=self.START_TOKEN)[:, :-1]
|
||||
dec = self.decoder(dec_inputs, context=context)
|
||||
if not return_loss:
|
||||
return dec
|
||||
loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
|
||||
loss_mel = F.cross_entropy(dec.permute(0, 2, 1), mel_codes)
|
||||
return loss_mel
|
||||
|
||||
def generate(self, conditioning_signal, text_codes, max_tokens=1024, **hf_generate_kwargs):
|
||||
|
@ -263,8 +274,9 @@ def register_autoregressive_codegen2(opt_net, opt):
|
|||
if __name__ == '__main__':
|
||||
codegen = AutoregressiveCodegen(512, 20)
|
||||
torch.save(codegen.state_dict(), 'sample.pth')
|
||||
codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200)))
|
||||
codegen(torch.randint(0,256, (2,200)),
|
||||
torch.randn(2,80,120),
|
||||
torch.randint(0,8192, (2,350)),
|
||||
torch.tensor([192,350]))
|
||||
codegen.generate(torch.randn((1, 80, 120)),
|
||||
torch.randint(0, 256, (1, 200)))
|
||||
codegen(torch.randint(0, 256, (2, 200)),
|
||||
torch.randn(2, 80, 120),
|
||||
torch.randint(0, 8192, (2, 350)),
|
||||
torch.tensor([192, 350]))
|
||||
|
|
|
@ -3,12 +3,12 @@ from random import random
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.audio.tts.unet_diffusion_tts7 import CheckpointedLayer
|
||||
from models.lucidrains.x_transformers import Encoder
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.audio.tts.unet_diffusion_tts7 import CheckpointedLayer
|
||||
from dlas.models.lucidrains.x_transformers import Encoder
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
class CheckpointedXTransformerEncoder(nn.Module):
|
||||
|
@ -16,6 +16,7 @@ class CheckpointedXTransformerEncoder(nn.Module):
|
|||
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
|
||||
to channels-last that XTransformer expects.
|
||||
"""
|
||||
|
||||
def __init__(self, **xtransformer_kwargs):
|
||||
super().__init__()
|
||||
self.transformer = XTransformer(**xtransformer_kwargs)
|
||||
|
@ -23,7 +24,8 @@ class CheckpointedXTransformerEncoder(nn.Module):
|
|||
for xform in [self.transformer.encoder, self.transformer.decoder.net]:
|
||||
for i in range(len(xform.attn_layers.layers)):
|
||||
n, b, r = xform.attn_layers.layers[i]
|
||||
xform.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
|
||||
xform.attn_layers.layers[i] = nn.ModuleList(
|
||||
[n, CheckpointedLayer(b), r])
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.transformer(*args, **kwargs)
|
||||
|
@ -45,20 +47,21 @@ class CtcCodeGenerator(nn.Module):
|
|||
self.recursive_embedding = ml.Embedding(pred_codes, model_dim)
|
||||
self.mask_embedding = nn.Parameter(torch.randn(model_dim))
|
||||
self.encoder = Encoder(
|
||||
dim=model_dim,
|
||||
depth=layers,
|
||||
heads=model_dim//64,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
)
|
||||
dim=model_dim,
|
||||
depth=layers,
|
||||
heads=model_dim//64,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
)
|
||||
self.pred_head = ml.Linear(model_dim, pred_codes)
|
||||
self.confidence_head = ml.Linear(model_dim, 1)
|
||||
|
||||
def inference(self, codes, pads, repeats):
|
||||
position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device))
|
||||
position_h = self.position_embedding(
|
||||
torch.arange(0, codes.shape[-1], device=codes.device))
|
||||
codes_h = self.codes_embedding(codes)
|
||||
|
||||
labels = pads + repeats * self.max_pad
|
||||
|
@ -83,13 +86,15 @@ class CtcCodeGenerator(nn.Module):
|
|||
print(f"Got unexpectedly long pads. Max: {pads.max()}, {pads}")
|
||||
pads = torch.clip(pads, 0, self.max_pad)
|
||||
if repeats.max() > self.max_repeat:
|
||||
print(f"Got unexpectedly long repeats. Max: {repeats.max()}, {repeats}")
|
||||
print(
|
||||
f"Got unexpectedly long repeats. Max: {repeats.max()}, {repeats}")
|
||||
repeats = torch.clip(repeats, 0, self.max_repeat)
|
||||
assert codes.max() < self.ctc_codes, codes.max()
|
||||
|
||||
labels = pads + repeats * self.max_pad
|
||||
|
||||
position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device))
|
||||
position_h = self.position_embedding(
|
||||
torch.arange(0, codes.shape[-1], device=codes.device))
|
||||
codes_h = self.codes_embedding(codes)
|
||||
recursive_h = self.recursive_embedding(labels)
|
||||
|
||||
|
@ -101,12 +106,14 @@ class CtcCodeGenerator(nn.Module):
|
|||
|
||||
h = self.encoder(position_h + codes_h + recursive_h)
|
||||
pred_logits = self.pred_head(h)
|
||||
loss = F.cross_entropy(pred_logits.permute(0,2,1), labels, reduce=False)
|
||||
loss = F.cross_entropy(pred_logits.permute(
|
||||
0, 2, 1), labels, reduce=False)
|
||||
|
||||
confidences = self.confidence_head(h).squeeze(-1)
|
||||
confidences = F.softmax(confidences * mask, dim=-1)
|
||||
confidence_loss = loss * confidences
|
||||
loss = loss / loss.shape[-1] # This balances the confidence_loss and loss.
|
||||
# This balances the confidence_loss and loss.
|
||||
loss = loss / loss.shape[-1]
|
||||
|
||||
return loss.mean(), confidence_loss.mean()
|
||||
|
||||
|
@ -118,8 +125,8 @@ def register_ctc_code_generator(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
model = CtcCodeGenerator()
|
||||
inps = torch.randint(0,36, (4, 300))
|
||||
pads = torch.randint(0,100, (4,300))
|
||||
repeats = torch.randint(0,20, (4,300))
|
||||
inps = torch.randint(0, 36, (4, 300))
|
||||
pads = torch.randint(0, 100, (4, 300))
|
||||
repeats = torch.randint(0, 20, (4, 300))
|
||||
loss = model(inps, pads, repeats)
|
||||
print(loss.shape)
|
||||
print(loss.shape)
|
||||
|
|
|
@ -1,16 +1,22 @@
|
|||
import functools
|
||||
import math
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_intermediary as ml
|
||||
from x_transformers.x_transformers import (DEFAULT_DIM_HEAD,
|
||||
AlibiPositionalBias, Attention,
|
||||
AttentionLayers, FeedForward,
|
||||
FixedPositionalEmbedding, GRUGating,
|
||||
LayerIntermediates,
|
||||
LearnedAlibiPositionalBias,
|
||||
RelativePositionBias, Residual,
|
||||
RMSNorm, RotaryEmbedding, Scale,
|
||||
ScaleNorm, ShiftTokens, cast_tuple,
|
||||
default, equals, exists,
|
||||
groupby_prefix_and_trim, not_equals)
|
||||
|
||||
from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \
|
||||
DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, \
|
||||
exists, Attention, FeedForward, Scale, ShiftTokens, GRUGating, Residual, cast_tuple, equals, LayerIntermediates, \
|
||||
AttentionLayers, not_equals
|
||||
import dlas.torch_intermediary as ml
|
||||
|
||||
|
||||
class TimeIntegrationBlock(nn.Module):
|
||||
|
@ -36,40 +42,41 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
|||
"""
|
||||
Modification of x-transformers.AttentionLayers that performs timestep embeddings and layerdrop.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
timestep_dim,
|
||||
depth,
|
||||
heads = 8,
|
||||
causal = False,
|
||||
cross_attend = False,
|
||||
only_cross = False,
|
||||
use_scalenorm = False,
|
||||
use_rmsnorm = False,
|
||||
use_rezero = False,
|
||||
alibi_pos_bias = False,
|
||||
alibi_num_heads = None,
|
||||
alibi_learned = False,
|
||||
rel_pos_bias = False,
|
||||
rel_pos_num_buckets = 32,
|
||||
rel_pos_max_distance = 128,
|
||||
position_infused_attn = False,
|
||||
rotary_pos_emb = False,
|
||||
rotary_emb_dim = None,
|
||||
custom_layers = None,
|
||||
sandwich_coef = None,
|
||||
par_ratio = None,
|
||||
residual_attn = False,
|
||||
cross_residual_attn = False,
|
||||
macaron = False,
|
||||
gate_residual = False,
|
||||
scale_residual = False,
|
||||
shift_tokens = 0,
|
||||
use_qk_norm_attn = False,
|
||||
qk_norm_attn_seq_len = None,
|
||||
zero_init_branch_output = False,
|
||||
layerdrop_percent = .1,
|
||||
heads=8,
|
||||
causal=False,
|
||||
cross_attend=False,
|
||||
only_cross=False,
|
||||
use_scalenorm=False,
|
||||
use_rmsnorm=False,
|
||||
use_rezero=False,
|
||||
alibi_pos_bias=False,
|
||||
alibi_num_heads=None,
|
||||
alibi_learned=False,
|
||||
rel_pos_bias=False,
|
||||
rel_pos_num_buckets=32,
|
||||
rel_pos_max_distance=128,
|
||||
position_infused_attn=False,
|
||||
rotary_pos_emb=False,
|
||||
rotary_emb_dim=None,
|
||||
custom_layers=None,
|
||||
sandwich_coef=None,
|
||||
par_ratio=None,
|
||||
residual_attn=False,
|
||||
cross_residual_attn=False,
|
||||
macaron=False,
|
||||
gate_residual=False,
|
||||
scale_residual=False,
|
||||
shift_tokens=0,
|
||||
use_qk_norm_attn=False,
|
||||
qk_norm_attn_seq_len=None,
|
||||
zero_init_branch_output=False,
|
||||
layerdrop_percent=.1,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(dim, depth)
|
||||
|
@ -84,21 +91,26 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
|||
self.layerdrop_percent = layerdrop_percent
|
||||
|
||||
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
|
||||
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
|
||||
self.pia_pos_emb = FixedPositionalEmbedding(
|
||||
dim) if position_infused_attn else None
|
||||
|
||||
rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
|
||||
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
|
||||
self.rotary_pos_emb = RotaryEmbedding(
|
||||
rotary_emb_dim) if rotary_pos_emb else None
|
||||
|
||||
assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
|
||||
assert not (
|
||||
alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
|
||||
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
||||
|
||||
if rel_pos_bias:
|
||||
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
|
||||
self.rel_pos = RelativePositionBias(
|
||||
scale=dim_head ** 0.5, causal=causal, heads=heads, num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance)
|
||||
elif alibi_pos_bias:
|
||||
alibi_num_heads = default(alibi_num_heads, heads)
|
||||
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
|
||||
alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
|
||||
self.rel_pos = alibi_pos_klass(heads = alibi_num_heads, bidirectional = not causal)
|
||||
self.rel_pos = alibi_pos_klass(
|
||||
heads=alibi_num_heads, bidirectional=not causal)
|
||||
else:
|
||||
self.rel_pos = None
|
||||
|
||||
|
@ -125,8 +137,10 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
|||
|
||||
# qk normalization
|
||||
if use_qk_norm_attn:
|
||||
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None
|
||||
attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
|
||||
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len **
|
||||
2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None
|
||||
attn_kwargs = {**attn_kwargs, 'qk_norm': True,
|
||||
'scale_init_value': attn_scale_init_value}
|
||||
|
||||
# zero init
|
||||
if zero_init_branch_output:
|
||||
|
@ -140,16 +154,20 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
|||
par_depth = depth * len(default_block)
|
||||
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
|
||||
default_block = tuple(filter(not_equals('f'), default_block))
|
||||
par_attn = par_depth // par_ratio
|
||||
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
|
||||
par_attn = par_depth // par_ratio
|
||||
# 2 / 3 attention layer cutoff suggested by PAR paper
|
||||
depth_cut = par_depth * 2 // 3
|
||||
par_width = (depth_cut + depth_cut // par_attn) // par_attn
|
||||
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
|
||||
par_block = default_block + ('f',) * (par_width - len(default_block))
|
||||
assert len(
|
||||
default_block) <= par_width, 'default block is too large for par_ratio'
|
||||
par_block = default_block + \
|
||||
('f',) * (par_width - len(default_block))
|
||||
par_head = par_block * par_attn
|
||||
layer_types = par_head + ('f',) * (par_depth - len(par_head))
|
||||
elif exists(sandwich_coef):
|
||||
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
|
||||
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
|
||||
layer_types = ('a',) * sandwich_coef + default_block * \
|
||||
(depth - sandwich_coef) + ('f',) * sandwich_coef
|
||||
else:
|
||||
layer_types = default_block * depth
|
||||
|
||||
|
@ -163,9 +181,10 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
|||
# iterate and construct layers
|
||||
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
|
||||
if layer_type == 'a':
|
||||
layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
|
||||
layer = Attention(dim, heads=heads,
|
||||
causal=causal, **attn_kwargs)
|
||||
elif layer_type == 'c':
|
||||
layer = Attention(dim, heads = heads, **attn_kwargs)
|
||||
layer = Attention(dim, heads=heads, **attn_kwargs)
|
||||
elif layer_type == 'f':
|
||||
layer = FeedForward(dim, **ff_kwargs)
|
||||
layer = layer if not macaron else Scale(0.5, layer)
|
||||
|
@ -175,19 +194,22 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
|||
if layer_shift_tokens > 0:
|
||||
shift_range_upper = layer_shift_tokens + 1
|
||||
shift_range_lower = -layer_shift_tokens if not causal else 0
|
||||
layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
|
||||
layer = ShiftTokens(
|
||||
range(shift_range_lower, shift_range_upper), layer)
|
||||
|
||||
if exists(branch_fn):
|
||||
layer = branch_fn(layer)
|
||||
|
||||
residual_fn = GRUGating if gate_residual else Residual
|
||||
residual = residual_fn(dim, scale_residual = scale_residual)
|
||||
residual = residual_fn(dim, scale_residual=scale_residual)
|
||||
|
||||
layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c')
|
||||
|
||||
pre_branch_norm = TimeIntegrationBlock(timestep_dim, dim, norm_fn())
|
||||
pre_branch_norm = TimeIntegrationBlock(
|
||||
timestep_dim, dim, norm_fn())
|
||||
post_branch_norm = norm_fn() if layer_uses_qk_norm else None
|
||||
post_main_norm = None # Always do prenorm for timestep integration.
|
||||
# Always do prenorm for timestep integration.
|
||||
post_main_norm = None
|
||||
|
||||
norms = nn.ModuleList([
|
||||
pre_branch_norm,
|
||||
|
@ -204,15 +226,16 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
|||
def forward(
|
||||
self,
|
||||
x,
|
||||
time_emb = None,
|
||||
context = None,
|
||||
mask = None,
|
||||
context_mask = None,
|
||||
attn_mask = None,
|
||||
mems = None,
|
||||
return_hiddens = False
|
||||
time_emb=None,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
attn_mask=None,
|
||||
mems=None,
|
||||
return_hiddens=False
|
||||
):
|
||||
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
||||
assert not (self.cross_attend ^ exists(
|
||||
context)), 'context must be passed in if cross_attend is set to True'
|
||||
assert time_emb is not None, 'must specify a timestep embedding.'
|
||||
|
||||
hiddens = []
|
||||
|
@ -224,8 +247,10 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
|||
|
||||
rotary_pos_emb = None
|
||||
if exists(self.rotary_pos_emb):
|
||||
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
|
||||
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
||||
max_rotary_emb_length = max(
|
||||
list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
|
||||
rotary_pos_emb = self.rotary_pos_emb(
|
||||
max_rotary_emb_length, x.device)
|
||||
|
||||
unused_params = []
|
||||
to_drop = 0
|
||||
|
@ -253,9 +278,11 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
|||
x = pre_branch_norm(x, time_emb)
|
||||
|
||||
if layer_type == 'a':
|
||||
out, inter = block(x, mask = mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
|
||||
out, inter = block(x, mask=mask, attn_mask=attn_mask, sinusoidal_emb=self.pia_pos_emb,
|
||||
rel_pos=self.rel_pos, rotary_pos_emb=rotary_pos_emb, prev_attn=prev_attn, mem=layer_mem)
|
||||
elif layer_type == 'c':
|
||||
out, inter = block(x, context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn)
|
||||
out, inter = block(
|
||||
x, context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
|
||||
elif layer_type == 'f':
|
||||
out = block(x)
|
||||
|
||||
|
@ -283,10 +310,10 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
|||
|
||||
if return_hiddens:
|
||||
intermediates = LayerIntermediates(
|
||||
hiddens = hiddens,
|
||||
attn_intermediates = intermediates
|
||||
hiddens=hiddens,
|
||||
attn_intermediates=intermediates
|
||||
)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
return x
|
||||
return x
|
||||
|
|
|
@ -1,17 +1,15 @@
|
|||
import functools
|
||||
import math
|
||||
from math import sqrt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import einsum
|
||||
from vector_quantize_pytorch import VectorQuantize
|
||||
|
||||
from models.vqvae.vqvae import Quantize
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
from dlas.models.vqvae.vqvae import Quantize
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
def default(val, d):
|
||||
|
@ -32,9 +30,9 @@ class ResBlock(nn.Module):
|
|||
def __init__(self, chan, conv, activation):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
conv(chan, chan, 3, padding = 1),
|
||||
conv(chan, chan, 3, padding=1),
|
||||
activation(),
|
||||
conv(chan, chan, 3, padding = 1),
|
||||
conv(chan, chan, 3, padding=1),
|
||||
activation(),
|
||||
conv(chan, chan, 1)
|
||||
)
|
||||
|
@ -52,7 +50,8 @@ class UpsampledConv(nn.Module):
|
|||
self.conv = conv(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
up = nn.functional.interpolate(x, scale_factor=self.stride, mode='nearest')
|
||||
up = nn.functional.interpolate(
|
||||
x, scale_factor=self.stride, mode='nearest')
|
||||
return self.conv(up)
|
||||
|
||||
|
||||
|
@ -60,23 +59,23 @@ class DiscreteVAE(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
positional_dims=2,
|
||||
num_tokens = 512,
|
||||
codebook_dim = 512,
|
||||
num_layers = 3,
|
||||
num_resnet_blocks = 0,
|
||||
hidden_dim = 64,
|
||||
channels = 3,
|
||||
stride = 2,
|
||||
kernel_size = 4,
|
||||
use_transposed_convs = True,
|
||||
encoder_norm = False,
|
||||
activation = 'relu',
|
||||
smooth_l1_loss = False,
|
||||
straight_through = False,
|
||||
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
|
||||
record_codes = False,
|
||||
use_lr_quantizer = False,
|
||||
lr_quantizer_args = {},
|
||||
num_tokens=512,
|
||||
codebook_dim=512,
|
||||
num_layers=3,
|
||||
num_resnet_blocks=0,
|
||||
hidden_dim=64,
|
||||
channels=3,
|
||||
stride=2,
|
||||
kernel_size=4,
|
||||
use_transposed_convs=True,
|
||||
encoder_norm=False,
|
||||
activation='relu',
|
||||
smooth_l1_loss=False,
|
||||
straight_through=False,
|
||||
normalization=None, # ((0.5,) * 3, (0.5,) * 3),
|
||||
record_codes=False,
|
||||
use_lr_quantizer=False,
|
||||
lr_quantizer_args={},
|
||||
):
|
||||
super().__init__()
|
||||
has_resblocks = num_resnet_blocks > 0
|
||||
|
@ -86,7 +85,8 @@ class DiscreteVAE(nn.Module):
|
|||
self.straight_through = straight_through
|
||||
self.positional_dims = positional_dims
|
||||
|
||||
assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
|
||||
# This VAE only supports 1d and 2d inputs for now.
|
||||
assert positional_dims > 0 and positional_dims < 3
|
||||
if positional_dims == 2:
|
||||
conv = nn.Conv2d
|
||||
conv_transpose = nn.ConvTranspose2d
|
||||
|
@ -103,7 +103,6 @@ class DiscreteVAE(nn.Module):
|
|||
else:
|
||||
assert NotImplementedError()
|
||||
|
||||
|
||||
enc_layers = []
|
||||
dec_layers = []
|
||||
|
||||
|
@ -116,18 +115,22 @@ class DiscreteVAE(nn.Module):
|
|||
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
|
||||
dec_chans = [dec_init_chan, *dec_chans]
|
||||
|
||||
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
|
||||
enc_chans_io, dec_chans_io = map(lambda t: list(
|
||||
zip(t[:-1], t[1:])), (enc_chans, dec_chans))
|
||||
|
||||
pad = (kernel_size - 1) // 2
|
||||
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
|
||||
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), act()))
|
||||
enc_layers.append(nn.Sequential(
|
||||
conv(enc_in, enc_out, kernel_size, stride=stride, padding=pad), act()))
|
||||
if encoder_norm:
|
||||
enc_layers.append(nn.GroupNorm(8, enc_out))
|
||||
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), act()))
|
||||
dec_layers.append(nn.Sequential(conv_transpose(
|
||||
dec_in, dec_out, kernel_size, stride=stride, padding=pad), act()))
|
||||
dec_out_chans = dec_chans[-1]
|
||||
innermost_dim = dec_chans[0]
|
||||
else:
|
||||
enc_layers.append(nn.Sequential(conv(channels, hidden_dim, 1), act()))
|
||||
enc_layers.append(nn.Sequential(
|
||||
conv(channels, hidden_dim, 1), act()))
|
||||
dec_out_chans = hidden_dim
|
||||
innermost_dim = hidden_dim
|
||||
|
||||
|
@ -138,7 +141,6 @@ class DiscreteVAE(nn.Module):
|
|||
if num_resnet_blocks > 0:
|
||||
dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1))
|
||||
|
||||
|
||||
enc_layers.append(conv(innermost_dim, codebook_dim, 1))
|
||||
dec_layers.append(conv(dec_out_chans, channels, 1))
|
||||
|
||||
|
@ -148,9 +150,11 @@ class DiscreteVAE(nn.Module):
|
|||
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
|
||||
|
||||
if use_lr_quantizer:
|
||||
self.codebook = VectorQuantize(dim=codebook_dim, codebook_size=num_tokens, **lr_quantizer_args)
|
||||
self.codebook = VectorQuantize(
|
||||
dim=codebook_dim, codebook_size=num_tokens, **lr_quantizer_args)
|
||||
else:
|
||||
self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
|
||||
self.codebook = Quantize(
|
||||
codebook_dim, num_tokens, new_return_order=True)
|
||||
|
||||
# take care of normalization within class
|
||||
self.normalization = normalization
|
||||
|
@ -165,7 +169,8 @@ class DiscreteVAE(nn.Module):
|
|||
if not self.normalization is not None:
|
||||
return images
|
||||
|
||||
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
|
||||
means, stds = map(lambda t: torch.as_tensor(
|
||||
t).to(images), self.normalization)
|
||||
arrange = 'c -> () c () ()' if self.positional_dims == 2 else 'c -> () c ()'
|
||||
means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
|
||||
images = images.clone()
|
||||
|
@ -183,7 +188,8 @@ class DiscreteVAE(nn.Module):
|
|||
@eval_decorator
|
||||
def get_codebook_indices(self, images):
|
||||
img = self.norm(images)
|
||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||
logits = self.encoder(img).permute(
|
||||
(0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
|
||||
sampled, codes, _ = self.codebook(logits)
|
||||
self.log_codes(codes)
|
||||
return codes
|
||||
|
@ -214,7 +220,8 @@ class DiscreteVAE(nn.Module):
|
|||
|
||||
def infer(self, img):
|
||||
img = self.norm(img)
|
||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||
logits = self.encoder(img).permute(
|
||||
(0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
|
||||
sampled, codes, commitment_loss = self.codebook(logits)
|
||||
return self.decode(codes)
|
||||
|
||||
|
@ -226,9 +233,11 @@ class DiscreteVAE(nn.Module):
|
|||
img
|
||||
):
|
||||
img = self.norm(img)
|
||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||
logits = self.encoder(img).permute(
|
||||
(0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
|
||||
sampled, codes, commitment_loss = self.codebook(logits)
|
||||
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
|
||||
sampled = sampled.permute(
|
||||
(0, 3, 1, 2) if len(img.shape) == 4 else (0, 2, 1))
|
||||
|
||||
if self.training:
|
||||
out = sampled
|
||||
|
@ -249,7 +258,8 @@ class DiscreteVAE(nn.Module):
|
|||
if self.record_codes and self.internal_step % 10 == 0:
|
||||
codes = codes.flatten()
|
||||
l = codes.shape[0]
|
||||
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
||||
i = self.code_ind if (
|
||||
self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
||||
self.codes[i:i+l] = codes.cpu()
|
||||
self.code_ind = self.code_ind + l
|
||||
if self.code_ind >= self.codes.shape[0]:
|
||||
|
@ -264,14 +274,14 @@ def register_lucidrains_dvae(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
#v = DiscreteVAE()
|
||||
#o=v(torch.randn(1,3,256,256))
|
||||
#print(o.shape)
|
||||
# v = DiscreteVAE()
|
||||
# o=v(torch.randn(1,3,256,256))
|
||||
# print(o.shape)
|
||||
v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048,
|
||||
hidden_dim=512, num_resnet_blocks=3, kernel_size=3, num_layers=1, use_transposed_convs=False,
|
||||
use_lr_quantizer=True)
|
||||
#v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth'))
|
||||
#v.eval()
|
||||
r,l,o=v(torch.randn(1,80,256))
|
||||
v.decode(torch.randint(0,8192,(1,256)))
|
||||
# v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth'))
|
||||
# v.eval()
|
||||
r, l, o = v(torch.randn(1, 80, 256))
|
||||
v.decode(torch.randint(0, 8192, (1, 256)))
|
||||
print(o.shape, l.shape)
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_intermediary as ml
|
||||
|
||||
|
||||
from models.diffusion.nn import normalization, conv_nd, zero_module
|
||||
from models.diffusion.unet_diffusion import Downsample, AttentionBlock, QKVAttention, QKVAttentionLegacy, Upsample
|
||||
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.diffusion.nn import conv_nd, normalization, zero_module
|
||||
from dlas.models.diffusion.unet_diffusion import (AttentionBlock, Downsample,
|
||||
QKVAttention,
|
||||
QKVAttentionLegacy, Upsample)
|
||||
# Combined resnet & full-attention encoder for converting an audio clip into an embedding.
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get, sequential_checkpoint
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import checkpoint, opt_get, sequential_checkpoint
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
|
@ -37,7 +37,8 @@ class ResBlock(nn.Module):
|
|||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
|
||||
conv_nd(dims, channels, self.out_channels,
|
||||
kernel_size, padding=padding),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
@ -56,7 +57,8 @@ class ResBlock(nn.Module):
|
|||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
|
||||
conv_nd(dims, self.out_channels, self.out_channels,
|
||||
kernel_size, padding=padding)
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -67,7 +69,8 @@ class ResBlock(nn.Module):
|
|||
dims, channels, self.out_channels, kernel_size, padding=padding
|
||||
)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
if self.do_checkpoint:
|
||||
|
@ -111,8 +114,10 @@ class AudioMiniEncoder(nn.Module):
|
|||
self.layers = depth
|
||||
for l in range(depth):
|
||||
for r in range(resnet_blocks):
|
||||
res.append(ResBlock(ch, dropout, dims=1, do_checkpoint=False, kernel_size=kernel_size))
|
||||
res.append(Downsample(ch, use_conv=True, dims=1, out_channels=ch*2, factor=downsample_factor))
|
||||
res.append(ResBlock(ch, dropout, dims=1,
|
||||
do_checkpoint=False, kernel_size=kernel_size))
|
||||
res.append(Downsample(ch, use_conv=True, dims=1,
|
||||
out_channels=ch*2, factor=downsample_factor))
|
||||
ch *= 2
|
||||
self.res = nn.Sequential(*res)
|
||||
self.final = nn.Sequential(
|
||||
|
@ -122,7 +127,8 @@ class AudioMiniEncoder(nn.Module):
|
|||
)
|
||||
attn = []
|
||||
for a in range(attn_blocks):
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False))
|
||||
attn.append(AttentionBlock(embedding_dim,
|
||||
num_attn_heads, do_checkpoint=False))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
|
||||
|
@ -150,10 +156,12 @@ class AudioMiniEncoderWithClassifierHead(nn.Module):
|
|||
return logits
|
||||
else:
|
||||
if self.distribute_zero_label:
|
||||
oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes)
|
||||
oh_labels = nn.functional.one_hot(
|
||||
labels, num_classes=self.num_classes)
|
||||
zeros_indices = (labels == 0).unsqueeze(-1)
|
||||
# Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise.
|
||||
zero_extra_mass = torch.full_like(oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1))
|
||||
zero_extra_mass = torch.full_like(
|
||||
oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1))
|
||||
zero_extra_mass[:, 0] = -.2
|
||||
zero_extra_mass = zero_extra_mass * zeros_indices
|
||||
oh_labels = oh_labels + zero_extra_mass
|
||||
|
@ -167,6 +175,7 @@ class QueryProvidedAttentionBlock(nn.Module):
|
|||
"""
|
||||
An attention block that provides a separate signal for the query vs the keys/parameters.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
|
@ -200,12 +209,13 @@ class QueryProvidedAttentionBlock(nn.Module):
|
|||
return checkpoint(self._forward, qx, kvx, mask)
|
||||
|
||||
def _forward(self, qx, kvx, mask=None):
|
||||
q = self.q(self.qnorm(qx)).unsqueeze(1).repeat(1, kvx.shape[1], 1).permute(0,2,1)
|
||||
kv = self.kv(self.norm(kvx.permute(0,2,1)))
|
||||
q = self.q(self.qnorm(qx)).unsqueeze(1).repeat(
|
||||
1, kvx.shape[1], 1).permute(0, 2, 1)
|
||||
kv = self.kv(self.norm(kvx.permute(0, 2, 1)))
|
||||
qkv = torch.cat([q, kv], dim=1)
|
||||
h = self.attention(qkv, mask)
|
||||
h = self.proj_out(h)
|
||||
return kvx + h.permute(0,2,1)
|
||||
return kvx + h.permute(0, 2, 1)
|
||||
|
||||
|
||||
# Next up: combine multiple embeddings given a conditioning signal into a single embedding.
|
||||
|
@ -213,7 +223,8 @@ class EmbeddingCombiner(nn.Module):
|
|||
def __init__(self, embedding_dim, attn_blocks=3, num_attn_heads=2, cond_provided=True):
|
||||
super().__init__()
|
||||
block = QueryProvidedAttentionBlock if cond_provided else AttentionBlock
|
||||
self.attn = nn.ModuleList([block(embedding_dim, num_attn_heads) for _ in range(attn_blocks)])
|
||||
self.attn = nn.ModuleList(
|
||||
[block(embedding_dim, num_attn_heads) for _ in range(attn_blocks)])
|
||||
self.cond_provided = cond_provided
|
||||
|
||||
# x_s: (b,n,d); b=batch_sz, n=number of embeddings, d=embedding_dim
|
||||
|
|
|
@ -3,10 +3,10 @@ import math
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_intermediary as ml
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
|
||||
|
@ -61,4 +61,4 @@ def register_random_latent_converter(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
model = RandomLatentConverter(512)
|
||||
model(torch.randn(5,512))
|
||||
model(torch.randn(5, 512))
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from models.audio.tts.tacotron2.taco_utils import *
|
||||
from models.audio.tts.tacotron2.text import *
|
||||
from models.audio.tts.tacotron2.tacotron2 import *
|
||||
from models.audio.tts.tacotron2.stft import *
|
||||
from models.audio.tts.tacotron2.layers import *
|
||||
from models.audio.tts.tacotron2.loss import *
|
||||
from dlas.models.audio.tts.tacotron2.layers import *
|
||||
from dlas.models.audio.tts.tacotron2.loss import *
|
||||
from dlas.models.audio.tts.tacotron2.stft import *
|
||||
from dlas.models.audio.tts.tacotron2.taco_utils import *
|
||||
from dlas.models.audio.tts.tacotron2.tacotron2 import *
|
||||
from dlas.models.audio.tts.tacotron2.text import *
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from scipy.signal import get_window
|
||||
import librosa.util as librosa_util
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.signal import get_window
|
||||
|
||||
|
||||
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
|
||||
|
@ -52,7 +52,8 @@ def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
|
|||
# Fill the envelope
|
||||
for i in range(n_frames):
|
||||
sample = i * hop_length
|
||||
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
|
||||
x[sample:min(n, sample + n_fft)
|
||||
] += win_sq[:max(0, min(n_fft, n - sample))]
|
||||
return x
|
||||
|
||||
|
||||
|
@ -90,4 +91,4 @@ def dynamic_range_decompression(x, C=1):
|
|||
------
|
||||
C: compression factor used to compress
|
||||
"""
|
||||
return torch.exp(x) / C
|
||||
return torch.exp(x) / C
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#import tensorflow as tf
|
||||
from models.audio.tts.tacotron2.text import symbols
|
||||
# import tensorflow as tf
|
||||
from dlas.models.audio.tts.tacotron2.text import symbols
|
||||
|
||||
|
||||
def create_hparams(hparams_string=None, verbose=False):
|
||||
|
@ -33,7 +33,8 @@ def create_hparams(hparams_string=None, verbose=False):
|
|||
# Audio Parameters #
|
||||
################################
|
||||
max_wav_value=32768.0,
|
||||
input_sample_rate=22050, # When different from sampling_rate, dataset automatically interpolates to sampling_rate
|
||||
# When different from sampling_rate, dataset automatically interpolates to sampling_rate
|
||||
input_sample_rate=22050,
|
||||
sampling_rate=22050,
|
||||
filter_length=1024,
|
||||
hop_length=256, # This means a MEL is 1/256th the equivalent audio.
|
||||
|
@ -86,4 +87,4 @@ def create_hparams(hparams_string=None, verbose=False):
|
|||
mask_padding=True # set model's padded outputs to padded values
|
||||
)
|
||||
|
||||
return hparams
|
||||
return hparams
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import torch
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from models.audio.tts.tacotron2.audio_processing import dynamic_range_compression
|
||||
from models.audio.tts.tacotron2.audio_processing import dynamic_range_decompression
|
||||
from models.audio.tts.tacotron2.stft import STFT
|
||||
import torch_intermediary as ml
|
||||
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.audio.tts.tacotron2.audio_processing import (
|
||||
dynamic_range_compression, dynamic_range_decompression)
|
||||
from dlas.models.audio.tts.tacotron2.stft import STFT
|
||||
|
||||
|
||||
class LinearNorm(torch.nn.Module):
|
||||
|
@ -24,7 +25,7 @@ class ConvNorm(torch.nn.Module):
|
|||
padding=None, dilation=1, bias=True, w_init_gain='linear'):
|
||||
super(ConvNorm, self).__init__()
|
||||
if padding is None:
|
||||
assert(kernel_size % 2 == 1)
|
||||
assert (kernel_size % 2 == 1)
|
||||
padding = int(dilation * (kernel_size - 1) / 2)
|
||||
|
||||
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
||||
|
@ -71,8 +72,8 @@ class TacotronSTFT(torch.nn.Module):
|
|||
-------
|
||||
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
|
||||
"""
|
||||
assert(torch.min(y.data) >= -10)
|
||||
assert(torch.max(y.data) <= 10)
|
||||
assert (torch.min(y.data) >= -10)
|
||||
assert (torch.max(y.data) <= 10)
|
||||
y = torch.clip(y, min=-1, max=1)
|
||||
|
||||
magnitudes, phases = self.stft_fn.transform(y)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from torch import nn
|
||||
|
||||
from trainer.losses import ConfigurableLoss
|
||||
from dlas.trainer.losses import ConfigurableLoss
|
||||
|
||||
|
||||
class Tacotron2Loss(ConfigurableLoss):
|
||||
|
@ -20,7 +20,8 @@ class Tacotron2Loss(ConfigurableLoss):
|
|||
gate_target.requires_grad = False
|
||||
gate_target = gate_target.view(-1, 1)
|
||||
|
||||
mel_out, mel_out_postnet, gate_out = state[self.mel_output_key], state[self.mel_output_postnet_key], state[self.gate_output_key]
|
||||
mel_out, mel_out_postnet, gate_out = state[self.mel_output_key], state[
|
||||
self.mel_output_postnet_key], state[self.gate_output_key]
|
||||
gate_out = gate_out.view(-1, 1)
|
||||
mel_loss = nn.MSELoss()(mel_out, mel_target) + \
|
||||
nn.MSELoss()(mel_out_postnet, mel_target)
|
||||
|
@ -61,4 +62,4 @@ class Tacotron2LossRaw(nn.Module):
|
|||
return {
|
||||
'mel_loss': self.last_mel_loss,
|
||||
'gate_loss': self.last_gate_loss
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,17 +30,19 @@ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
from scipy.signal import get_window
|
||||
from librosa.util import pad_center, tiny
|
||||
from models.audio.tts.tacotron2.audio_processing import window_sumsquare
|
||||
from scipy.signal import get_window
|
||||
from torch.autograd import Variable
|
||||
|
||||
from dlas.models.audio.tts.tacotron2.audio_processing import window_sumsquare
|
||||
|
||||
|
||||
class STFT(torch.nn.Module):
|
||||
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
||||
|
||||
def __init__(self, filter_length=800, hop_length=200, win_length=800,
|
||||
window='hann'):
|
||||
super(STFT, self).__init__()
|
||||
|
@ -61,7 +63,7 @@ class STFT(torch.nn.Module):
|
|||
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
|
||||
|
||||
if window is not None:
|
||||
assert(filter_length >= win_length)
|
||||
assert (filter_length >= win_length)
|
||||
# get window and zero center pad it to filter_length
|
||||
fft_window = get_window(window, win_length, fftbins=True)
|
||||
fft_window = pad_center(fft_window, filter_length)
|
||||
|
@ -125,17 +127,19 @@ class STFT(torch.nn.Module):
|
|||
window_sum = torch.autograd.Variable(
|
||||
torch.from_numpy(window_sum), requires_grad=False)
|
||||
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
|
||||
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
|
||||
inverse_transform[:, :,
|
||||
approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
|
||||
|
||||
# scale by hop ratio
|
||||
inverse_transform *= float(self.filter_length) / self.hop_length
|
||||
|
||||
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
|
||||
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
|
||||
inverse_transform = inverse_transform[:,
|
||||
:, :-int(self.filter_length/2):]
|
||||
|
||||
return inverse_transform
|
||||
|
||||
def forward(self, input_data):
|
||||
self.magnitude, self.phase = self.transform(input_data)
|
||||
reconstruction = self.inverse(self.magnitude, self.phase)
|
||||
return reconstruction
|
||||
return reconstruction
|
||||
|
|
|
@ -8,7 +8,8 @@ from scipy.io.wavfile import read
|
|||
def get_mask_from_lengths(lengths, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = torch.max(lengths).item()
|
||||
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).to(lengths.device)
|
||||
ids = torch.arange(0, max_len, out=torch.LongTensor(
|
||||
max_len)).to(lengths.device)
|
||||
mask = (ids < lengths.unsqueeze(1)).bool()
|
||||
return mask
|
||||
|
||||
|
@ -22,24 +23,29 @@ def load_wav_to_torch(full_path):
|
|||
elif data.dtype == np.float16 or data.dtype == np.float32:
|
||||
norm_fix = 1.
|
||||
else:
|
||||
raise NotImplemented(f"Provided data dtype not supported: {data.dtype}")
|
||||
raise NotImplemented(
|
||||
f"Provided data dtype not supported: {data.dtype}")
|
||||
return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate)
|
||||
|
||||
|
||||
def load_filepaths_and_text_type(filename, type, split="|"):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
filepaths_and_text = [list(line.strip().split(split)) + [type] for line in f]
|
||||
filepaths_and_text = [
|
||||
list(line.strip().split(split)) + [type] for line in f]
|
||||
base = os.path.dirname(filename)
|
||||
for j in range(len(filepaths_and_text)):
|
||||
filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])
|
||||
filepaths_and_text[j][0] = os.path.join(
|
||||
base, filepaths_and_text[j][0])
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
def load_filepaths_and_text(filename, split="|"):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
filepaths_and_text = [line.strip().split(split) for line in f]
|
||||
base = os.path.dirname(filename)
|
||||
for j in range(len(filepaths_and_text)):
|
||||
filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])
|
||||
filepaths_and_text[j][0] = os.path.join(
|
||||
base, filepaths_and_text[j][0])
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
|
@ -48,4 +54,4 @@ def to_gpu(x):
|
|||
|
||||
if torch.cuda.is_available():
|
||||
x = x.cuda(non_blocking=True)
|
||||
return torch.autograd.Variable(x)
|
||||
return torch.autograd.Variable(x)
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
from math import sqrt
|
||||
|
||||
import torch
|
||||
from munch import munchify
|
||||
from torch.autograd import Variable
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
from torch.nn import functional as F
|
||||
from models.audio.tts.tacotron2.layers import ConvNorm, LinearNorm
|
||||
from models.audio.tts.tacotron2.hparams import create_hparams
|
||||
from trainer.networks import register_model
|
||||
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
|
||||
import torch_intermediary as ml
|
||||
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.audio.tts.tacotron2.hparams import create_hparams
|
||||
from dlas.models.audio.tts.tacotron2.layers import ConvNorm, LinearNorm
|
||||
from dlas.models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
|
||||
from dlas.trainer.networks import register_model
|
||||
|
||||
|
||||
class LocationLayer(nn.Module):
|
||||
|
@ -59,7 +61,8 @@ class Attention(nn.Module):
|
|||
"""
|
||||
|
||||
processed_query = self.query_layer(query.unsqueeze(1))
|
||||
processed_attention_weights = self.location_layer(attention_weights_cat)
|
||||
processed_attention_weights = self.location_layer(
|
||||
attention_weights_cat)
|
||||
energies = self.v(torch.tanh(
|
||||
processed_query + processed_attention_weights + processed_memory))
|
||||
|
||||
|
@ -128,7 +131,8 @@ class Postnet(nn.Module):
|
|||
ConvNorm(hparams.postnet_embedding_dim,
|
||||
hparams.postnet_embedding_dim,
|
||||
kernel_size=hparams.postnet_kernel_size, stride=1,
|
||||
padding=int((hparams.postnet_kernel_size - 1) / 2),
|
||||
padding=int(
|
||||
(hparams.postnet_kernel_size - 1) / 2),
|
||||
dilation=1, w_init_gain='tanh'),
|
||||
nn.BatchNorm1d(hparams.postnet_embedding_dim))
|
||||
)
|
||||
|
@ -140,11 +144,12 @@ class Postnet(nn.Module):
|
|||
padding=int((hparams.postnet_kernel_size - 1) / 2),
|
||||
dilation=1, w_init_gain='linear'),
|
||||
nn.BatchNorm1d(hparams.n_mel_channels))
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for i in range(len(self.convolutions) - 1):
|
||||
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
|
||||
x = F.dropout(torch.tanh(
|
||||
self.convolutions[i](x)), 0.5, self.training)
|
||||
x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
|
||||
|
||||
return x
|
||||
|
@ -155,6 +160,7 @@ class Encoder(nn.Module):
|
|||
- Three 1-d convolution banks
|
||||
- Bidirectional LSTM
|
||||
"""
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
|
@ -408,7 +414,8 @@ class Decoder(nn.Module):
|
|||
mel_outputs, gate_outputs, alignments = [], [], []
|
||||
while len(mel_outputs) < decoder_inputs.size(0) - 1:
|
||||
decoder_input = decoder_inputs[len(mel_outputs)]
|
||||
mel_output, gate_output, attention_weights = self.decode(decoder_input)
|
||||
mel_output, gate_output, attention_weights = self.decode(
|
||||
decoder_input)
|
||||
mel_outputs += [mel_output.squeeze(1)]
|
||||
gate_outputs += [gate_output.squeeze(1)]
|
||||
alignments += [attention_weights]
|
||||
|
@ -529,9 +536,9 @@ def register_nv_tacotron2(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
tron = register_nv_tacotron2({}, {})
|
||||
inputs = torch.randint(high=24, size=(1,12)), \
|
||||
torch.tensor([12]), \
|
||||
torch.randn((1,80,749)), \
|
||||
torch.tensor([749])
|
||||
inputs = torch.randint(high=24, size=(1, 12)), \
|
||||
torch.tensor([12]), \
|
||||
torch.randn((1, 80, 749)), \
|
||||
torch.tensor([749])
|
||||
out = tron(*inputs)
|
||||
print(out)
|
||||
print(out)
|
||||
|
|
|
@ -3,9 +3,8 @@ import re
|
|||
|
||||
import torch
|
||||
|
||||
from models.audio.tts.tacotron2.text import cleaners
|
||||
from models.audio.tts.tacotron2.text.symbols import symbols
|
||||
|
||||
from dlas.models.audio.tts.tacotron2.text import cleaners
|
||||
from dlas.models.audio.tts.tacotron2.text.symbols import symbols
|
||||
|
||||
# Mappings from symbol to numeric ID and vice versa:
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
|
@ -16,72 +15,73 @@ _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
|||
|
||||
|
||||
def text_to_sequence(text, cleaner_names=['english_cleaners']):
|
||||
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
|
||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
||||
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
||||
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
||||
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
cleaner_names: names of the cleaner functions to run the text through
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
cleaner_names: names of the cleaner functions to run the text through
|
||||
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
'''
|
||||
sequence = []
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
'''
|
||||
sequence = []
|
||||
|
||||
# Check for curly braces and treat their contents as ARPAbet:
|
||||
while len(text):
|
||||
m = _curly_re.match(text)
|
||||
if not m:
|
||||
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
|
||||
break
|
||||
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
|
||||
sequence += _arpabet_to_sequence(m.group(2))
|
||||
text = m.group(3)
|
||||
# Check for curly braces and treat their contents as ARPAbet:
|
||||
while len(text):
|
||||
m = _curly_re.match(text)
|
||||
if not m:
|
||||
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
|
||||
break
|
||||
sequence += _symbols_to_sequence(
|
||||
_clean_text(m.group(1), cleaner_names))
|
||||
sequence += _arpabet_to_sequence(m.group(2))
|
||||
text = m.group(3)
|
||||
|
||||
return sequence
|
||||
return sequence
|
||||
|
||||
|
||||
def sequence_to_text(sequence):
|
||||
'''Converts a sequence of IDs back to a string'''
|
||||
result = ''
|
||||
for symbol_id in sequence:
|
||||
if isinstance(symbol_id, torch.Tensor):
|
||||
symbol_id = symbol_id.item()
|
||||
if symbol_id in _id_to_symbol:
|
||||
s = _id_to_symbol[symbol_id]
|
||||
# Enclose ARPAbet back in curly braces:
|
||||
if len(s) > 1 and s[0] == '@':
|
||||
s = '{%s}' % s[1:]
|
||||
result += s
|
||||
return result.replace('}{', ' ')
|
||||
'''Converts a sequence of IDs back to a string'''
|
||||
result = ''
|
||||
for symbol_id in sequence:
|
||||
if isinstance(symbol_id, torch.Tensor):
|
||||
symbol_id = symbol_id.item()
|
||||
if symbol_id in _id_to_symbol:
|
||||
s = _id_to_symbol[symbol_id]
|
||||
# Enclose ARPAbet back in curly braces:
|
||||
if len(s) > 1 and s[0] == '@':
|
||||
s = '{%s}' % s[1:]
|
||||
result += s
|
||||
return result.replace('}{', ' ')
|
||||
|
||||
|
||||
def tacotron_symbols():
|
||||
return list(_symbol_to_id.keys())
|
||||
return list(_symbol_to_id.keys())
|
||||
|
||||
|
||||
def tacotron_symbol_mapping():
|
||||
return _symbol_to_id.copy()
|
||||
return _symbol_to_id.copy()
|
||||
|
||||
|
||||
def _clean_text(text, cleaner_names):
|
||||
for name in cleaner_names:
|
||||
cleaner = getattr(cleaners, name)
|
||||
if not cleaner:
|
||||
raise Exception('Unknown cleaner: %s' % name)
|
||||
text = cleaner(text)
|
||||
return text
|
||||
for name in cleaner_names:
|
||||
cleaner = getattr(cleaners, name)
|
||||
if not cleaner:
|
||||
raise Exception('Unknown cleaner: %s' % name)
|
||||
text = cleaner(text)
|
||||
return text
|
||||
|
||||
|
||||
def _symbols_to_sequence(symbols):
|
||||
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
|
||||
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
|
||||
|
||||
|
||||
def _arpabet_to_sequence(text):
|
||||
return _symbols_to_sequence(['@' + s for s in text.split()])
|
||||
return _symbols_to_sequence(['@' + s for s in text.split()])
|
||||
|
||||
|
||||
def _should_keep_symbol(s):
|
||||
return s in _symbol_to_id and s != '_' and s != '~'
|
||||
return s in _symbol_to_id and s != '_' and s != '~'
|
||||
|
|
|
@ -12,80 +12,79 @@ hyperparameter. Some cleaners are English-specific. You'll typically want to use
|
|||
the symbols in symbols.py to match your data).
|
||||
'''
|
||||
|
||||
|
||||
# Regular expression matching whitespace:
|
||||
import re
|
||||
from unidecode import unidecode
|
||||
from .numbers import normalize_numbers
|
||||
|
||||
|
||||
# Regular expression matching whitespace:
|
||||
_whitespace_re = re.compile(r'\s+')
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('mrs', 'misess'),
|
||||
('mr', 'mister'),
|
||||
('dr', 'doctor'),
|
||||
('st', 'saint'),
|
||||
('co', 'company'),
|
||||
('jr', 'junior'),
|
||||
('maj', 'major'),
|
||||
('gen', 'general'),
|
||||
('drs', 'doctors'),
|
||||
('rev', 'reverend'),
|
||||
('lt', 'lieutenant'),
|
||||
('hon', 'honorable'),
|
||||
('sgt', 'sergeant'),
|
||||
('capt', 'captain'),
|
||||
('esq', 'esquire'),
|
||||
('ltd', 'limited'),
|
||||
('col', 'colonel'),
|
||||
('ft', 'fort'),
|
||||
('mrs', 'misess'),
|
||||
('mr', 'mister'),
|
||||
('dr', 'doctor'),
|
||||
('st', 'saint'),
|
||||
('co', 'company'),
|
||||
('jr', 'junior'),
|
||||
('maj', 'major'),
|
||||
('gen', 'general'),
|
||||
('drs', 'doctors'),
|
||||
('rev', 'reverend'),
|
||||
('lt', 'lieutenant'),
|
||||
('hon', 'honorable'),
|
||||
('sgt', 'sergeant'),
|
||||
('capt', 'captain'),
|
||||
('esq', 'esquire'),
|
||||
('ltd', 'limited'),
|
||||
('col', 'colonel'),
|
||||
('ft', 'fort'),
|
||||
]]
|
||||
|
||||
|
||||
def expand_abbreviations(text):
|
||||
for regex, replacement in _abbreviations:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
for regex, replacement in _abbreviations:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def expand_numbers(text):
|
||||
return normalize_numbers(text)
|
||||
return normalize_numbers(text)
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, ' ', text)
|
||||
return re.sub(_whitespace_re, ' ', text)
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
return unidecode(text)
|
||||
return unidecode(text)
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def transliteration_cleaners(text):
|
||||
'''Pipeline for non-English text that transliterates to ASCII.'''
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
'''Pipeline for non-English text that transliterates to ASCII.'''
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def english_cleaners(text):
|
||||
'''Pipeline for English text, including number and abbreviation expansion.'''
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_numbers(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = collapse_whitespace(text)
|
||||
text = text.replace('"', '')
|
||||
return text
|
||||
'''Pipeline for English text, including number and abbreviation expansion.'''
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_numbers(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = collapse_whitespace(text)
|
||||
text = text.replace('"', '')
|
||||
return text
|
||||
|
|
|
@ -2,64 +2,62 @@
|
|||
|
||||
import re
|
||||
|
||||
|
||||
valid_symbols = [
|
||||
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
|
||||
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
|
||||
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
|
||||
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
|
||||
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
|
||||
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
|
||||
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
|
||||
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
|
||||
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
|
||||
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
|
||||
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
|
||||
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
|
||||
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
|
||||
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
|
||||
]
|
||||
|
||||
_valid_symbol_set = set(valid_symbols)
|
||||
|
||||
|
||||
class CMUDict:
|
||||
'''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
|
||||
def __init__(self, file_or_path, keep_ambiguous=True):
|
||||
if isinstance(file_or_path, str):
|
||||
with open(file_or_path, encoding='latin-1') as f:
|
||||
entries = _parse_cmudict(f)
|
||||
else:
|
||||
entries = _parse_cmudict(file_or_path)
|
||||
if not keep_ambiguous:
|
||||
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
||||
self._entries = entries
|
||||
'''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
|
||||
|
||||
def __init__(self, file_or_path, keep_ambiguous=True):
|
||||
if isinstance(file_or_path, str):
|
||||
with open(file_or_path, encoding='latin-1') as f:
|
||||
entries = _parse_cmudict(f)
|
||||
else:
|
||||
entries = _parse_cmudict(file_or_path)
|
||||
if not keep_ambiguous:
|
||||
entries = {word: pron for word,
|
||||
pron in entries.items() if len(pron) == 1}
|
||||
self._entries = entries
|
||||
|
||||
def __len__(self):
|
||||
return len(self._entries)
|
||||
|
||||
|
||||
def lookup(self, word):
|
||||
'''Returns list of ARPAbet pronunciations of the given word.'''
|
||||
return self._entries.get(word.upper())
|
||||
def __len__(self):
|
||||
return len(self._entries)
|
||||
|
||||
def lookup(self, word):
|
||||
'''Returns list of ARPAbet pronunciations of the given word.'''
|
||||
return self._entries.get(word.upper())
|
||||
|
||||
|
||||
_alt_re = re.compile(r'\([0-9]+\)')
|
||||
|
||||
|
||||
def _parse_cmudict(file):
|
||||
cmudict = {}
|
||||
for line in file:
|
||||
if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
|
||||
parts = line.split(' ')
|
||||
word = re.sub(_alt_re, '', parts[0])
|
||||
pronunciation = _get_pronunciation(parts[1])
|
||||
if pronunciation:
|
||||
if word in cmudict:
|
||||
cmudict[word].append(pronunciation)
|
||||
else:
|
||||
cmudict[word] = [pronunciation]
|
||||
return cmudict
|
||||
cmudict = {}
|
||||
for line in file:
|
||||
if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
|
||||
parts = line.split(' ')
|
||||
word = re.sub(_alt_re, '', parts[0])
|
||||
pronunciation = _get_pronunciation(parts[1])
|
||||
if pronunciation:
|
||||
if word in cmudict:
|
||||
cmudict[word].append(pronunciation)
|
||||
else:
|
||||
cmudict[word] = [pronunciation]
|
||||
return cmudict
|
||||
|
||||
|
||||
def _get_pronunciation(s):
|
||||
parts = s.strip().split(' ')
|
||||
for part in parts:
|
||||
if part not in _valid_symbol_set:
|
||||
return None
|
||||
return ' '.join(parts)
|
||||
parts = s.strip().split(' ')
|
||||
for part in parts:
|
||||
if part not in _valid_symbol_set:
|
||||
return None
|
||||
return ' '.join(parts)
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
""" from https://github.com/keithito/tacotron """
|
||||
|
||||
import inflect
|
||||
import re
|
||||
|
||||
import inflect
|
||||
|
||||
_inflect = inflect.engine()
|
||||
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
||||
|
@ -14,58 +14,58 @@ _number_re = re.compile(r'[0-9]+')
|
|||
|
||||
|
||||
def _remove_commas(m):
|
||||
return m.group(1).replace(',', '')
|
||||
return m.group(1).replace(',', '')
|
||||
|
||||
|
||||
def _expand_decimal_point(m):
|
||||
return m.group(1).replace('.', ' point ')
|
||||
return m.group(1).replace('.', ' point ')
|
||||
|
||||
|
||||
def _expand_dollars(m):
|
||||
match = m.group(1)
|
||||
parts = match.split('.')
|
||||
if len(parts) > 2:
|
||||
return match + ' dollars' # Unexpected format
|
||||
dollars = int(parts[0]) if parts[0] else 0
|
||||
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||
if dollars and cents:
|
||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
||||
elif dollars:
|
||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||
return '%s %s' % (dollars, dollar_unit)
|
||||
elif cents:
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return '%s %s' % (cents, cent_unit)
|
||||
else:
|
||||
return 'zero dollars'
|
||||
match = m.group(1)
|
||||
parts = match.split('.')
|
||||
if len(parts) > 2:
|
||||
return match + ' dollars' # Unexpected format
|
||||
dollars = int(parts[0]) if parts[0] else 0
|
||||
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||
if dollars and cents:
|
||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
||||
elif dollars:
|
||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||
return '%s %s' % (dollars, dollar_unit)
|
||||
elif cents:
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return '%s %s' % (cents, cent_unit)
|
||||
else:
|
||||
return 'zero dollars'
|
||||
|
||||
|
||||
def _expand_ordinal(m):
|
||||
return _inflect.number_to_words(m.group(0))
|
||||
return _inflect.number_to_words(m.group(0))
|
||||
|
||||
|
||||
def _expand_number(m):
|
||||
num = int(m.group(0))
|
||||
if num > 1000 and num < 3000:
|
||||
if num == 2000:
|
||||
return 'two thousand'
|
||||
elif num > 2000 and num < 2010:
|
||||
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
||||
elif num % 100 == 0:
|
||||
return _inflect.number_to_words(num // 100) + ' hundred'
|
||||
num = int(m.group(0))
|
||||
if num > 1000 and num < 3000:
|
||||
if num == 2000:
|
||||
return 'two thousand'
|
||||
elif num > 2000 and num < 2010:
|
||||
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
||||
elif num % 100 == 0:
|
||||
return _inflect.number_to_words(num // 100) + ' hundred'
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword='')
|
||||
return _inflect.number_to_words(num, andword='')
|
||||
|
||||
|
||||
def normalize_numbers(text):
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
text = re.sub(_pounds_re, r'\1 pounds', text)
|
||||
text = re.sub(_dollars_re, _expand_dollars, text)
|
||||
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
||||
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||
text = re.sub(_number_re, _expand_number, text)
|
||||
return text
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
text = re.sub(_pounds_re, r'\1 pounds', text)
|
||||
text = re.sub(_dollars_re, _expand_dollars, text)
|
||||
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
||||
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||
text = re.sub(_number_re, _expand_number, text)
|
||||
return text
|
||||
|
|
|
@ -4,8 +4,9 @@
|
|||
Defines the set of symbols used in text input to the model.
|
||||
|
||||
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
||||
from models.audio.tts.tacotron2.text import cmudict
|
||||
|
||||
|
||||
from dlas.models.audio.tts.tacotron2.text import cmudict
|
||||
_pad = '_'
|
||||
_punctuation = '!\'(),.:;? '
|
||||
_special = '-'
|
||||
|
@ -15,4 +16,5 @@ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
|||
_arpabet = ['@' + s for s in cmudict.valid_symbols]
|
||||
|
||||
# Export all symbols:
|
||||
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet
|
||||
symbols = [_pad] + list(_special) + list(_punctuation) + \
|
||||
list(_letters) + _arpabet
|
||||
|
|
|
@ -1,20 +1,20 @@
|
|||
from math import sqrt
|
||||
|
||||
import torch
|
||||
from munch import munchify
|
||||
from torch.autograd import Variable
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
from torch.nn import functional as F
|
||||
|
||||
from models.arch_util import ConvGnSilu
|
||||
from models.diffusion.unet_diffusion import UNetModel, AttentionPool2d
|
||||
from models.audio.tts.tacotron2.layers import LinearNorm
|
||||
from models.audio.tts.tacotron2.hparams import create_hparams
|
||||
from models.audio.tts.tacotron2.tacotron2 import Attention, Encoder
|
||||
from trainer.networks import register_model
|
||||
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
|
||||
from utils.util import checkpoint
|
||||
import torch_intermediary as ml
|
||||
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.arch_util import ConvGnSilu
|
||||
from dlas.models.audio.tts.tacotron2.hparams import create_hparams
|
||||
from dlas.models.audio.tts.tacotron2.layers import LinearNorm
|
||||
from dlas.models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
|
||||
from dlas.models.audio.tts.tacotron2.tacotron2 import Attention, Encoder
|
||||
from dlas.models.diffusion.unet_diffusion import AttentionPool2d, UNetModel
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import checkpoint
|
||||
|
||||
|
||||
class WavDecoder(nn.Module):
|
||||
|
@ -24,26 +24,33 @@ class WavDecoder(nn.Module):
|
|||
self.K = int(sample_rate * (K_ms/1000))
|
||||
self.clarifier = UNetModel(image_size=self.K,
|
||||
in_channels=1,
|
||||
model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model.
|
||||
# This is a requirement to enable to load the embedding produced by the decoder into the unet model.
|
||||
model_channels=dec_channels // 4,
|
||||
out_channels=2, # 2 channels: eps_pred and variance_pred
|
||||
num_res_blocks=2,
|
||||
attention_resolutions=(8,),
|
||||
dims=1,
|
||||
dropout=.1,
|
||||
channel_mult=(1,2,4,8),
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
use_raw_y_as_embedding=True)
|
||||
assert self.K % 64 == 0 # Otherwise the UNetModel breaks.
|
||||
self.pre_rnn = nn.Sequential(ConvGnSilu(1,32,kernel_size=5,convnd=nn.Conv1d),
|
||||
ConvGnSilu(32,64,kernel_size=5,stride=4,convnd=nn.Conv1d),
|
||||
ConvGnSilu(64,128,kernel_size=5,stride=4,convnd=nn.Conv1d),
|
||||
ConvGnSilu(128,256,kernel_size=5,stride=4,convnd=nn.Conv1d),
|
||||
ConvGnSilu(256,dec_channels,kernel_size=1,convnd=nn.Conv1d),
|
||||
AttentionPool2d(self.K//64,dec_channels,dec_channels//4))
|
||||
self.pre_rnn = nn.Sequential(ConvGnSilu(1, 32, kernel_size=5, convnd=nn.Conv1d),
|
||||
ConvGnSilu(32, 64, kernel_size=5,
|
||||
stride=4, convnd=nn.Conv1d),
|
||||
ConvGnSilu(64, 128, kernel_size=5,
|
||||
stride=4, convnd=nn.Conv1d),
|
||||
ConvGnSilu(128, 256, kernel_size=5,
|
||||
stride=4, convnd=nn.Conv1d),
|
||||
ConvGnSilu(256, dec_channels,
|
||||
kernel_size=1, convnd=nn.Conv1d),
|
||||
AttentionPool2d(self.K//64, dec_channels, dec_channels//4))
|
||||
self.attention_rnn = nn.LSTMCell(dec_channels*2, dec_channels)
|
||||
self.attention_layer = Attention(dec_channels, dec_channels, dec_channels)
|
||||
self.attention_layer = Attention(
|
||||
dec_channels, dec_channels, dec_channels)
|
||||
self.decoder_rnn = nn.LSTMCell(dec_channels*2, dec_channels, 1)
|
||||
self.linear_projection = LinearNorm(dec_channels*2, self.dec_channels)
|
||||
self.gate_layer = LinearNorm(self.dec_channels*2, 1, bias=True, w_init_gain='sigmoid')
|
||||
self.gate_layer = LinearNorm(
|
||||
self.dec_channels*2, 1, bias=True, w_init_gain='sigmoid')
|
||||
self.dropout_probability = dropout_probability
|
||||
|
||||
def chunk_wav(self, wav):
|
||||
|
@ -51,16 +58,18 @@ class WavDecoder(nn.Module):
|
|||
# Pad the last chunk as needed.
|
||||
padding_needed = self.K - wavs[-1].shape[-1]
|
||||
if padding_needed > 0:
|
||||
wavs[-1] = F.pad(wavs[-1], (0,padding_needed))
|
||||
wavs[-1] = F.pad(wavs[-1], (0, padding_needed))
|
||||
|
||||
wavs = torch.stack(wavs, dim=1) # wavs.shape = (b,s,K) where s=decoder sequence length
|
||||
# wavs.shape = (b,s,K) where s=decoder sequence length
|
||||
wavs = torch.stack(wavs, dim=1)
|
||||
return wavs, padding_needed
|
||||
|
||||
|
||||
def prepare_decoder_inputs(self, inp):
|
||||
# inp.shape = (b,s,K) chunked waveform.
|
||||
b,s,K = inp.shape
|
||||
first_frame = torch.zeros(b,1,K).to(inp.device)
|
||||
x = torch.cat([first_frame, inp[:,:-1]], dim=1) # It is now aligned for teacher forcing.
|
||||
b, s, K = inp.shape
|
||||
first_frame = torch.zeros(b, 1, K).to(inp.device)
|
||||
# It is now aligned for teacher forcing.
|
||||
x = torch.cat([first_frame, inp[:, :-1]], dim=1)
|
||||
return x
|
||||
|
||||
def initialize_decoder_states(self, memory, mask):
|
||||
|
@ -75,18 +84,25 @@ class WavDecoder(nn.Module):
|
|||
B = memory.size(0)
|
||||
MAX_TIME = memory.size(1)
|
||||
|
||||
self.attention_hidden = Variable(memory.data.new(B, self.dec_channels).zero_())
|
||||
self.attention_cell = Variable(memory.data.new(B, self.dec_channels).zero_())
|
||||
self.attention_hidden = Variable(
|
||||
memory.data.new(B, self.dec_channels).zero_())
|
||||
self.attention_cell = Variable(
|
||||
memory.data.new(B, self.dec_channels).zero_())
|
||||
|
||||
self.decoder_hidden = Variable(memory.data.new(B, self.dec_channels).zero_())
|
||||
self.decoder_cell = Variable(memory.data.new(B, self.dec_channels).zero_())
|
||||
self.decoder_hidden = Variable(
|
||||
memory.data.new(B, self.dec_channels).zero_())
|
||||
self.decoder_cell = Variable(
|
||||
memory.data.new(B, self.dec_channels).zero_())
|
||||
|
||||
self.attention_weights = Variable(memory.data.new(B, MAX_TIME).zero_())
|
||||
self.attention_weights_cum = Variable(memory.data.new(B, MAX_TIME).zero_())
|
||||
self.attention_context = Variable(memory.data.new(B, self.dec_channels).zero_())
|
||||
self.attention_weights_cum = Variable(
|
||||
memory.data.new(B, MAX_TIME).zero_())
|
||||
self.attention_context = Variable(
|
||||
memory.data.new(B, self.dec_channels).zero_())
|
||||
|
||||
self.memory = memory
|
||||
self.processed_memory = checkpoint(self.attention_layer.memory_layer, memory)
|
||||
self.processed_memory = checkpoint(
|
||||
self.attention_layer.memory_layer, memory)
|
||||
self.mask = mask
|
||||
|
||||
def teardown_states(self):
|
||||
|
@ -113,20 +129,28 @@ class WavDecoder(nn.Module):
|
|||
attention_weights:
|
||||
"""
|
||||
cell_input = torch.cat((decoder_input, self.attention_context), -1)
|
||||
self.attention_hidden, self.attention_cell = self.attention_rnn(cell_input, (self.attention_hidden, self.attention_cell))
|
||||
self.attention_hidden = F.dropout(self.attention_hidden, self.dropout_probability, self.training)
|
||||
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
||||
cell_input, (self.attention_hidden, self.attention_cell))
|
||||
self.attention_hidden = F.dropout(
|
||||
self.attention_hidden, self.dropout_probability, self.training)
|
||||
|
||||
attention_weights_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1)
|
||||
attention_weights_cat = torch.cat((self.attention_weights.unsqueeze(
|
||||
1), self.attention_weights_cum.unsqueeze(1)), dim=1)
|
||||
self.attention_context, self.attention_weights = checkpoint(self.attention_layer, self.attention_hidden, self.memory,
|
||||
self.processed_memory, attention_weights_cat, self.mask)
|
||||
|
||||
self.attention_weights_cum += self.attention_weights
|
||||
decoder_input = torch.cat((self.attention_hidden, self.attention_context), -1)
|
||||
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(decoder_input, (self.decoder_hidden, self.decoder_cell))
|
||||
self.decoder_hidden = F.dropout(self.decoder_hidden, self.dropout_probability, self.training)
|
||||
decoder_input = torch.cat(
|
||||
(self.attention_hidden, self.attention_context), -1)
|
||||
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
||||
decoder_input, (self.decoder_hidden, self.decoder_cell))
|
||||
self.decoder_hidden = F.dropout(
|
||||
self.decoder_hidden, self.dropout_probability, self.training)
|
||||
|
||||
decoder_hidden_attention_context = torch.cat((self.decoder_hidden, self.attention_context), dim=1)
|
||||
decoder_output = checkpoint(self.linear_projection, decoder_hidden_attention_context)
|
||||
decoder_hidden_attention_context = torch.cat(
|
||||
(self.decoder_hidden, self.attention_context), dim=1)
|
||||
decoder_output = checkpoint(
|
||||
self.linear_projection, decoder_hidden_attention_context)
|
||||
|
||||
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
|
||||
return decoder_output, gate_prediction, self.attention_weights
|
||||
|
@ -137,11 +161,11 @@ class WavDecoder(nn.Module):
|
|||
# (T_out, B) -> (B, T_out)
|
||||
gate_outputs = torch.stack(gate_outputs, dim=1).repeat(1, self.K)
|
||||
|
||||
b,s,_,K = diffusion_eps.shape
|
||||
b, s, _, K = diffusion_eps.shape
|
||||
# (B, S, 2, K) -> (B, 2, S*K)
|
||||
diffusion_eps = diffusion_eps.permute(0,2,1,3).reshape(b, 2, s*K)
|
||||
diffusion_eps = diffusion_eps.permute(0, 2, 1, 3).reshape(b, 2, s*K)
|
||||
|
||||
return diffusion_eps[:,:,:-padding_added], gate_outputs[:,:-padding_added], alignments[:,:-padding_added]
|
||||
return diffusion_eps[:, :, :-padding_added], gate_outputs[:, :-padding_added], alignments[:, :-padding_added]
|
||||
|
||||
def forward(self, wav_noised, wav_real, timesteps, text_enc, memory_lengths):
|
||||
'''
|
||||
|
@ -155,14 +179,17 @@ class WavDecoder(nn.Module):
|
|||
wav_noised, padding_added = self.chunk_wav(wav_noised)
|
||||
wav_real, _ = self.chunk_wav(wav_real)
|
||||
wav_real = self.prepare_decoder_inputs(wav_real)
|
||||
b,s,K = wav_real.shape
|
||||
wav_real = checkpoint(self.pre_rnn, wav_real.reshape(b*s,1,K)).reshape(b,s,self.dec_channels)
|
||||
b, s, K = wav_real.shape
|
||||
wav_real = checkpoint(self.pre_rnn, wav_real.reshape(
|
||||
b*s, 1, K)).reshape(b, s, self.dec_channels)
|
||||
|
||||
self.initialize_decoder_states(text_enc, mask=~get_mask_from_lengths(memory_lengths))
|
||||
self.initialize_decoder_states(
|
||||
text_enc, mask=~get_mask_from_lengths(memory_lengths))
|
||||
decoder_contexts, gate_outputs, alignments = [], [], []
|
||||
while len(decoder_contexts) < wav_real.size(1):
|
||||
decoder_input = wav_real[:, len(decoder_contexts)]
|
||||
dec_context, gate_output, attention_weights = self.produce_context(decoder_input)
|
||||
dec_context, gate_output, attention_weights = self.produce_context(
|
||||
decoder_input)
|
||||
decoder_contexts += [dec_context.squeeze(1)]
|
||||
gate_outputs += [gate_output.squeeze(1)]
|
||||
alignments += [attention_weights]
|
||||
|
@ -170,12 +197,14 @@ class WavDecoder(nn.Module):
|
|||
|
||||
# diffusion_inputs and wavs needs to have the sequence and batch dimensions combined, and needs a channel dimension
|
||||
diffusion_emb = torch.stack(decoder_contexts, dim=1)
|
||||
b,s,c = diffusion_emb.shape
|
||||
diffusion_emb = diffusion_emb.reshape(b*s,c)
|
||||
wav_noised = wav_noised.reshape(b*s,1,self.K)
|
||||
diffusion_eps = self.clarifier(wav_noised, timesteps.repeat(s), diffusion_emb).reshape(b,s,2,self.K)
|
||||
b, s, c = diffusion_emb.shape
|
||||
diffusion_emb = diffusion_emb.reshape(b*s, c)
|
||||
wav_noised = wav_noised.reshape(b*s, 1, self.K)
|
||||
diffusion_eps = self.clarifier(wav_noised, timesteps.repeat(
|
||||
s), diffusion_emb).reshape(b, s, 2, self.K)
|
||||
# Recombine diffusion outputs across the sequence into a single prediction.
|
||||
diffusion_eps, gate_outputs, alignments = self.recombine(diffusion_eps, gate_outputs, alignments, padding_added)
|
||||
diffusion_eps, gate_outputs, alignments = self.recombine(
|
||||
diffusion_eps, gate_outputs, alignments, padding_added)
|
||||
return diffusion_eps, gate_outputs, alignments
|
||||
|
||||
|
||||
|
@ -199,11 +228,11 @@ class WaveTacotron2(nn.Module):
|
|||
if self.mask_padding and output_lengths is not None:
|
||||
mask_fill = outputs[0].shape[-1]
|
||||
mask = ~get_mask_from_lengths(output_lengths, mask_fill)
|
||||
mask = mask.unsqueeze(1).repeat(1,2,1)
|
||||
mask = mask.unsqueeze(1).repeat(1, 2, 1)
|
||||
|
||||
outputs[0].data.masked_fill_(mask, 0.0)
|
||||
outputs[0] = outputs[0].unsqueeze(1) # Re-add channel dimension.
|
||||
outputs[1].data.masked_fill_(mask[:,0], 1e3) # gate energies
|
||||
outputs[1].data.masked_fill_(mask[:, 0], 1e3) # gate energies
|
||||
|
||||
return outputs
|
||||
|
||||
|
@ -214,7 +243,8 @@ class WaveTacotron2(nn.Module):
|
|||
|
||||
text_lengths, output_lengths = text_lengths.data, output_lengths.data
|
||||
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
|
||||
encoder_outputs = checkpoint(self.encoder, embedded_inputs, text_lengths)
|
||||
encoder_outputs = checkpoint(
|
||||
self.encoder, embedded_inputs, text_lengths)
|
||||
eps_pred, gate_outputs, alignments = self.decoder(
|
||||
wavs_diffused, wavs_corrected, timesteps, encoder_outputs, memory_lengths=text_lengths)
|
||||
|
||||
|
@ -234,7 +264,7 @@ if __name__ == '__main__':
|
|||
out = tron(wavs_diffused=torch.randn(2, 1, 22000),
|
||||
wavs_corrected=torch.randn(2, 1, 22000),
|
||||
timesteps=torch.LongTensor([555, 543]),
|
||||
text_inputs=torch.randint(high=24, size=(2,12)),
|
||||
text_inputs=torch.randint(high=24, size=(2, 12)),
|
||||
text_lengths=torch.tensor([12, 12]),
|
||||
output_lengths=torch.tensor([21995]))
|
||||
print([o.shape for o in out])
|
||||
print([o.shape for o in out])
|
||||
|
|
|
@ -23,11 +23,13 @@ Returns:
|
|||
import functools
|
||||
import random
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_intermediary as ml
|
||||
from tqdm import tqdm
|
||||
|
||||
import dlas.torch_intermediary as ml
|
||||
|
||||
|
||||
def null_position_embeddings(range, dim):
|
||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||
|
@ -61,13 +63,13 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
|
|||
"""
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
gpt_config = GPT2Config(vocab_size=256, # Unused.
|
||||
n_positions=max_mel_seq_len+max_text_seq_len,
|
||||
n_ctx=max_mel_seq_len+max_text_seq_len,
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
gradient_checkpointing=checkpointing,
|
||||
use_cache=not checkpointing)
|
||||
n_positions=max_mel_seq_len+max_text_seq_len,
|
||||
n_ctx=max_mel_seq_len+max_text_seq_len,
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
gradient_checkpointing=checkpointing,
|
||||
use_cache=not checkpointing)
|
||||
gpt = GPT2Model(gpt_config)
|
||||
# Override the built in positional embeddings
|
||||
del gpt.wpe
|
||||
|
@ -75,8 +77,10 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
|
|||
# Built-in token embeddings are unused.
|
||||
del gpt.wte
|
||||
|
||||
mel_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim)
|
||||
text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim)
|
||||
mel_pos_emb = LearnedPositionEmbeddings(
|
||||
max_mel_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim)
|
||||
text_pos_emb = LearnedPositionEmbeddings(
|
||||
max_text_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim)
|
||||
return gpt, mel_pos_emb, text_pos_emb, None, None
|
||||
|
||||
|
||||
|
@ -85,7 +89,8 @@ def build_lr_performer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_l
|
|||
lucidrains Performer implementation, https://github.com/lucidrains/performer-pytorch
|
||||
"""
|
||||
from models.lucidrains.performer.performer_pytorch import Performer
|
||||
model = Performer(dim=model_dim, depth=layers, heads=heads, dim_head=model_dim, causal=True)
|
||||
model = Performer(dim=model_dim, depth=layers,
|
||||
heads=heads, dim_head=model_dim, causal=True)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -104,14 +109,14 @@ def build_lr_xformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len
|
|||
|
||||
|
||||
def test_all_performance(**kwargs):
|
||||
transformer_builders = [#build_hf_gpt_transformer,
|
||||
build_lr_performer,]
|
||||
# build_lr_reformer,
|
||||
# build_lr_xformer]
|
||||
transformer_builders = [ # build_hf_gpt_transformer,
|
||||
build_lr_performer,]
|
||||
# build_lr_reformer,
|
||||
# build_lr_xformer]
|
||||
for builder in transformer_builders:
|
||||
model = builder(**kwargs)
|
||||
start = time()
|
||||
args = torch.randint(0, 8192, (16,450))
|
||||
args = torch.randint(0, 8192, (16, 450))
|
||||
for k in tqdm(range(10)):
|
||||
model(args)
|
||||
stop = time()
|
||||
|
@ -119,4 +124,5 @@ def test_all_performance(**kwargs):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_all_performance(layers=12, model_dim=512, heads=8, num_tokens=8192, max_seq_len=1000, checkpointing=False)
|
||||
test_all_performance(layers=12, model_dim=512, heads=8,
|
||||
num_tokens=8192, max_seq_len=1000, checkpointing=False)
|
||||
|
|
|
@ -2,17 +2,23 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
|
||||
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
import torch_intermediary as ml
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
|
||||
timestep_embedding, zero_module)
|
||||
from dlas.models.diffusion.unet_diffusion import (TimestepBlock,
|
||||
TimestepEmbedSequential)
|
||||
from dlas.models.lucidrains.x_transformers import (Attention, Encoder,
|
||||
FeedForward,
|
||||
RMSScaleShiftNorm,
|
||||
RotaryEmbedding)
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import checkpoint
|
||||
|
||||
|
||||
def is_latent(t):
|
||||
return t.dtype == torch.float
|
||||
|
||||
|
||||
def is_sequence(t):
|
||||
return t.dtype == torch.long
|
||||
|
||||
|
@ -21,7 +27,8 @@ class MultiGroupEmbedding(nn.Module):
|
|||
def __init__(self, tokens, groups, dim):
|
||||
super().__init__()
|
||||
# nn.Embedding
|
||||
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
|
||||
self.m = nn.ModuleList(
|
||||
[ml.Embedding(tokens, dim // groups) for _ in range(groups)])
|
||||
|
||||
def forward(self, x):
|
||||
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
||||
|
@ -41,13 +48,16 @@ class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock):
|
|||
class AttentionBlock(TimestepBlock):
|
||||
def __init__(self, dim, heads, dropout):
|
||||
super().__init__()
|
||||
self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout, zero_init_output=False)
|
||||
self.ff = FeedForward(dim, mult=1, dropout=dropout, zero_init_output=True)
|
||||
self.attn = Attention(dim, heads=heads, causal=False,
|
||||
dropout=dropout, zero_init_output=False)
|
||||
self.ff = FeedForward(
|
||||
dim, mult=1, dropout=dropout, zero_init_output=True)
|
||||
self.rms_scale_norm = RMSScaleShiftNorm(dim)
|
||||
|
||||
def forward(self, x, timestep_emb, rotary_emb):
|
||||
h = self.rms_scale_norm(x, norm_scale_shift_inp=timestep_emb)
|
||||
h, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb)
|
||||
h, _, _, _ = checkpoint(self.attn, h, None, None,
|
||||
None, None, None, rotary_emb)
|
||||
h = checkpoint(self.ff, h)
|
||||
return h + x
|
||||
|
||||
|
@ -56,6 +66,7 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
"""
|
||||
A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way?
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_channels=512,
|
||||
|
@ -71,7 +82,8 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
dropout=0,
|
||||
use_fp16=False,
|
||||
# Parameters for regularization.
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
# This implements a mechanism similar to what is used in classifier-free training.
|
||||
unconditioned_percentage=.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -91,17 +103,17 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
linear(model_channels, model_channels),
|
||||
)
|
||||
self.conditioning_embedder = nn.Sequential(nn.Conv1d(in_channels, model_channels // 2, 3, padding=1, stride=2),
|
||||
nn.Conv1d(model_channels//2, model_channels,3,padding=1,stride=2))
|
||||
nn.Conv1d(model_channels//2, model_channels, 3, padding=1, stride=2))
|
||||
self.conditioning_encoder = Encoder(
|
||||
dim=model_channels,
|
||||
depth=4,
|
||||
heads=heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
)
|
||||
dim=model_channels,
|
||||
depth=4,
|
||||
heads=heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
)
|
||||
self.clvp_encoder = ml.Linear(clvp_in_dim, model_channels)
|
||||
# nn.Embedding
|
||||
self.type_embedding = ml.Embedding(types, model_channels)
|
||||
|
@ -114,43 +126,48 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
# nn.Embedding
|
||||
self.embeddings = ml.Embedding(token_count, model_channels)
|
||||
else:
|
||||
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
|
||||
self.embeddings = MultiGroupEmbedding(
|
||||
token_count, in_groups, model_channels)
|
||||
self.latent_conditioner = nn.Sequential(
|
||||
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
|
||||
Encoder(
|
||||
dim=model_channels,
|
||||
depth=2,
|
||||
heads=heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
)
|
||||
dim=model_channels,
|
||||
depth=2,
|
||||
heads=heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
)
|
||||
)
|
||||
self.latent_fade = nn.Parameter(torch.zeros(1,1,model_channels))
|
||||
self.latent_fade = nn.Parameter(torch.zeros(1, 1, model_channels))
|
||||
self.code_converter = Encoder(
|
||||
dim=model_channels,
|
||||
depth=3,
|
||||
heads=heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
)
|
||||
dim=model_channels,
|
||||
depth=3,
|
||||
heads=heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
)
|
||||
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))
|
||||
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
||||
self.unconditioned_embedding = nn.Parameter(
|
||||
torch.randn(1, 1, model_channels))
|
||||
self.mel_head = nn.Conv1d(
|
||||
model_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
|
||||
self.intg = ml.Linear(model_channels*2, model_channels)
|
||||
self.layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers)])
|
||||
self.layers = TimestepRotaryEmbedSequential(
|
||||
*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers)])
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(model_channels),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
|
||||
zero_module(conv_nd(1, model_channels,
|
||||
out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
self.debug_codes = {}
|
||||
|
@ -165,7 +182,8 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
return groups
|
||||
|
||||
def timestep_independent(self, codes, conditioning_input, expected_seq_len, prenet_latent=None, return_code_pred=False):
|
||||
cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1)
|
||||
cond_emb = self.conditioning_embedder(
|
||||
conditioning_input).permute(0, 2, 1)
|
||||
cond_emb = self.conditioning_encoder(cond_emb)[:, 0]
|
||||
|
||||
code_emb = self.embeddings(codes)
|
||||
|
@ -173,7 +191,8 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
latent_conditioning = self.latent_conditioner(prenet_latent)
|
||||
code_emb = code_emb + latent_conditioning * self.latent_fade
|
||||
|
||||
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
||||
unconditioned_batches = torch.zeros(
|
||||
(code_emb.shape[0], 1, 1), device=code_emb.device)
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
||||
|
@ -182,57 +201,65 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
code_emb)
|
||||
code_emb = self.code_converter(code_emb)
|
||||
|
||||
expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1)
|
||||
expanded_code_emb = F.interpolate(code_emb.permute(
|
||||
0, 2, 1), size=expected_seq_len, mode='nearest').permute(0, 2, 1)
|
||||
if not return_code_pred:
|
||||
return expanded_code_emb, cond_emb
|
||||
else:
|
||||
# Perform the mel_head computation on the pre-exanded code embeddings, then interpolate it separately.
|
||||
mel_pred = self.mel_head(code_emb.permute(0,2,1))
|
||||
mel_pred = F.interpolate(mel_pred, size=expected_seq_len, mode='nearest')
|
||||
mel_pred = self.mel_head(code_emb.permute(0, 2, 1))
|
||||
mel_pred = F.interpolate(
|
||||
mel_pred, size=expected_seq_len, mode='nearest')
|
||||
# Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches.
|
||||
# This is because we don't want that gradient being used to train parameters through the codes_embedder as
|
||||
# it unbalances contributions to that network from the MSE loss.
|
||||
mel_pred = mel_pred * unconditioned_batches.logical_not()
|
||||
return expanded_code_emb, cond_emb, mel_pred
|
||||
|
||||
|
||||
def forward(self, x, timesteps, codes=None, conditioning_input=None, clvp_input=None, type=None, prenet_latent=None, precomputed_code_embeddings=None,
|
||||
precomputed_cond_embeddings=None, conditioning_free=False, return_code_pred=False):
|
||||
if precomputed_code_embeddings is not None:
|
||||
assert precomputed_cond_embeddings is not None, "Must specify both precomputed embeddings if one is specified"
|
||||
assert codes is None and conditioning_input is None and prenet_latent is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here."
|
||||
assert not (return_code_pred and precomputed_code_embeddings is not None), "I cannot compute a code_pred output for you."
|
||||
assert not (
|
||||
return_code_pred and precomputed_code_embeddings is not None), "I cannot compute a code_pred output for you."
|
||||
assert type is not None, "Type is required."
|
||||
|
||||
unused_params = []
|
||||
if not return_code_pred:
|
||||
unused_params.extend(list(self.mel_head.parameters()))
|
||||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
||||
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||
code_emb = self.unconditioned_embedding.repeat(
|
||||
x.shape[0], 1, x.shape[-1])
|
||||
unused_params.extend(
|
||||
list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||
unused_params.extend(list(self.latent_conditioner.parameters()))
|
||||
else:
|
||||
if precomputed_code_embeddings is not None:
|
||||
code_emb = precomputed_code_embeddings
|
||||
cond_emb = precomputed_cond_embeddings
|
||||
else:
|
||||
code_emb, cond_emb, mel_pred = self.timestep_independent(codes, conditioning_input, x.shape[-1], prenet_latent, True)
|
||||
code_emb, cond_emb, mel_pred = self.timestep_independent(
|
||||
codes, conditioning_input, x.shape[-1], prenet_latent, True)
|
||||
if prenet_latent is None:
|
||||
unused_params.extend(list(self.latent_conditioner.parameters()) + [self.latent_fade])
|
||||
unused_params.extend(
|
||||
list(self.latent_conditioner.parameters()) + [self.latent_fade])
|
||||
unused_params.append(self.unconditioned_embedding)
|
||||
|
||||
clvp_emb = torch.zeros_like(cond_emb) if clvp_input is None else self.clvp_encoder(clvp_input)
|
||||
clvp_emb = torch.zeros_like(
|
||||
cond_emb) if clvp_input is None else self.clvp_encoder(clvp_input)
|
||||
type_emb = self.type_embedding(type)
|
||||
if clvp_input is None:
|
||||
unused_params.extend(self.clvp_encoder.parameters())
|
||||
blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb + clvp_emb + type_emb
|
||||
x = self.inp_block(x).permute(0,2,1)
|
||||
blk_emb = self.time_embed(timestep_embedding(
|
||||
timesteps, self.model_channels)) + cond_emb + clvp_emb + type_emb
|
||||
x = self.inp_block(x).permute(0, 2, 1)
|
||||
|
||||
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
|
||||
x = self.intg(torch.cat([x, code_emb], dim=-1))
|
||||
x = self.layers(x, blk_emb, rotary_pos_emb)
|
||||
|
||||
x = x.float().permute(0,2,1)
|
||||
x = x.float().permute(0, 2, 1)
|
||||
out = self.out(x)
|
||||
|
||||
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
||||
|
@ -253,13 +280,14 @@ def register_transformer_diffusion_tts(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
clip = torch.randn(2, 256, 400)
|
||||
aligned_latent = torch.randn(2,100,512)
|
||||
aligned_sequence = torch.randint(0,8,(2,100,8))
|
||||
aligned_latent = torch.randn(2, 100, 512)
|
||||
aligned_sequence = torch.randint(0, 8, (2, 100, 8))
|
||||
cond = torch.randn(2, 256, 400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
clvp = torch.randn(2,768)
|
||||
type = torch.LongTensor([0,1])
|
||||
model = TransformerDiffusionTTS(512, unconditioned_percentage=.5, in_groups=8)
|
||||
o = model(clip, ts, aligned_sequence, cond, clvp_input=clvp, type=type, return_code_pred=True)
|
||||
#o = model(clip, ts, aligned_sequence, cond, aligned_latent)
|
||||
|
||||
clvp = torch.randn(2, 768)
|
||||
type = torch.LongTensor([0, 1])
|
||||
model = TransformerDiffusionTTS(
|
||||
512, unconditioned_percentage=.5, in_groups=8)
|
||||
o = model(clip, ts, aligned_sequence, cond,
|
||||
clvp_input=clvp, type=type, return_code_pred=True)
|
||||
# o = model(clip, ts, aligned_sequence, cond, aligned_latent)
|
||||
|
|
|
@ -1,18 +1,24 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
|
||||
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, print_network
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
|
||||
timestep_embedding, zero_module)
|
||||
from dlas.models.diffusion.unet_diffusion import (TimestepBlock,
|
||||
TimestepEmbedSequential)
|
||||
from dlas.models.lucidrains.x_transformers import (Attention, Encoder,
|
||||
FeedForward,
|
||||
RMSScaleShiftNorm,
|
||||
RotaryEmbedding)
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import checkpoint, print_network
|
||||
|
||||
|
||||
def is_latent(t):
|
||||
return t.dtype == torch.float
|
||||
|
||||
|
||||
def is_sequence(t):
|
||||
return t.dtype == torch.long
|
||||
|
||||
|
@ -21,7 +27,8 @@ class MultiGroupEmbedding(nn.Module):
|
|||
def __init__(self, tokens, groups, dim):
|
||||
super().__init__()
|
||||
# nn.Embedding
|
||||
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
|
||||
self.m = nn.ModuleList(
|
||||
[ml.Embedding(tokens, dim // groups) for _ in range(groups)])
|
||||
|
||||
def forward(self, x):
|
||||
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
||||
|
@ -44,12 +51,14 @@ class DietAttentionBlock(TimestepBlock):
|
|||
self.rms_scale_norm = RMSScaleShiftNorm(in_dim)
|
||||
self.proj = ml.Linear(in_dim, dim)
|
||||
self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout)
|
||||
self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True)
|
||||
self.ff = FeedForward(dim, in_dim, mult=1,
|
||||
dropout=dropout, zero_init_output=True)
|
||||
|
||||
def forward(self, x, timestep_emb, rotary_emb):
|
||||
h = self.rms_scale_norm(x, norm_scale_shift_inp=timestep_emb)
|
||||
h = self.proj(h)
|
||||
h, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb)
|
||||
h, _, _, _ = checkpoint(self.attn, h, None, None,
|
||||
None, None, None, rotary_emb)
|
||||
h = checkpoint(self.ff, h)
|
||||
return h + x
|
||||
|
||||
|
@ -58,6 +67,7 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
"""
|
||||
A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way?
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prenet_channels=256,
|
||||
|
@ -75,7 +85,8 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
dropout=0,
|
||||
use_fp16=False,
|
||||
# Parameters for regularization.
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
# This implements a mechanism similar to what is used in classifier-free training.
|
||||
unconditioned_percentage=.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -96,17 +107,17 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
)
|
||||
prenet_heads = prenet_channels//64
|
||||
self.conditioning_embedder = nn.Sequential(nn.Conv1d(in_channels, prenet_channels // 2, 3, padding=1, stride=2),
|
||||
nn.Conv1d(prenet_channels//2, prenet_channels,3,padding=1,stride=2))
|
||||
nn.Conv1d(prenet_channels//2, prenet_channels, 3, padding=1, stride=2))
|
||||
self.conditioning_encoder = Encoder(
|
||||
dim=prenet_channels,
|
||||
depth=4,
|
||||
heads=prenet_heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
)
|
||||
dim=prenet_channels,
|
||||
depth=4,
|
||||
heads=prenet_heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
)
|
||||
self.clvp_encoder = ml.Linear(clvp_in_dim, prenet_channels)
|
||||
# nn.Embedding
|
||||
self.type_embedding = ml.Embedding(types, prenet_channels)
|
||||
|
@ -119,45 +130,48 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
# nn.Embedding
|
||||
self.embeddings = ml.Embedding(token_count, prenet_channels)
|
||||
else:
|
||||
self.embeddings = MultiGroupEmbedding(token_count, in_groups, prenet_channels)
|
||||
self.embeddings = MultiGroupEmbedding(
|
||||
token_count, in_groups, prenet_channels)
|
||||
self.latent_conditioner = nn.Sequential(
|
||||
nn.Conv1d(in_latent_channels, prenet_channels, 3, padding=1),
|
||||
Encoder(
|
||||
dim=prenet_channels,
|
||||
depth=2,
|
||||
heads=prenet_heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
)
|
||||
dim=prenet_channels,
|
||||
depth=2,
|
||||
heads=prenet_heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
)
|
||||
)
|
||||
self.latent_fade = nn.Parameter(torch.zeros(1,1,prenet_channels))
|
||||
self.latent_fade = nn.Parameter(torch.zeros(1, 1, prenet_channels))
|
||||
self.code_converter = Encoder(
|
||||
dim=prenet_channels,
|
||||
depth=3,
|
||||
heads=prenet_heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
)
|
||||
dim=prenet_channels,
|
||||
depth=3,
|
||||
heads=prenet_heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
)
|
||||
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
|
||||
self.unconditioned_embedding = nn.Parameter(
|
||||
torch.randn(1, 1, prenet_channels))
|
||||
|
||||
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
|
||||
self.cond_intg = ml.Linear(prenet_channels*4, model_channels)
|
||||
self.intg = ml.Linear(prenet_channels*2, model_channels)
|
||||
|
||||
self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)])
|
||||
|
||||
self.layers = TimestepRotaryEmbedSequential(
|
||||
*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)])
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(model_channels),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
|
||||
zero_module(conv_nd(1, model_channels,
|
||||
out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
self.debug_codes = {}
|
||||
|
@ -172,7 +186,8 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
return groups
|
||||
|
||||
def timestep_independent(self, codes, conditioning_input, expected_seq_len, prenet_latent=None):
|
||||
cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1)
|
||||
cond_emb = self.conditioning_embedder(
|
||||
conditioning_input).permute(0, 2, 1)
|
||||
cond_emb = self.conditioning_encoder(cond_emb)[:, 0]
|
||||
|
||||
code_emb = self.embeddings(codes)
|
||||
|
@ -188,11 +203,11 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
code_emb)
|
||||
code_emb = self.code_converter(code_emb)
|
||||
|
||||
expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1)
|
||||
expanded_code_emb = F.interpolate(code_emb.permute(
|
||||
0, 2, 1), size=expected_seq_len, mode='nearest').permute(0, 2, 1)
|
||||
|
||||
return expanded_code_emb, cond_emb
|
||||
|
||||
|
||||
def forward(self, x, timesteps, codes=None, conditioning_input=None, clvp_input=None, type=None, prenet_latent=None, precomputed_code_embeddings=None,
|
||||
precomputed_cond_embeddings=None, conditioning_free=False):
|
||||
if precomputed_code_embeddings is not None:
|
||||
|
@ -202,32 +217,38 @@ class TransformerDiffusionTTS(nn.Module):
|
|||
|
||||
unused_params = []
|
||||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
||||
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||
code_emb = self.unconditioned_embedding.repeat(
|
||||
x.shape[0], 1, x.shape[-1])
|
||||
unused_params.extend(
|
||||
list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||
unused_params.extend(list(self.latent_conditioner.parameters()))
|
||||
else:
|
||||
if precomputed_code_embeddings is not None:
|
||||
code_emb = precomputed_code_embeddings
|
||||
cond_emb = precomputed_cond_embeddings
|
||||
else:
|
||||
code_emb, cond_emb = self.timestep_independent(codes, conditioning_input, x.shape[-1], prenet_latent)
|
||||
code_emb, cond_emb = self.timestep_independent(
|
||||
codes, conditioning_input, x.shape[-1], prenet_latent)
|
||||
if prenet_latent is None:
|
||||
unused_params.extend(list(self.latent_conditioner.parameters()) + [self.latent_fade])
|
||||
unused_params.extend(
|
||||
list(self.latent_conditioner.parameters()) + [self.latent_fade])
|
||||
unused_params.append(self.unconditioned_embedding)
|
||||
|
||||
clvp_emb = torch.zeros_like(cond_emb) if clvp_input is None else self.clvp_encoder(clvp_input)
|
||||
clvp_emb = torch.zeros_like(
|
||||
cond_emb) if clvp_input is None else self.clvp_encoder(clvp_input)
|
||||
type_emb = self.type_embedding(type)
|
||||
if clvp_input is None:
|
||||
unused_params.extend(self.clvp_encoder.parameters())
|
||||
blk_emb = torch.cat([self.time_embed(timestep_embedding(timesteps, self.prenet_channels)), cond_emb, clvp_emb, type_emb], dim=-1)
|
||||
blk_emb = torch.cat([self.time_embed(timestep_embedding(
|
||||
timesteps, self.prenet_channels)), cond_emb, clvp_emb, type_emb], dim=-1)
|
||||
blk_emb = self.cond_intg(blk_emb)
|
||||
x = self.inp_block(x).permute(0,2,1)
|
||||
x = self.inp_block(x).permute(0, 2, 1)
|
||||
|
||||
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
|
||||
x = self.intg(torch.cat([x, code_emb], dim=-1))
|
||||
x = self.layers(x, blk_emb, rotary_pos_emb)
|
||||
|
||||
x = x.float().permute(0,2,1)
|
||||
x = x.float().permute(0, 2, 1)
|
||||
out = self.out(x)
|
||||
|
||||
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
||||
|
@ -246,15 +267,15 @@ def register_transformer_diffusion_tts2(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
clip = torch.randn(2, 256, 400)
|
||||
aligned_latent = torch.randn(2,100,512)
|
||||
aligned_sequence = torch.randint(0,8,(2,100,8))
|
||||
aligned_latent = torch.randn(2, 100, 512)
|
||||
aligned_sequence = torch.randint(0, 8, (2, 100, 8))
|
||||
cond = torch.randn(2, 256, 400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
clvp = torch.randn(2,768)
|
||||
type = torch.LongTensor([0,1])
|
||||
model = TransformerDiffusionTTS(model_channels=3072, num_layers=16, unconditioned_percentage=.5, in_groups=8, prenet_channels=1024, block_channels=1024)
|
||||
clvp = torch.randn(2, 768)
|
||||
type = torch.LongTensor([0, 1])
|
||||
model = TransformerDiffusionTTS(model_channels=3072, num_layers=16,
|
||||
unconditioned_percentage=.5, in_groups=8, prenet_channels=1024, block_channels=1024)
|
||||
print_network(model)
|
||||
o = model(clip, ts, aligned_sequence, cond, clvp_input=clvp, type=type)
|
||||
torch.save(model.state_dict(), 'test.pth')
|
||||
#o = model(clip, ts, aligned_sequence, cond, aligned_latent)
|
||||
|
||||
# o = model(clip, ts, aligned_sequence, cond, aligned_latent)
|
||||
|
|
|
@ -5,16 +5,19 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import autocast
|
||||
import torch_intermediary as ml
|
||||
from x_transformers import ContinuousTransformerWrapper, Encoder
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
|
||||
Downsample, Upsample, TimestepBlock
|
||||
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
from x_transformers import Encoder, ContinuousTransformerWrapper
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
|
||||
timestep_embedding, zero_module)
|
||||
from dlas.models.diffusion.unet_diffusion import (AttentionBlock, Downsample,
|
||||
TimestepBlock,
|
||||
TimestepEmbedSequential,
|
||||
Upsample)
|
||||
from dlas.scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import checkpoint
|
||||
|
||||
|
||||
def clustered_mask(probability, shape, dev, lateral_expansion_radius_max=3, inverted=False):
|
||||
|
@ -567,4 +570,3 @@ if __name__ == '__main__':
|
|||
o.sum().backward()
|
||||
model.before_step(0)
|
||||
torch.save(model.state_dict(), 'test_out.pth')
|
||||
|
||||
|
|
|
@ -5,21 +5,26 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch import autocast
|
||||
from x_transformers import Encoder
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
|
||||
Downsample, Upsample, TimestepBlock
|
||||
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder
|
||||
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from dlas.models.audio.tts.unet_diffusion_tts7 import \
|
||||
CheckpointedXTransformerEncoder
|
||||
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
|
||||
timestep_embedding, zero_module)
|
||||
from dlas.models.diffusion.unet_diffusion import (AttentionBlock, Downsample,
|
||||
TimestepBlock,
|
||||
TimestepEmbedSequential,
|
||||
Upsample)
|
||||
from dlas.scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import checkpoint
|
||||
|
||||
|
||||
def is_latent(t):
|
||||
return t.dtype == torch.float
|
||||
|
||||
|
||||
def is_sequence(t):
|
||||
return t.dtype == torch.long
|
||||
|
||||
|
@ -49,7 +54,8 @@ class ResBlock(TimestepBlock):
|
|||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding),
|
||||
conv_nd(dims, channels, self.out_channels,
|
||||
eff_kernel, padding=eff_padding),
|
||||
)
|
||||
|
||||
self.emb_layers = nn.Sequential(
|
||||
|
@ -64,14 +70,16 @@ class ResBlock(TimestepBlock):
|
|||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
|
||||
conv_nd(dims, self.out_channels, self.out_channels,
|
||||
kernel_size, padding=padding)
|
||||
),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
|
@ -100,6 +108,7 @@ class ResBlock(TimestepBlock):
|
|||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class DiffusionTts(nn.Module):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
|
@ -143,12 +152,12 @@ class DiffusionTts(nn.Module):
|
|||
out_channels=2, # mean and variance
|
||||
dropout=0,
|
||||
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
|
||||
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
|
||||
channel_mult=(1, 1.5, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
|
||||
num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
|
||||
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
|
||||
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
|
||||
token_conditioning_resolutions=(1,16,),
|
||||
attention_resolutions=(512,1024,2048),
|
||||
token_conditioning_resolutions=(1, 16,),
|
||||
attention_resolutions=(512, 1024, 2048),
|
||||
conv_resample=True,
|
||||
dims=1,
|
||||
use_fp16=False,
|
||||
|
@ -159,10 +168,12 @@ class DiffusionTts(nn.Module):
|
|||
scale_factor=2,
|
||||
time_embed_dim_multiplier=4,
|
||||
freeze_main_net=False,
|
||||
efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
|
||||
# Uses kernels with width of 1 in several places rather than 3.
|
||||
efficient_convs=True,
|
||||
use_scale_shift_norm=True,
|
||||
# Parameters for regularization.
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
# This implements a mechanism similar to what is used in classifier-free training.
|
||||
unconditioned_percentage=.1,
|
||||
# Parameters for super-sampling.
|
||||
super_sampling=False,
|
||||
super_sampling_max_noising_factor=.1,
|
||||
|
@ -173,7 +184,8 @@ class DiffusionTts(nn.Module):
|
|||
num_heads_upsample = num_heads
|
||||
|
||||
if super_sampling:
|
||||
in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input.
|
||||
# In super-sampling mode, the LR input is concatenated directly onto the input.
|
||||
in_channels *= 2
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
|
@ -224,10 +236,12 @@ class DiffusionTts(nn.Module):
|
|||
rotary_pos_emb=True,
|
||||
)
|
||||
))
|
||||
self.latent_converter = nn.Conv1d(in_latent_channels, conditioning_dim, 1)
|
||||
self.aligned_latent_padding_embedding = nn.Parameter(torch.randn(1,in_latent_channels,1))
|
||||
self.latent_converter = nn.Conv1d(
|
||||
in_latent_channels, conditioning_dim, 1)
|
||||
self.aligned_latent_padding_embedding = nn.Parameter(
|
||||
torch.randn(1, in_latent_channels, 1))
|
||||
if in_channels > 60: # It's a spectrogram.
|
||||
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,conditioning_dim,3,padding=1,stride=2),
|
||||
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels, conditioning_dim, 3, padding=1, stride=2),
|
||||
CheckpointedXTransformerEncoder(
|
||||
needs_permute=True,
|
||||
max_seq_len=-1,
|
||||
|
@ -242,25 +256,33 @@ class DiffusionTts(nn.Module):
|
|||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
)
|
||||
))
|
||||
))
|
||||
else:
|
||||
self.contextual_embedder = AudioMiniEncoder(1, conditioning_dim, base_channels=32, depth=6, resnet_blocks=1,
|
||||
attn_blocks=3, num_attn_heads=8, dropout=dropout, downsample_factor=4, kernel_size=5)
|
||||
self.conditioning_conv = nn.Conv1d(conditioning_dim*2, conditioning_dim, 1)
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
|
||||
self.conditioning_conv = nn.Conv1d(
|
||||
conditioning_dim*2, conditioning_dim, 1)
|
||||
self.unconditioned_embedding = nn.Parameter(
|
||||
torch.randn(1, conditioning_dim, 1))
|
||||
self.conditioning_timestep_integrator = TimestepEmbedSequential(
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim,
|
||||
dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
AttentionBlock(conditioning_dim, num_heads=num_heads,
|
||||
num_head_channels=num_head_channels),
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim,
|
||||
dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
AttentionBlock(conditioning_dim, num_heads=num_heads,
|
||||
num_head_channels=num_head_channels),
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim,
|
||||
dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
)
|
||||
self.conditioning_expansion = conditioning_expansion
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
|
||||
conv_nd(dims, in_channels, model_channels,
|
||||
kernel_size, padding=padding)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
@ -371,7 +393,8 @@ class DiffusionTts(nn.Module):
|
|||
if level and i == num_blocks:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor)
|
||||
Upsample(ch, conv_resample, dims=dims,
|
||||
out_channels=out_ch, factor=scale_factor)
|
||||
)
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
|
@ -380,7 +403,8 @@ class DiffusionTts(nn.Module):
|
|||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels,
|
||||
kernel_size, padding=padding)),
|
||||
)
|
||||
|
||||
if self.freeze_main_net:
|
||||
|
@ -410,13 +434,14 @@ class DiffusionTts(nn.Module):
|
|||
cm = ceil_multiple(x.shape[-1], self.alignment_size)
|
||||
if cm != 0:
|
||||
pc = (cm-x.shape[-1])/x.shape[-1]
|
||||
x = F.pad(x, (0,cm-x.shape[-1]))
|
||||
x = F.pad(x, (0, cm-x.shape[-1]))
|
||||
# Also fix aligned_latent, which is aligned to x.
|
||||
if is_latent(aligned_conditioning):
|
||||
aligned_conditioning = torch.cat([aligned_conditioning,
|
||||
self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1)
|
||||
else:
|
||||
aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1])))
|
||||
aligned_conditioning = F.pad(
|
||||
aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1])))
|
||||
return x, aligned_conditioning
|
||||
|
||||
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False):
|
||||
|
@ -435,9 +460,12 @@ class DiffusionTts(nn.Module):
|
|||
if self.super_sampling_enabled:
|
||||
assert lr_input is not None
|
||||
if self.training and self.super_sampling_max_noising_factor > 0:
|
||||
noising_factor = random.uniform(0,self.super_sampling_max_noising_factor)
|
||||
lr_input = torch.randn_like(lr_input) * noising_factor + lr_input
|
||||
lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest')
|
||||
noising_factor = random.uniform(
|
||||
0, self.super_sampling_max_noising_factor)
|
||||
lr_input = torch.randn_like(
|
||||
lr_input) * noising_factor + lr_input
|
||||
lr_input = F.interpolate(
|
||||
lr_input, size=(x.shape[-1],), mode='nearest')
|
||||
x = torch.cat([x, lr_input], dim=1)
|
||||
|
||||
# Shuffle aligned_latent to BxCxS format
|
||||
|
@ -451,11 +479,13 @@ class DiffusionTts(nn.Module):
|
|||
with autocast(x.device.type, enabled=self.enable_fp16):
|
||||
|
||||
hs = []
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
time_emb = self.time_embed(
|
||||
timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
# Note: this block does not need to repeated on inference, since it is not timestep-dependent.
|
||||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
|
||||
code_emb = self.unconditioned_embedding.repeat(
|
||||
x.shape[0], 1, 1)
|
||||
else:
|
||||
cond_emb = self.contextual_embedder(conditioning_input)
|
||||
if len(cond_emb.shape) == 3: # Just take the first element.
|
||||
|
@ -464,8 +494,10 @@ class DiffusionTts(nn.Module):
|
|||
code_emb = self.latent_converter(aligned_conditioning)
|
||||
else:
|
||||
code_emb = self.code_converter(aligned_conditioning)
|
||||
cond_emb = cond_emb.unsqueeze(-1).repeat(1, 1, code_emb.shape[-1])
|
||||
code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb], dim=1))
|
||||
cond_emb = cond_emb.unsqueeze(-1).repeat(1,
|
||||
1, code_emb.shape[-1])
|
||||
code_emb = self.conditioning_conv(
|
||||
torch.cat([cond_emb, code_emb], dim=1))
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
||||
|
@ -474,15 +506,18 @@ class DiffusionTts(nn.Module):
|
|||
code_emb)
|
||||
|
||||
# Everything after this comment is timestep dependent.
|
||||
code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
|
||||
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
||||
code_emb = torch.repeat_interleave(
|
||||
code_emb, self.conditioning_expansion, dim=-1)
|
||||
code_emb = self.conditioning_timestep_integrator(
|
||||
code_emb, time_emb)
|
||||
|
||||
first = True
|
||||
time_emb = time_emb.float()
|
||||
h = x
|
||||
for k, module in enumerate(self.input_blocks):
|
||||
if isinstance(module, nn.Conv1d):
|
||||
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
|
||||
h_tok = F.interpolate(module(code_emb), size=(
|
||||
h.shape[-1]), mode='nearest')
|
||||
h = h + h_tok
|
||||
else:
|
||||
with autocast(x.device.type, enabled=self.enable_fp16 and not first):
|
||||
|
@ -501,7 +536,8 @@ class DiffusionTts(nn.Module):
|
|||
|
||||
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
||||
extraneous_addition = 0
|
||||
params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters())
|
||||
params = [self.aligned_latent_padding_embedding,
|
||||
self.unconditioned_embedding] + list(self.latent_converter.parameters())
|
||||
for p in params:
|
||||
extraneous_addition = extraneous_addition + p.mean()
|
||||
out = out + extraneous_addition * 0
|
||||
|
@ -516,14 +552,14 @@ def register_diffusion_tts9(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
clip = torch.randn(2, 1, 32868)
|
||||
aligned_latent = torch.randn(2,388,1024)
|
||||
aligned_sequence = torch.randint(0,8192,(2,388))
|
||||
aligned_latent = torch.randn(2, 388, 1024)
|
||||
aligned_sequence = torch.randint(0, 8192, (2, 388))
|
||||
cond = torch.randn(2, 1, 44000)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = DiffusionTts(128,
|
||||
channel_mult=[1,1.5,2, 3, 4, 6, 8],
|
||||
channel_mult=[1, 1.5, 2, 3, 4, 6, 8],
|
||||
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
|
||||
token_conditioning_resolutions=[1,4,16,64],
|
||||
token_conditioning_resolutions=[1, 4, 16, 64],
|
||||
attention_resolutions=[],
|
||||
num_heads=8,
|
||||
kernel_size=3,
|
||||
|
@ -535,4 +571,3 @@ if __name__ == '__main__':
|
|||
o = model(clip, ts, aligned_latent, cond)
|
||||
# Test with sequence aligned conditioning
|
||||
o = model(clip, ts, aligned_sequence, cond)
|
||||
|
||||
|
|
|
@ -5,18 +5,22 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import autocast
|
||||
import torch_intermediary as ml
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock, QKVAttentionLegacy
|
||||
from models.lucidrains.x_transformers import RelativePositionBias
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
|
||||
timestep_embedding, zero_module)
|
||||
from dlas.models.diffusion.unet_diffusion import (QKVAttentionLegacy,
|
||||
TimestepBlock,
|
||||
TimestepEmbedSequential)
|
||||
from dlas.models.lucidrains.x_transformers import RelativePositionBias
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import checkpoint
|
||||
|
||||
|
||||
def is_latent(t):
|
||||
return t.dtype == torch.float
|
||||
|
||||
|
||||
def is_sequence(t):
|
||||
return t.dtype == torch.long
|
||||
|
||||
|
@ -54,7 +58,8 @@ class AttentionBlock(nn.Module):
|
|||
|
||||
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
|
||||
if relative_pos_embeddings:
|
||||
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
|
||||
self.relative_pos_embeddings = RelativePositionBias(scale=(
|
||||
channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
|
||||
else:
|
||||
self.relative_pos_embeddings = None
|
||||
|
||||
|
@ -92,7 +97,8 @@ class ResBlock(TimestepBlock):
|
|||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding),
|
||||
conv_nd(dims, channels, self.out_channels,
|
||||
eff_kernel, padding=eff_padding),
|
||||
)
|
||||
|
||||
self.emb_layers = nn.Sequential(
|
||||
|
@ -107,14 +113,16 @@ class ResBlock(TimestepBlock):
|
|||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
|
||||
conv_nd(dims, self.out_channels, self.out_channels,
|
||||
kernel_size, padding=padding)
|
||||
),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
|
@ -147,8 +155,10 @@ class ResBlock(TimestepBlock):
|
|||
class DiffusionLayer(TimestepBlock):
|
||||
def __init__(self, model_channels, dropout, num_heads):
|
||||
super().__init__()
|
||||
self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True)
|
||||
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
|
||||
self.resblk = ResBlock(model_channels, model_channels, dropout,
|
||||
model_channels, dims=1, use_scale_shift_norm=True)
|
||||
self.attn = AttentionBlock(
|
||||
model_channels, num_heads, relative_pos_embeddings=True)
|
||||
|
||||
def forward(self, x, time_emb):
|
||||
y = self.resblk(x, time_emb)
|
||||
|
@ -170,7 +180,8 @@ class DiffusionTtsFlat(nn.Module):
|
|||
freeze_everything_except_autoregressive_inputs=False,
|
||||
# Parameters for regularization.
|
||||
layer_drop=.1,
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
# This implements a mechanism similar to what is used in classifier-free training.
|
||||
unconditioned_percentage=.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -198,33 +209,48 @@ class DiffusionTtsFlat(nn.Module):
|
|||
# nn.Embedding
|
||||
self.code_embedding = ml.Embedding(in_tokens, model_channels)
|
||||
self.code_converter = nn.Sequential(
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads,
|
||||
relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads,
|
||||
relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads,
|
||||
relative_pos_embeddings=True),
|
||||
)
|
||||
self.code_norm = normalization(model_channels)
|
||||
self.latent_conditioner = nn.Sequential(
|
||||
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads,
|
||||
relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads,
|
||||
relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads,
|
||||
relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads,
|
||||
relative_pos_embeddings=True),
|
||||
)
|
||||
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
|
||||
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2),
|
||||
nn.Conv1d(
|
||||
model_channels, model_channels*2, 3, padding=1, stride=2),
|
||||
AttentionBlock(
|
||||
model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(
|
||||
model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(
|
||||
model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(
|
||||
model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False))
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
|
||||
self.unconditioned_embedding = nn.Parameter(
|
||||
torch.randn(1, model_channels, 1))
|
||||
self.conditioning_timestep_integrator = TimestepEmbedSequential(
|
||||
DiffusionLayer(model_channels, dropout, num_heads),
|
||||
DiffusionLayer(model_channels, dropout, num_heads),
|
||||
DiffusionLayer(model_channels, dropout, num_heads),
|
||||
)
|
||||
self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1)
|
||||
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
||||
self.integrating_conv = nn.Conv1d(
|
||||
model_channels*2, model_channels, kernel_size=1)
|
||||
self.mel_head = nn.Conv1d(
|
||||
model_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
|
||||
[ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])
|
||||
|
@ -232,7 +258,8 @@ class DiffusionTtsFlat(nn.Module):
|
|||
self.out = nn.Sequential(
|
||||
normalization(model_channels),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
|
||||
zero_module(conv_nd(1, model_channels,
|
||||
out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
if freeze_everything_except_autoregressive_inputs:
|
||||
|
@ -263,25 +290,30 @@ class DiffusionTtsFlat(nn.Module):
|
|||
conditioning_input.shape) == 3 else conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
|
||||
conds.append(self.contextual_embedder(
|
||||
speech_conditioning_input[:, j]))
|
||||
conds = torch.cat(conds, dim=-1)
|
||||
cond_emb = conds.mean(dim=-1)
|
||||
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
|
||||
if is_latent(aligned_conditioning):
|
||||
code_emb = self.latent_conditioner(aligned_conditioning)
|
||||
else:
|
||||
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
|
||||
code_emb = self.code_embedding(
|
||||
aligned_conditioning).permute(0, 2, 1)
|
||||
code_emb = self.code_converter(code_emb)
|
||||
code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
|
||||
code_emb = self.code_norm(
|
||||
code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
|
||||
|
||||
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
||||
unconditioned_batches = torch.zeros(
|
||||
(code_emb.shape[0], 1, 1), device=code_emb.device)
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
||||
device=code_emb.device) < self.unconditioned_percentage
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
|
||||
code_emb)
|
||||
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
|
||||
expanded_code_emb = F.interpolate(
|
||||
code_emb, size=expected_seq_len, mode='nearest')
|
||||
|
||||
if not return_code_pred:
|
||||
return expanded_code_emb
|
||||
|
@ -291,7 +323,6 @@ class DiffusionTtsFlat(nn.Module):
|
|||
mel_pred = mel_pred * unconditioned_batches.logical_not()
|
||||
return expanded_code_emb, mel_pred
|
||||
|
||||
|
||||
def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
@ -304,27 +335,36 @@ class DiffusionTtsFlat(nn.Module):
|
|||
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_input is not None)
|
||||
assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
|
||||
assert precomputed_aligned_embeddings is not None or (
|
||||
aligned_conditioning is not None and conditioning_input is not None)
|
||||
# These two are mutually exclusive.
|
||||
assert not (
|
||||
return_code_pred and precomputed_aligned_embeddings is not None)
|
||||
|
||||
unused_params = list(self.mel_head.parameters())
|
||||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
||||
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||
code_emb = self.unconditioned_embedding.repeat(
|
||||
x.shape[0], 1, x.shape[-1])
|
||||
unused_params.extend(
|
||||
list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||
unused_params.extend(list(self.latent_conditioner.parameters()))
|
||||
else:
|
||||
if precomputed_aligned_embeddings is not None:
|
||||
code_emb = precomputed_aligned_embeddings
|
||||
else:
|
||||
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, x.shape[-1], True)
|
||||
code_emb, mel_pred = self.timestep_independent(
|
||||
aligned_conditioning, conditioning_input, x.shape[-1], True)
|
||||
if is_latent(aligned_conditioning):
|
||||
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||
unused_params.extend(
|
||||
list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||
else:
|
||||
unused_params.extend(list(self.latent_conditioner.parameters()))
|
||||
unused_params.extend(
|
||||
list(self.latent_conditioner.parameters()))
|
||||
|
||||
unused_params.append(self.unconditioned_embedding)
|
||||
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
time_emb = self.time_embed(
|
||||
timestep_embedding(timesteps, self.model_channels))
|
||||
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
||||
x = self.inp_block(x)
|
||||
x = torch.cat([x, code_emb], dim=1)
|
||||
|
@ -356,10 +396,12 @@ class DiffusionTtsFlat(nn.Module):
|
|||
conditioning_input.shape) == 3 else conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
|
||||
conds.append(self.contextual_embedder(
|
||||
speech_conditioning_input[:, j]))
|
||||
conds = torch.cat(conds, dim=-1)
|
||||
return conds.mean(dim=-1)
|
||||
|
||||
|
||||
@register_model
|
||||
def register_diffusion_tts_flat(opt_net, opt):
|
||||
return DiffusionTtsFlat(**opt_net['kwargs'])
|
||||
|
@ -367,17 +409,17 @@ def register_diffusion_tts_flat(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
clip = torch.randn(2, 100, 400)
|
||||
aligned_latent = torch.randn(2,388,512)
|
||||
aligned_sequence = torch.randint(0,8192,(2,100))
|
||||
aligned_latent = torch.randn(2, 388, 512)
|
||||
aligned_sequence = torch.randint(0, 8192, (2, 100))
|
||||
cond = torch.randn(2, 100, 400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = DiffusionTtsFlat(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
||||
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=True, num_heads=16,
|
||||
layer_drop=0, unconditioned_percentage=0)
|
||||
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=True, num_heads=16,
|
||||
layer_drop=0, unconditioned_percentage=0)
|
||||
# Test with latent aligned conditioning
|
||||
#o = model(clip, ts, aligned_latent, cond)
|
||||
# o = model(clip, ts, aligned_latent, cond)
|
||||
# Test with sequence aligned conditioning
|
||||
#o = model(clip, ts, aligned_sequence, cond)
|
||||
# o = model(clip, ts, aligned_sequence, cond)
|
||||
|
||||
with torch.no_grad():
|
||||
proj = torch.randn(2, 100, 1024).cuda()
|
||||
|
@ -389,4 +431,3 @@ if __name__ == '__main__':
|
|||
for k in range(1000):
|
||||
model(clip, ts, precomputed_aligned_embeddings=ti)
|
||||
print(f"Elapsed: {time()-start}")
|
||||
|
||||
|
|
|
@ -1,11 +1,16 @@
|
|||
from models.diffusion.fp16_util import convert_module_to_f32, convert_module_to_f16
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, ResBlock, TimestepEmbedSequential, \
|
||||
Downsample, Upsample
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from trainer.networks import register_model
|
||||
from dlas.models.diffusion.fp16_util import (convert_module_to_f16,
|
||||
convert_module_to_f32)
|
||||
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
|
||||
timestep_embedding, zero_module)
|
||||
from dlas.models.diffusion.unet_diffusion import (AttentionBlock,
|
||||
AttentionPool2d, Downsample,
|
||||
ResBlock,
|
||||
TimestepEmbedSequential,
|
||||
Upsample)
|
||||
from dlas.trainer.networks import register_model
|
||||
|
||||
|
||||
class DiffusionVocoder(nn.Module):
|
||||
|
@ -46,11 +51,12 @@ class DiffusionVocoder(nn.Module):
|
|||
in_channels=1,
|
||||
out_channels=2, # mean and variance
|
||||
spectrogram_channels=80,
|
||||
spectrogram_conditioning_level=3, # Level at which spectrogram conditioning is applied to the waveform.
|
||||
# Level at which spectrogram conditioning is applied to the waveform.
|
||||
spectrogram_conditioning_level=3,
|
||||
dropout=0,
|
||||
# 106496 -> 26624 -> 6656 -> 16664 -> 416 -> 104 -> 26 for ~5secs@22050Hz
|
||||
channel_mult=(1, 2, 4, 8, 16, 32, 64),
|
||||
attention_resolutions=(16,32,64),
|
||||
attention_resolutions=(16, 32, 64),
|
||||
conv_resample=True,
|
||||
dims=1,
|
||||
num_classes=None,
|
||||
|
@ -95,7 +101,8 @@ class DiffusionVocoder(nn.Module):
|
|||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
|
||||
conv_nd(dims, in_channels, model_channels,
|
||||
kernel_size, padding=padding)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
@ -104,7 +111,8 @@ class DiffusionVocoder(nn.Module):
|
|||
ch = model_channels
|
||||
ds = 1
|
||||
|
||||
spec_chs = channel_mult[spectrogram_conditioning_level] * model_channels
|
||||
spec_chs = channel_mult[spectrogram_conditioning_level] * \
|
||||
model_channels
|
||||
self.spectrogram_conditioner = nn.Sequential(
|
||||
conv_nd(dims, self.spectrogram_channels, spec_chs, 1),
|
||||
normalization(spec_chs),
|
||||
|
@ -119,7 +127,8 @@ class DiffusionVocoder(nn.Module):
|
|||
|
||||
for level, mult in enumerate(channel_mult):
|
||||
if level == spectrogram_conditioning_level+1:
|
||||
ch *= 2 # At this level, the spectrogram is concatenated onto the input.
|
||||
# At this level, the spectrogram is concatenated onto the input.
|
||||
ch *= 2
|
||||
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
|
@ -248,7 +257,8 @@ class DiffusionVocoder(nn.Module):
|
|||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels,
|
||||
kernel_size, padding=padding)),
|
||||
)
|
||||
|
||||
def convert_to_fp16(self):
|
||||
|
@ -278,14 +288,16 @@ class DiffusionVocoder(nn.Module):
|
|||
"""
|
||||
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
|
||||
hs = []
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
emb = self.time_embed(timestep_embedding(
|
||||
timesteps, self.model_channels))
|
||||
conditioning = self.spectrogram_conditioner(spectrogram)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for k, module in enumerate(self.input_blocks):
|
||||
h = module(h, emb)
|
||||
if k == self.input_block_injection_point:
|
||||
cond = nn.functional.interpolate(conditioning, size=h.shape[-self.dims:], mode='nearest')
|
||||
cond = nn.functional.interpolate(
|
||||
conditioning, size=h.shape[-self.dims:], mode='nearest')
|
||||
h = torch.cat([h, cond], dim=1)
|
||||
h = self.convergence_conv(h)
|
||||
hs.append(h)
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import AttentionBlock, ResBlock, TimestepEmbedSequential, \
|
||||
Downsample, Upsample
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from trainer.networks import register_model
|
||||
from dlas.models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
|
||||
timestep_embedding, zero_module)
|
||||
from dlas.models.diffusion.unet_diffusion import (AttentionBlock, Downsample,
|
||||
ResBlock,
|
||||
TimestepEmbedSequential,
|
||||
Upsample)
|
||||
from dlas.trainer.networks import register_model
|
||||
|
||||
|
||||
class DiscreteSpectrogramConditioningBlock(nn.Module):
|
||||
|
@ -23,6 +26,7 @@ class DiscreteSpectrogramConditioningBlock(nn.Module):
|
|||
:param x: bxcxS waveform latent
|
||||
:param codes: bxN discrete codes, N <= S
|
||||
"""
|
||||
|
||||
def forward(self, x, dvae_in):
|
||||
b, c, S = x.shape
|
||||
_, q, N = dvae_in.shape
|
||||
|
@ -70,12 +74,12 @@ class DiffusionVocoderWithRef(nn.Module):
|
|||
discrete_codes=512,
|
||||
dropout=0,
|
||||
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
|
||||
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
|
||||
channel_mult=(1, 1.5, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
|
||||
num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
|
||||
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
|
||||
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
|
||||
spectrogram_conditioning_resolutions=(512,),
|
||||
attention_resolutions=(512,1024,2048),
|
||||
attention_resolutions=(512, 1024, 2048),
|
||||
conv_resample=True,
|
||||
dims=1,
|
||||
use_fp16=False,
|
||||
|
@ -122,10 +126,11 @@ class DiffusionVocoderWithRef(nn.Module):
|
|||
self.conditioning_enabled = conditioning_inputs_provided
|
||||
if conditioning_inputs_provided:
|
||||
self.contextual_embedder = AudioMiniEncoder(in_channels, time_embed_dim, base_channels=32, depth=6, resnet_blocks=1,
|
||||
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
|
||||
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
|
||||
|
||||
seqlyr = TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
|
||||
conv_nd(dims, in_channels, model_channels,
|
||||
kernel_size, padding=padding)
|
||||
)
|
||||
seqlyr.level = 0
|
||||
self.input_blocks = nn.ModuleList([seqlyr])
|
||||
|
@ -137,7 +142,8 @@ class DiffusionVocoderWithRef(nn.Module):
|
|||
|
||||
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
|
||||
if ds in spectrogram_conditioning_resolutions:
|
||||
spec_cond_block = DiscreteSpectrogramConditioningBlock(discrete_codes, ch, 2 ** level)
|
||||
spec_cond_block = DiscreteSpectrogramConditioningBlock(
|
||||
discrete_codes, ch, 2 ** level)
|
||||
self.input_blocks.append(spec_cond_block)
|
||||
spectrogram_blocks.append(spec_cond_block)
|
||||
ch *= 2
|
||||
|
@ -172,21 +178,21 @@ class DiffusionVocoderWithRef(nn.Module):
|
|||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
upblk = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
kernel_size=kernel_size,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor
|
||||
)
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
kernel_size=kernel_size,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor
|
||||
)
|
||||
)
|
||||
upblk.level = 2 ** level
|
||||
self.input_blocks.append(upblk)
|
||||
ch = out_ch
|
||||
|
@ -270,7 +276,8 @@ class DiffusionVocoderWithRef(nn.Module):
|
|||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels,
|
||||
kernel_size, padding=padding)),
|
||||
)
|
||||
|
||||
if freeze_layers_below is not None:
|
||||
|
@ -297,7 +304,8 @@ class DiffusionVocoderWithRef(nn.Module):
|
|||
del p.DO_NOT_TRAIN
|
||||
p.requires_grad = True
|
||||
unfrozen_params += 1
|
||||
print(f"freeze_layers_below specified. Training a total of {unfrozen_params} parameters.")
|
||||
print(
|
||||
f"freeze_layers_below specified. Training a total of {unfrozen_params} parameters.")
|
||||
|
||||
def forward(self, x, timesteps, spectrogram, conditioning_input=None):
|
||||
"""
|
||||
|
@ -313,7 +321,8 @@ class DiffusionVocoderWithRef(nn.Module):
|
|||
assert conditioning_input is not None
|
||||
|
||||
hs = []
|
||||
emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
emb1 = self.time_embed(timestep_embedding(
|
||||
timesteps, self.model_channels))
|
||||
if self.conditioning_enabled:
|
||||
emb2 = self.contextual_embedder(conditioning_input)
|
||||
emb = emb1 + emb2
|
||||
|
@ -363,19 +372,20 @@ if __name__ == '__main__':
|
|||
move_all_layers_down(path, 'diffuse_new_lyr.pth', layers_to_be_added=2)
|
||||
|
||||
clip = torch.randn(2, 1, 40960)
|
||||
spec = torch.randn(2,80,160)
|
||||
spec = torch.randn(2, 80, 160)
|
||||
cond = torch.randn(2, 1, 40960)
|
||||
ts = torch.LongTensor([555, 556])
|
||||
model = DiffusionVocoderWithRef(model_channels=128, channel_mult=[1,1,1.5,2, 3, 4, 6, 8, 8, 8, 8 ],
|
||||
num_res_blocks=[1,2, 2, 2, 2, 2, 2, 2, 2, 1, 1 ], spectrogram_conditioning_resolutions=[2,512],
|
||||
dropout=.05, attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2,
|
||||
model = DiffusionVocoderWithRef(model_channels=128, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8],
|
||||
num_res_blocks=[1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1], spectrogram_conditioning_resolutions=[2, 512],
|
||||
dropout=.05, attention_resolutions=[512, 1024], num_heads=4, kernel_size=3, scale_factor=2,
|
||||
conditioning_inputs_provided=True, conditioning_input_dim=80, time_embed_dim_multiplier=4,
|
||||
discrete_codes=80, freeze_layers_below=1)
|
||||
loading_errors = model.load_state_dict(torch.load('diffuse_new_lyr.pth'), strict=False)
|
||||
loading_errors = model.load_state_dict(
|
||||
torch.load('diffuse_new_lyr.pth'), strict=False)
|
||||
new_params = loading_errors.missing_keys
|
||||
new_params_trained = []
|
||||
existing_params_trained = []
|
||||
for n,p in model.named_parameters():
|
||||
for n, p in model.named_parameters():
|
||||
if not hasattr(p, 'DO_NOT_TRAIN'):
|
||||
if n in new_params:
|
||||
new_params_trained.append(n)
|
||||
|
|
|
@ -1,24 +1,26 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import GPT2Config, GPT2PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
||||
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
||||
from transformers.utils.model_parallel_utils import (assert_device_map,
|
||||
get_device_map)
|
||||
|
||||
from models.arch_util import AttentionBlock
|
||||
from models.audio.tts.transformer_builders import build_hf_gpt_transformer
|
||||
from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.arch_util import AttentionBlock
|
||||
from dlas.models.audio.tts.transformer_builders import build_hf_gpt_transformer
|
||||
from dlas.models.lucidrains.x_transformers import (RotaryEmbedding,
|
||||
apply_rotary_pos_emb)
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
import torch_intermediary as ml
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""
|
||||
Basic residual convolutional block that uses GroupNorm.
|
||||
"""
|
||||
|
||||
def __init__(self, chan):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
|
@ -48,7 +50,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
|
||||
def parallelize(self, device_map=None):
|
||||
self.device_map = (
|
||||
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
||||
get_device_map(len(self.transformer.h),
|
||||
range(torch.cuda.device_count()))
|
||||
if device_map is None
|
||||
else device_map
|
||||
)
|
||||
|
@ -121,7 +124,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
):
|
||||
assert self.cached_mel_emb is not None
|
||||
assert inputs_embeds is None # Not supported by this inference model.
|
||||
assert labels is None # Training not supported by this inference model.
|
||||
# Training not supported by this inference model.
|
||||
assert labels is None
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# Create embedding
|
||||
|
@ -131,13 +135,15 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
text_emb = self.embeddings(text_inputs)
|
||||
text_emb = text_emb + self.text_pos_embedding(text_emb)
|
||||
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
|
||||
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0)
|
||||
mel_emb = self.cached_mel_emb.repeat_interleave(
|
||||
text_emb.shape[0]//self.cached_mel_emb.shape[0], 0)
|
||||
else:
|
||||
mel_emb = self.cached_mel_emb
|
||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||
else:
|
||||
emb = self.embeddings(input_ids)
|
||||
emb = emb + self.text_pos_embedding.get_fixed_embedding(attention_mask.shape[1]-mel_len, attention_mask.device)
|
||||
emb = emb + self.text_pos_embedding.get_fixed_embedding(
|
||||
attention_mask.shape[1]-mel_len, attention_mask.device)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
inputs_embeds=emb,
|
||||
|
@ -182,7 +188,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
||||
"""
|
||||
return tuple(
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device))
|
||||
for past_state in layer_past)
|
||||
for layer_past in past
|
||||
)
|
||||
|
||||
|
@ -199,7 +206,8 @@ class ConditioningEncoder(nn.Module):
|
|||
attn = []
|
||||
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
||||
for a in range(attn_blocks):
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
|
||||
attn.append(AttentionBlock(embedding_dim,
|
||||
num_attn_heads, do_checkpoint=do_checkpointing))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
self.do_checkpointing = do_checkpointing
|
||||
|
@ -219,23 +227,27 @@ class MelEncoder(nn.Module):
|
|||
super().__init__()
|
||||
self.channels = channels
|
||||
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1),
|
||||
nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
|
||||
nn.Sequential(
|
||||
*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels//4, channels//2,
|
||||
kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(channels//16, channels//2),
|
||||
nn.ReLU(),
|
||||
nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
|
||||
nn.Sequential(
|
||||
*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels//2, channels,
|
||||
kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(channels//8, channels),
|
||||
nn.ReLU(),
|
||||
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Sequential(
|
||||
*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
|
||||
)
|
||||
self.reduction = 4
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
for e in self.encoder:
|
||||
x = e(x)
|
||||
return x.permute(0,2,1)
|
||||
return x.permute(0, 2, 1)
|
||||
|
||||
|
||||
class UnifiedVoice(nn.Module):
|
||||
|
@ -276,25 +288,32 @@ class UnifiedVoice(nn.Module):
|
|||
self.layers = layers
|
||||
self.heads = heads
|
||||
self.max_conditioning_inputs = max_conditioning_inputs
|
||||
self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens+2+self.max_conditioning_inputs
|
||||
self.max_mel_tokens = -1 if max_mel_tokens == - \
|
||||
1 else max_mel_tokens+2+self.max_conditioning_inputs
|
||||
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens+2
|
||||
self.model_dim = model_dim
|
||||
self.mel_length_compression = mel_length_compression
|
||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||
self.conditioning_encoder = ConditioningEncoder(
|
||||
80, model_dim, num_attn_heads=heads)
|
||||
self.average_conditioning_embeddings = average_conditioning_embeddings
|
||||
self.tortoise_compat = tortoise_compat # credit to https://github.com/152334H/DL-Art-School/commit/ae80992817059acf6eef38a680efa5124cee570b
|
||||
# credit to https://github.com/152334H/DL-Art-School/commit/ae80992817059acf6eef38a680efa5124cee570b
|
||||
self.tortoise_compat = tortoise_compat
|
||||
# nn.Embedding
|
||||
self.text_embedding = ml.Embedding(self.number_text_tokens, model_dim)
|
||||
if use_mel_codes_as_input:
|
||||
# nn.Embedding
|
||||
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
|
||||
else:
|
||||
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
||||
self.mel_embedding = MelEncoder(
|
||||
model_dim, resblocks_per_reduction=1)
|
||||
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
||||
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
|
||||
build_hf_gpt_transformer(
|
||||
layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
|
||||
if train_solo_embeddings:
|
||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||
self.mel_solo_embedding = nn.Parameter(
|
||||
torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||
self.text_solo_embedding = nn.Parameter(
|
||||
torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||
else:
|
||||
self.mel_solo_embedding = 0
|
||||
self.text_solo_embedding = 0
|
||||
|
@ -303,7 +322,6 @@ class UnifiedVoice(nn.Module):
|
|||
self.text_head = ml.Linear(model_dim, self.number_text_tokens)
|
||||
self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
|
||||
|
||||
|
||||
# Initialize the embeddings per the GPT-2 scheme
|
||||
embeddings = [self.text_embedding]
|
||||
if use_mel_codes_as_input:
|
||||
|
@ -328,8 +346,8 @@ class UnifiedVoice(nn.Module):
|
|||
}
|
||||
|
||||
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
||||
inp = F.pad(input, (1,0), value=start_token)
|
||||
tar = F.pad(input, (0,1), value=stop_token)
|
||||
inp = F.pad(input, (1, 0), value=start_token)
|
||||
tar = F.pad(input, (0, 1), value=stop_token)
|
||||
return inp, tar
|
||||
|
||||
def set_mel_padding(self, mel_input_tokens, wav_lengths):
|
||||
|
@ -341,22 +359,26 @@ class UnifiedVoice(nn.Module):
|
|||
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
|
||||
mel_lengths = wav_lengths // self.mel_length_compression
|
||||
for b in range(len(mel_lengths)):
|
||||
actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
|
||||
# Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
|
||||
actual_end = mel_lengths[b] + 1
|
||||
if actual_end < mel_input_tokens.shape[-1]:
|
||||
mel_input_tokens[b, actual_end:] = self.stop_mel_token
|
||||
return mel_input_tokens
|
||||
|
||||
def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False):
|
||||
if second_inputs is not None:
|
||||
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
|
||||
emb = torch.cat([speech_conditioning_inputs,
|
||||
first_inputs, second_inputs], dim=1)
|
||||
else:
|
||||
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
|
||||
|
||||
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
||||
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True,
|
||||
output_attentions=get_attns)
|
||||
if get_attns:
|
||||
return gpt_out.attentions
|
||||
|
||||
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
|
||||
# The first logit is tied to the speech_conditioning_input
|
||||
enc = gpt_out.last_hidden_state[:, 1:]
|
||||
enc = self.final_norm(enc)
|
||||
|
||||
if return_latent:
|
||||
|
@ -364,11 +386,11 @@ class UnifiedVoice(nn.Module):
|
|||
|
||||
first_logits = enc[:, :first_inputs.shape[1]]
|
||||
first_logits = first_head(first_logits)
|
||||
first_logits = first_logits.permute(0,2,1)
|
||||
first_logits = first_logits.permute(0, 2, 1)
|
||||
if second_inputs is not None:
|
||||
second_logits = enc[:, -second_inputs.shape[1]:]
|
||||
second_logits = second_head(second_logits)
|
||||
second_logits = second_logits.permute(0,2,1)
|
||||
second_logits = second_logits.permute(0, 2, 1)
|
||||
return first_logits, second_logits
|
||||
else:
|
||||
return first_logits
|
||||
|
@ -394,24 +416,31 @@ class UnifiedVoice(nn.Module):
|
|||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||
# chopping the inputs by the maximum actual length.
|
||||
max_text_len = text_lengths.max()
|
||||
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
|
||||
text_inputs = F.pad(
|
||||
text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token)
|
||||
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
||||
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
|
||||
mel_codes = F.pad(mel_codes[:, :max_mel_len],
|
||||
(0, 1), value=self.stop_mel_token)
|
||||
if raw_mels is not None:
|
||||
raw_mels = raw_mels[:, :, :max_mel_len*4]
|
||||
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
||||
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
|
||||
speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds.append(self.conditioning_encoder(
|
||||
speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
if self.average_conditioning_embeddings:
|
||||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
|
||||
text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(
|
||||
text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(
|
||||
mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||
if raw_mels is not None:
|
||||
mel_inp = F.pad(raw_mels, (0, 8))
|
||||
else:
|
||||
|
@ -421,13 +450,17 @@ class UnifiedVoice(nn.Module):
|
|||
|
||||
sub = -2 if self.tortoise_compat else -1
|
||||
if text_first:
|
||||
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
|
||||
text_logits, mel_logits = self.get_logits(
|
||||
conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
|
||||
if return_latent:
|
||||
return mel_logits[:, :sub] # Despite the name, these are not logits.
|
||||
# Despite the name, these are not logits.
|
||||
return mel_logits[:, :sub]
|
||||
else:
|
||||
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
|
||||
mel_logits, text_logits = self.get_logits(
|
||||
conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
|
||||
if return_latent:
|
||||
return text_logits[:, :sub] # Despite the name, these are not logits
|
||||
# Despite the name, these are not logits
|
||||
return text_logits[:, :sub]
|
||||
|
||||
if return_attentions:
|
||||
return mel_logits
|
||||
|
@ -443,18 +476,23 @@ class UnifiedVoice(nn.Module):
|
|||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||
# chopping the inputs by the maximum actual length.
|
||||
max_text_len = text_lengths.max()
|
||||
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
|
||||
text_inputs = F.pad(
|
||||
text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token)
|
||||
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
|
||||
speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds.append(self.conditioning_encoder(
|
||||
speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
if self.average_conditioning_embeddings:
|
||||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
|
||||
text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(
|
||||
text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
|
||||
text_logits = self.get_logits(conds, text_emb, self.text_head)
|
||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||
return loss_text.mean()
|
||||
|
@ -468,26 +506,31 @@ class UnifiedVoice(nn.Module):
|
|||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||
# chopping the inputs by the maximum actual length.
|
||||
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
||||
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
|
||||
mel_codes = F.pad(mel_codes[:, :max_mel_len],
|
||||
(0, 1), value=self.stop_mel_token)
|
||||
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
||||
if raw_mels is not None:
|
||||
raw_mels = raw_mels[:, :, :max_mel_len*4]
|
||||
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
|
||||
speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds.append(self.conditioning_encoder(
|
||||
speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
if self.average_conditioning_embeddings:
|
||||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
|
||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(
|
||||
mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||
if raw_mels is not None:
|
||||
mel_inp = F.pad(raw_mels, (0, 4))
|
||||
else:
|
||||
mel_inp = mel_codes
|
||||
mel_emb = self.mel_embedding(mel_inp)
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding
|
||||
mel_emb = mel_emb + \
|
||||
self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding
|
||||
mel_logits = self.get_logits(conds, mel_emb, self.mel_head)
|
||||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||
return loss_mel.mean()
|
||||
|
@ -507,17 +550,22 @@ class UnifiedVoice(nn.Module):
|
|||
n_head=self.heads,
|
||||
gradient_checkpointing=False,
|
||||
use_cache=True)
|
||||
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
|
||||
self.inference_model = GPT2InferenceModel(
|
||||
gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
|
||||
self.gpt.wte = self.mel_embedding
|
||||
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
|
||||
text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(
|
||||
text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
|
||||
speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds.append(self.conditioning_encoder(
|
||||
speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
if self.average_conditioning_embeddings:
|
||||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
|
@ -525,8 +573,9 @@ class UnifiedVoice(nn.Module):
|
|||
emb = torch.cat([conds, text_emb], dim=1)
|
||||
self.inference_model.store_mel_emb(emb)
|
||||
|
||||
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||
fake_inputs[:,-1] = self.start_mel_token
|
||||
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],),
|
||||
fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||
fake_inputs[:, -1] = self.start_mel_token
|
||||
|
||||
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
||||
max_length=seq_length, output_attentions=return_attentions, return_dict_in_generate=True, **hf_generate_kwargs)
|
||||
|
@ -535,15 +584,16 @@ class UnifiedVoice(nn.Module):
|
|||
else:
|
||||
return gen.sequences[:, fake_inputs.shape[1]:]
|
||||
|
||||
|
||||
# Turns the (utterly insane) output of HF.generate() into a far more sane output:
|
||||
# [tensors(B,H,S,S)]. Outer=layers, B=batch,H=head,S=sequence
|
||||
|
||||
def make_hf_generate_attentions_sane(self, attentions):
|
||||
layers = [[] for _ in range(len(attentions[0]))]
|
||||
full_attention_size = attentions[-1][0].shape[-1]
|
||||
for i, gen in enumerate(attentions):
|
||||
for j, lyr in enumerate(gen):
|
||||
layers[j].append(F.pad(lyr, (0, full_attention_size - lyr.shape[-1])))
|
||||
layers[j].append(
|
||||
F.pad(lyr, (0, full_attention_size - lyr.shape[-1])))
|
||||
catted = []
|
||||
for lyr in layers:
|
||||
catted.append(torch.cat(lyr, dim=2))
|
||||
|
@ -562,18 +612,21 @@ class UnifiedVoice(nn.Module):
|
|||
for l, layer in enumerate(attentions):
|
||||
dec_context = layer[:, :, num_context:, :]
|
||||
# Mask out everything that isn't text (including the start token, which gets a LOT of attention)
|
||||
dec_context[:,:,:,:text_padding+1] = 0
|
||||
dec_context[:,:,:,num_context:] = 0
|
||||
dec_context[:, :, :, :text_padding+1] = 0
|
||||
dec_context[:, :, :, num_context:] = 0
|
||||
for h in range(dec_context.shape[1]):
|
||||
dec_context_indices = torch.argmax(dec_context[0,h], dim=-1)
|
||||
dec_context_indices = torch.argmax(dec_context[0, h], dim=-1)
|
||||
print(f'layer_{l};head_{h}: ' + str(dec_context_indices))
|
||||
for t, att_tok in enumerate(attentions):
|
||||
combined_attention_weights = torch.zeros((codes.shape[0], num_text), device=codes.device)
|
||||
combined_attention_weights = torch.zeros(
|
||||
(codes.shape[0], num_text), device=codes.device)
|
||||
for lyr in att_tok:
|
||||
token_to_text_attentions = lyr[:, :, -1, text_padding:(text_padding + num_text)].sum(dim=1)
|
||||
token_to_text_attentions = lyr[:, :, -1,
|
||||
text_padding:(text_padding + num_text)].sum(dim=1)
|
||||
combined_attention_weights = combined_attention_weights + token_to_text_attentions
|
||||
break
|
||||
most_attended_text_token = combined_attention_weights.argmax(dim=-1)
|
||||
most_attended_text_token = combined_attention_weights.argmax(
|
||||
dim=-1)
|
||||
results[:, t] = most_attended_text_token
|
||||
eos_token_mask = (codes != self.stop_mel_token)
|
||||
return results * eos_token_mask
|
||||
|
@ -585,10 +638,11 @@ def register_unified_voice2(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True)
|
||||
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True,
|
||||
max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True)
|
||||
l = gpt(torch.randn(2, 3, 80, 800),
|
||||
torch.randint(high=256, size=(2,120)),
|
||||
torch.randint(high=256, size=(2, 120)),
|
||||
torch.tensor([32, 120]),
|
||||
torch.randint(high=8192, size=(2,250)),
|
||||
torch.tensor([250*256,195*256]))
|
||||
#gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
|
||||
torch.randint(high=8192, size=(2, 250)),
|
||||
torch.tensor([250*256, 195*256]))
|
||||
# gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
|
||||
|
|
|
@ -1,24 +1,26 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_intermediary as ml
|
||||
|
||||
from transformers import GPT2Config, GPT2PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
||||
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
||||
from transformers.utils.model_parallel_utils import (assert_device_map,
|
||||
get_device_map)
|
||||
|
||||
from models.arch_util import AttentionBlock
|
||||
from models.audio.tts.transformer_builders import build_hf_gpt_transformer
|
||||
from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.arch_util import AttentionBlock
|
||||
from dlas.models.audio.tts.transformer_builders import build_hf_gpt_transformer
|
||||
from dlas.models.lucidrains.x_transformers import (RotaryEmbedding,
|
||||
apply_rotary_pos_emb)
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""
|
||||
Basic residual convolutional block that uses GroupNorm.
|
||||
"""
|
||||
|
||||
def __init__(self, chan):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
|
@ -48,7 +50,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
|
||||
def parallelize(self, device_map=None):
|
||||
self.device_map = (
|
||||
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
||||
get_device_map(len(self.transformer.h),
|
||||
range(torch.cuda.device_count()))
|
||||
if device_map is None
|
||||
else device_map
|
||||
)
|
||||
|
@ -120,7 +123,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
):
|
||||
assert self.cached_prior_emb is not None
|
||||
assert inputs_embeds is None # Not supported by this inference model.
|
||||
assert labels is None # Training not supported by this inference model.
|
||||
# Training not supported by this inference model.
|
||||
assert labels is None
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# Create embedding
|
||||
|
@ -128,15 +132,18 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
if input_ids.shape[1] != 1:
|
||||
posterior_inputs = input_ids[:, prior_len:]
|
||||
posterior_emb = self.embeddings(posterior_inputs)
|
||||
posterior_emb = posterior_emb + self.posterior_pos_embedding(posterior_emb)
|
||||
posterior_emb = posterior_emb + \
|
||||
self.posterior_pos_embedding(posterior_emb)
|
||||
if self.cached_prior_emb.shape[0] != posterior_emb.shape[0]:
|
||||
prior_emb = self.cached_prior_emb.repeat_interleave(posterior_emb.shape[0] // self.cached_prior_emb.shape[0], 0)
|
||||
prior_emb = self.cached_prior_emb.repeat_interleave(
|
||||
posterior_emb.shape[0] // self.cached_prior_emb.shape[0], 0)
|
||||
else:
|
||||
prior_emb = self.cached_prior_emb
|
||||
emb = torch.cat([prior_emb, posterior_emb], dim=1)
|
||||
else:
|
||||
emb = self.embeddings(input_ids)
|
||||
emb = emb + self.posterior_pos_embedding.get_fixed_embedding(attention_mask.shape[1] - prior_len, attention_mask.device)
|
||||
emb = emb + self.posterior_pos_embedding.get_fixed_embedding(
|
||||
attention_mask.shape[1] - prior_len, attention_mask.device)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
inputs_embeds=emb,
|
||||
|
@ -181,7 +188,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
||||
"""
|
||||
return tuple(
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device))
|
||||
for past_state in layer_past)
|
||||
for layer_past in past
|
||||
)
|
||||
|
||||
|
@ -198,7 +206,8 @@ class ConditioningEncoder(nn.Module):
|
|||
attn = []
|
||||
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
||||
for a in range(attn_blocks):
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
|
||||
attn.append(AttentionBlock(embedding_dim,
|
||||
num_attn_heads, do_checkpoint=do_checkpointing))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
self.do_checkpointing = do_checkpointing
|
||||
|
@ -218,23 +227,27 @@ class MelEncoder(nn.Module):
|
|||
super().__init__()
|
||||
self.channels = channels
|
||||
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1),
|
||||
nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
|
||||
nn.Sequential(
|
||||
*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels//4, channels//2,
|
||||
kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(channels//16, channels//2),
|
||||
nn.ReLU(),
|
||||
nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
|
||||
nn.Sequential(
|
||||
*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels//2, channels,
|
||||
kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(channels//8, channels),
|
||||
nn.ReLU(),
|
||||
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Sequential(
|
||||
*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
|
||||
)
|
||||
self.reduction = 4
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
for e in self.encoder:
|
||||
x = e(x)
|
||||
return x.permute(0,2,1)
|
||||
return x.permute(0, 2, 1)
|
||||
|
||||
|
||||
class UnifiedVoice(nn.Module):
|
||||
|
@ -260,7 +273,8 @@ class UnifiedVoice(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
self.number_text_tokens = number_text_tokens
|
||||
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
|
||||
self.start_text_token = number_text_tokens * \
|
||||
types if start_text_token is None else start_text_token
|
||||
self.stop_text_token = 0
|
||||
self.number_mel_codes = number_mel_codes
|
||||
self.start_mel_token = start_mel_token
|
||||
|
@ -268,17 +282,21 @@ class UnifiedVoice(nn.Module):
|
|||
self.layers = layers
|
||||
self.heads = heads
|
||||
self.max_conditioning_inputs = max_conditioning_inputs
|
||||
self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens+2+self.max_conditioning_inputs
|
||||
self.max_mel_tokens = -1 if max_mel_tokens == - \
|
||||
1 else max_mel_tokens+2+self.max_conditioning_inputs
|
||||
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens+2
|
||||
self.model_dim = model_dim
|
||||
self.mel_length_compression = mel_length_compression
|
||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||
self.conditioning_encoder = ConditioningEncoder(
|
||||
80, model_dim, num_attn_heads=heads)
|
||||
# nn.Embedding
|
||||
self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim)
|
||||
self.text_embedding = ml.Embedding(
|
||||
self.number_text_tokens*types+1, model_dim)
|
||||
# nn.Embedding
|
||||
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
|
||||
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
||||
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
|
||||
build_hf_gpt_transformer(
|
||||
layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
|
||||
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1)
|
||||
|
@ -306,8 +324,8 @@ class UnifiedVoice(nn.Module):
|
|||
}
|
||||
|
||||
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
||||
inp = F.pad(input, (1,0), value=start_token)
|
||||
tar = F.pad(input, (0,1), value=stop_token)
|
||||
inp = F.pad(input, (1, 0), value=start_token)
|
||||
tar = F.pad(input, (0, 1), value=stop_token)
|
||||
return inp, tar
|
||||
|
||||
def set_mel_padding(self, mel_input_tokens, wav_lengths):
|
||||
|
@ -319,45 +337,48 @@ class UnifiedVoice(nn.Module):
|
|||
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
|
||||
mel_lengths = wav_lengths // self.mel_length_compression
|
||||
for b in range(len(mel_lengths)):
|
||||
actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
|
||||
# Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
|
||||
actual_end = mel_lengths[b] + 1
|
||||
if actual_end < mel_input_tokens.shape[-1]:
|
||||
mel_input_tokens[b, actual_end:] = self.stop_mel_token
|
||||
return mel_input_tokens
|
||||
|
||||
def get_logits(self, speech_conditioning_inputs, text_inputs, text_head, mel_inputs, mel_head, aligned_head, return_latent=False):
|
||||
if mel_inputs is not None:
|
||||
emb = torch.cat([speech_conditioning_inputs, text_inputs, mel_inputs], dim=1)
|
||||
emb = torch.cat([speech_conditioning_inputs,
|
||||
text_inputs, mel_inputs], dim=1)
|
||||
else:
|
||||
emb = torch.cat([speech_conditioning_inputs, text_inputs], dim=1)
|
||||
|
||||
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True)
|
||||
|
||||
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
|
||||
# The first logit is tied to the speech_conditioning_input
|
||||
enc = gpt_out.last_hidden_state[:, 1:]
|
||||
enc = self.final_norm(enc)
|
||||
|
||||
if return_latent:
|
||||
return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1] + text_inputs.shape[1]], enc[:, -mel_inputs.shape[1]:]
|
||||
|
||||
text_logits = enc[:, :text_inputs.shape[1]]
|
||||
text_logits = text_head(text_logits).permute(0,2,1)
|
||||
text_logits = text_head(text_logits).permute(0, 2, 1)
|
||||
|
||||
mel_logits = enc[:, -mel_inputs.shape[1]:]
|
||||
aligned_logits = aligned_head(mel_logits).permute(0,2,1)
|
||||
mel_logits = mel_head(mel_logits).permute(0,2,1)
|
||||
aligned_logits = aligned_head(mel_logits).permute(0, 2, 1)
|
||||
mel_logits = mel_head(mel_logits).permute(0, 2, 1)
|
||||
|
||||
return text_logits, mel_logits, aligned_logits
|
||||
|
||||
|
||||
def get_conditioning_latent(self, speech_conditioning_input):
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
|
||||
speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds.append(self.conditioning_encoder(
|
||||
speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
return conds
|
||||
|
||||
|
||||
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, aligned_codes, types=None, return_latent=False):
|
||||
"""
|
||||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||
|
@ -376,19 +397,22 @@ class UnifiedVoice(nn.Module):
|
|||
if types is not None:
|
||||
text_inputs = text_inputs * (1+types).unsqueeze(-1)
|
||||
|
||||
|
||||
conds = self.get_conditioning_latent(speech_conditioning_input)
|
||||
|
||||
ac_expansion_factor = mel_codes.shape[-1] / aligned_codes.shape[-1]
|
||||
aligned_codes = aligned_codes.repeat(1, ac_expansion_factor)
|
||||
_, aligned_targets = self.build_aligned_inputs_and_targets(aligned_codes, 0, 0)
|
||||
_, aligned_targets = self.build_aligned_inputs_and_targets(
|
||||
aligned_codes, 0, 0)
|
||||
|
||||
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
|
||||
text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(
|
||||
text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
||||
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(
|
||||
mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||
mel_inp = mel_codes
|
||||
mel_emb = self.mel_embedding(mel_inp)
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
|
||||
|
@ -396,7 +420,8 @@ class UnifiedVoice(nn.Module):
|
|||
text_logits, mel_logits, aligned_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head,
|
||||
self.aligned_head, return_latent=return_latent)
|
||||
if return_latent:
|
||||
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
||||
# Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
||||
return mel_logits[:, :-2]
|
||||
|
||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||
|
@ -418,25 +443,31 @@ class UnifiedVoice(nn.Module):
|
|||
n_head=self.heads,
|
||||
gradient_checkpointing=False,
|
||||
use_cache=True)
|
||||
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
|
||||
self.inference_model = GPT2InferenceModel(
|
||||
gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
|
||||
self.gpt.wte = self.mel_embedding
|
||||
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
|
||||
text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(
|
||||
text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
|
||||
speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds.append(self.conditioning_encoder(
|
||||
speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
|
||||
emb = torch.cat([conds, text_emb], dim=1)
|
||||
self.inference_model.store_prior_emb(emb)
|
||||
|
||||
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||
fake_inputs[:,-1] = self.start_mel_token
|
||||
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],),
|
||||
fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||
fake_inputs[:, -1] = self.start_mel_token
|
||||
|
||||
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
||||
max_length=seq_length, return_dict_in_generate=True, **hf_generate_kwargs)
|
||||
|
@ -449,11 +480,12 @@ def register_unified_voice3(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gpt = UnifiedVoice(model_dim=256, heads=4, max_conditioning_inputs=4, types=2)
|
||||
mel = torch.randint(high=8192, size=(2,250))
|
||||
ac = torch.randint(high=256, size=(2,250*1024//443))
|
||||
gpt = UnifiedVoice(model_dim=256, heads=4,
|
||||
max_conditioning_inputs=4, types=2)
|
||||
mel = torch.randint(high=8192, size=(2, 250))
|
||||
ac = torch.randint(high=256, size=(2, 250*1024//443))
|
||||
l = gpt(torch.randn(2, 3, 80, 800),
|
||||
torch.randint(high=256, size=(2,120)),
|
||||
torch.randint(high=256, size=(2, 120)),
|
||||
torch.tensor([32, 120]),
|
||||
mel, torch.tensor([250*256,195*256]), ac,
|
||||
mel, torch.tensor([250*256, 195*256]), ac,
|
||||
types=torch.tensor([0, 1]))
|
||||
|
|
|
@ -4,20 +4,23 @@ import torch.nn.functional as F
|
|||
from transformers import GPT2Config, GPT2PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
||||
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
||||
from transformers.utils.model_parallel_utils import (assert_device_map,
|
||||
get_device_map)
|
||||
|
||||
from models.arch_util import AttentionBlock
|
||||
from models.audio.tts.transformer_builders import build_hf_gpt_transformer
|
||||
from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
import torch_intermediary as ml
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.arch_util import AttentionBlock
|
||||
from dlas.models.audio.tts.transformer_builders import build_hf_gpt_transformer
|
||||
from dlas.models.lucidrains.x_transformers import (RotaryEmbedding,
|
||||
apply_rotary_pos_emb)
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""
|
||||
Basic residual convolutional block that uses GroupNorm.
|
||||
"""
|
||||
|
||||
def __init__(self, chan):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
|
@ -47,7 +50,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
|
||||
def parallelize(self, device_map=None):
|
||||
self.device_map = (
|
||||
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
||||
get_device_map(len(self.transformer.h),
|
||||
range(torch.cuda.device_count()))
|
||||
if device_map is None
|
||||
else device_map
|
||||
)
|
||||
|
@ -119,7 +123,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
):
|
||||
assert self.cached_prior_emb is not None
|
||||
assert inputs_embeds is None # Not supported by this inference model.
|
||||
assert labels is None # Training not supported by this inference model.
|
||||
# Training not supported by this inference model.
|
||||
assert labels is None
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# Create embedding
|
||||
|
@ -127,15 +132,18 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
if input_ids.shape[1] != 1:
|
||||
posterior_inputs = input_ids[:, prior_len:]
|
||||
posterior_emb = self.embeddings(posterior_inputs)
|
||||
posterior_emb = posterior_emb + self.posterior_pos_embedding(posterior_emb)
|
||||
posterior_emb = posterior_emb + \
|
||||
self.posterior_pos_embedding(posterior_emb)
|
||||
if self.cached_prior_emb.shape[0] != posterior_emb.shape[0]:
|
||||
prior_emb = self.cached_prior_emb.repeat_interleave(posterior_emb.shape[0] // self.cached_prior_emb.shape[0], 0)
|
||||
prior_emb = self.cached_prior_emb.repeat_interleave(
|
||||
posterior_emb.shape[0] // self.cached_prior_emb.shape[0], 0)
|
||||
else:
|
||||
prior_emb = self.cached_prior_emb
|
||||
emb = torch.cat([prior_emb, posterior_emb], dim=1)
|
||||
else:
|
||||
emb = self.embeddings(input_ids)
|
||||
emb = emb + self.posterior_pos_embedding.get_fixed_embedding(attention_mask.shape[1] - prior_len, attention_mask.device)
|
||||
emb = emb + self.posterior_pos_embedding.get_fixed_embedding(
|
||||
attention_mask.shape[1] - prior_len, attention_mask.device)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
inputs_embeds=emb,
|
||||
|
@ -180,7 +188,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
||||
"""
|
||||
return tuple(
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device))
|
||||
for past_state in layer_past)
|
||||
for layer_past in past
|
||||
)
|
||||
|
||||
|
@ -197,7 +206,8 @@ class ConditioningEncoder(nn.Module):
|
|||
attn = []
|
||||
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
||||
for a in range(attn_blocks):
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
|
||||
attn.append(AttentionBlock(embedding_dim,
|
||||
num_attn_heads, do_checkpoint=do_checkpointing))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
self.do_checkpointing = do_checkpointing
|
||||
|
@ -217,23 +227,27 @@ class MelEncoder(nn.Module):
|
|||
super().__init__()
|
||||
self.channels = channels
|
||||
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1),
|
||||
nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
|
||||
nn.Sequential(
|
||||
*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels//4, channels//2,
|
||||
kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(channels//16, channels//2),
|
||||
nn.ReLU(),
|
||||
nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
|
||||
nn.Sequential(
|
||||
*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels//2, channels,
|
||||
kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(channels//8, channels),
|
||||
nn.ReLU(),
|
||||
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Sequential(
|
||||
*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
|
||||
)
|
||||
self.reduction = 4
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
for e in self.encoder:
|
||||
x = e(x)
|
||||
return x.permute(0,2,1)
|
||||
return x.permute(0, 2, 1)
|
||||
|
||||
|
||||
class UnifiedVoice(nn.Module):
|
||||
|
@ -243,7 +257,8 @@ class UnifiedVoice(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
self.number_text_tokens = number_text_tokens
|
||||
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
|
||||
self.start_text_token = number_text_tokens * \
|
||||
types if start_text_token is None else start_text_token
|
||||
self.stop_text_token = 0
|
||||
self.number_mel_codes = number_mel_codes
|
||||
self.start_mel_token = start_mel_token
|
||||
|
@ -251,17 +266,21 @@ class UnifiedVoice(nn.Module):
|
|||
self.layers = layers
|
||||
self.heads = heads
|
||||
self.max_conditioning_inputs = max_conditioning_inputs
|
||||
self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens+2+self.max_conditioning_inputs
|
||||
self.max_mel_tokens = -1 if max_mel_tokens == - \
|
||||
1 else max_mel_tokens+2+self.max_conditioning_inputs
|
||||
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens+2
|
||||
self.model_dim = model_dim
|
||||
self.mel_length_compression = mel_length_compression
|
||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||
self.conditioning_encoder = ConditioningEncoder(
|
||||
80, model_dim, num_attn_heads=heads)
|
||||
# nn.Embedding
|
||||
self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim)
|
||||
self.text_embedding = ml.Embedding(
|
||||
self.number_text_tokens*types+1, model_dim)
|
||||
# nn.Embedding
|
||||
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
|
||||
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
||||
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
|
||||
build_hf_gpt_transformer(
|
||||
layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
|
||||
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1)
|
||||
|
@ -289,8 +308,8 @@ class UnifiedVoice(nn.Module):
|
|||
}
|
||||
|
||||
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
||||
inp = F.pad(input, (1,0), value=start_token)
|
||||
tar = F.pad(input, (0,1), value=stop_token)
|
||||
inp = F.pad(input, (1, 0), value=start_token)
|
||||
tar = F.pad(input, (0, 1), value=stop_token)
|
||||
return inp, tar
|
||||
|
||||
def set_mel_padding(self, mel_input_tokens, wav_lengths):
|
||||
|
@ -302,17 +321,20 @@ class UnifiedVoice(nn.Module):
|
|||
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
|
||||
mel_lengths = wav_lengths // self.mel_length_compression
|
||||
for b in range(len(mel_lengths)):
|
||||
actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
|
||||
# Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
|
||||
actual_end = mel_lengths[b] + 1
|
||||
if actual_end < mel_input_tokens.shape[-1]:
|
||||
mel_input_tokens[b, actual_end:] = self.stop_mel_token
|
||||
return mel_input_tokens
|
||||
|
||||
def get_logits(self, speech_conditioning_inputs, first_inputs, second_inputs, return_latent=False):
|
||||
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
|
||||
emb = torch.cat([speech_conditioning_inputs,
|
||||
first_inputs, second_inputs], dim=1)
|
||||
|
||||
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True)
|
||||
|
||||
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
|
||||
# The first logit is tied to the speech_conditioning_input
|
||||
enc = gpt_out.last_hidden_state[:, 1:]
|
||||
enc = self.final_norm(enc)
|
||||
|
||||
if return_latent:
|
||||
|
@ -320,29 +342,29 @@ class UnifiedVoice(nn.Module):
|
|||
|
||||
text_logits = enc[:, :first_inputs.shape[1]]
|
||||
text_logits = self.text_head(text_logits)
|
||||
text_logits = text_logits.permute(0,2,1)
|
||||
text_logits = text_logits.permute(0, 2, 1)
|
||||
|
||||
mel_logits = enc[:, -second_inputs.shape[1]:]
|
||||
mel_logits = self.mel_head(mel_logits)
|
||||
mel_logits = mel_logits.permute(0,2,1)
|
||||
mel_logits = mel_logits.permute(0, 2, 1)
|
||||
|
||||
alignment_logits = enc[:, -second_inputs.shape[1]:]
|
||||
alignment_logits = self.alignment_head(alignment_logits)
|
||||
alignment_logits = alignment_logits.permute(0,2,1)
|
||||
alignment_logits = alignment_logits.permute(0, 2, 1)
|
||||
|
||||
return text_logits, mel_logits, alignment_logits
|
||||
|
||||
|
||||
def get_conditioning_latent(self, speech_conditioning_input):
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
|
||||
speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds.append(self.conditioning_encoder(
|
||||
speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
return conds
|
||||
|
||||
|
||||
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, ctc_codes, wav_lengths, types=None, return_latent=False):
|
||||
"""
|
||||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||
|
@ -368,24 +390,30 @@ class UnifiedVoice(nn.Module):
|
|||
ctc_codes[b][j] = last_code
|
||||
else:
|
||||
last_code = ctc_codes[b][j]
|
||||
alignment_targets = F.interpolate(ctc_codes.unsqueeze(1).float(), size=(mel_codes.shape[-1],), mode='nearest').long().squeeze()
|
||||
alignment_targets = F.interpolate(ctc_codes.unsqueeze(1).float(), size=(
|
||||
mel_codes.shape[-1],), mode='nearest').long().squeeze()
|
||||
|
||||
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
||||
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
|
||||
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
|
||||
alignment_targets = F.pad(alignment_targets, (0,2), value=0)
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
|
||||
alignment_targets = F.pad(alignment_targets, (0, 2), value=0)
|
||||
|
||||
conds = self.get_conditioning_latent(speech_conditioning_input)
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
|
||||
text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(
|
||||
text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(
|
||||
mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||
mel_inp = mel_codes
|
||||
mel_emb = self.mel_embedding(mel_inp)
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
|
||||
|
||||
text_logits, mel_logits, alignment_logits = self.get_logits(conds, text_emb, mel_emb, return_latent=return_latent)
|
||||
text_logits, mel_logits, alignment_logits = self.get_logits(
|
||||
conds, text_emb, mel_emb, return_latent=return_latent)
|
||||
if return_latent:
|
||||
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
||||
# Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
||||
return mel_logits[:, :-2]
|
||||
|
||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||
|
@ -407,25 +435,31 @@ class UnifiedVoice(nn.Module):
|
|||
n_head=self.heads,
|
||||
gradient_checkpointing=False,
|
||||
use_cache=True)
|
||||
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
|
||||
self.inference_model = GPT2InferenceModel(
|
||||
gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
|
||||
self.gpt.wte = self.mel_embedding
|
||||
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
|
||||
text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(
|
||||
text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
|
||||
speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds.append(self.conditioning_encoder(
|
||||
speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
|
||||
emb = torch.cat([conds, text_emb], dim=1)
|
||||
self.inference_model.store_prior_emb(emb)
|
||||
|
||||
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||
fake_inputs[:,-1] = self.start_mel_token
|
||||
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],),
|
||||
fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||
fake_inputs[:, -1] = self.start_mel_token
|
||||
|
||||
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
||||
max_length=seq_length, return_dict_in_generate=True, **hf_generate_kwargs)
|
||||
|
@ -438,10 +472,11 @@ def register_unified_voice4(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gpt = UnifiedVoice(model_dim=256, heads=4, max_conditioning_inputs=4, types=2)
|
||||
gpt = UnifiedVoice(model_dim=256, heads=4,
|
||||
max_conditioning_inputs=4, types=2)
|
||||
l = gpt(torch.randn(2, 3, 80, 800),
|
||||
torch.randint(high=256, size=(2,120)),
|
||||
torch.randint(high=256, size=(2, 120)),
|
||||
torch.tensor([32, 120]),
|
||||
torch.randint(high=8192, size=(2,250)),
|
||||
torch.tensor([250*256,195*256]),
|
||||
torch.randint(high=8192, size=(2, 250)),
|
||||
torch.tensor([250*256, 195*256]),
|
||||
types=torch.tensor([0, 1]))
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
import random
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import einsum
|
||||
|
||||
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from trainer.injectors.spec_augment import spec_augment
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
import torch_intermediary as ml
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from dlas.trainer.injectors.spec_augment import spec_augment
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import opt_get
|
||||
|
||||
|
||||
def exists(val):
|
||||
|
@ -17,7 +18,7 @@ def exists(val):
|
|||
|
||||
def masked_mean(t, mask, dim=1):
|
||||
t = t.masked_fill(~mask[:, :, None], 0.)
|
||||
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
|
||||
return t.sum(dim=1) / mask.sum(dim=1)[..., None]
|
||||
|
||||
|
||||
class VoiceCLIP(nn.Module):
|
||||
|
@ -36,7 +37,8 @@ class VoiceCLIP(nn.Module):
|
|||
super().__init__()
|
||||
self.encoder = AudioMiniEncoder(80, encoder_output)
|
||||
if pretrained_encoder_dict_path is not None:
|
||||
self.encoder.load_state_dict(torch.load(pretrained_encoder_dict_path))
|
||||
self.encoder.load_state_dict(
|
||||
torch.load(pretrained_encoder_dict_path))
|
||||
self.to_latent = ml.Linear(encoder_output, dim_latent, bias=False)
|
||||
self.temperature = nn.Parameter(torch.tensor(1.))
|
||||
self.mel_compression_ratio = mel_compression_ratio
|
||||
|
@ -47,7 +49,8 @@ class VoiceCLIP(nn.Module):
|
|||
speech_lengths,
|
||||
return_loss=True
|
||||
):
|
||||
half_length = min(speech_mels.shape[-1], torch.min(speech_lengths).item() // self.mel_compression_ratio) // 2
|
||||
half_length = min(
|
||||
speech_mels.shape[-1], torch.min(speech_lengths).item() // self.mel_compression_ratio) // 2
|
||||
|
||||
# Extract two speech MELs from the same clip, apply some random noise to them and also apply specaugment to them.
|
||||
first_half = speech_mels[:, :, :half_length]
|
||||
|
@ -81,7 +84,8 @@ class VoiceCLIP(nn.Module):
|
|||
second_emb = self.encoder(second_half)
|
||||
second_latents = self.to_latent(second_emb)
|
||||
|
||||
first_latents, second_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (first_latents, second_latents))
|
||||
first_latents, second_latents = map(lambda t: F.normalize(
|
||||
t, p=2, dim=-1), (first_latents, second_latents))
|
||||
|
||||
temp = self.temperature.exp()
|
||||
|
||||
|
@ -90,8 +94,10 @@ class VoiceCLIP(nn.Module):
|
|||
return sim
|
||||
|
||||
sim = einsum('i d, j d -> i j', first_latents, second_latents) * temp
|
||||
labels = torch.arange(first_latents.shape[0], device=first_latents.device)
|
||||
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
|
||||
labels = torch.arange(
|
||||
first_latents.shape[0], device=first_latents.device)
|
||||
loss = (F.cross_entropy(sim, labels) +
|
||||
F.cross_entropy(sim.t(), labels)) / 2
|
||||
return loss
|
||||
|
||||
def inference(self, speech_mels):
|
||||
|
@ -111,6 +117,6 @@ def register_voice_to_voice_clip(opt_net, opt):
|
|||
if __name__ == '__main__':
|
||||
clip = VoiceCLIP()
|
||||
for k in range(1000):
|
||||
clip(torch.randn((2,80,156)),
|
||||
torch.randint(130*1024,156*1024,(2,)),
|
||||
return_loss=True)
|
||||
clip(torch.randn((2, 80, 156)),
|
||||
torch.randint(130*1024, 156*1024, (2,)),
|
||||
return_loss=True)
|
||||
|
|
|
@ -3,11 +3,11 @@ import functools
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from x_transformers import Encoder, Decoder, ContinuousTransformerWrapper
|
||||
from x_transformers import ContinuousTransformerWrapper, Decoder, Encoder
|
||||
|
||||
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from trainer.networks import register_model
|
||||
import torch_intermediary as ml
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from dlas.trainer.networks import register_model
|
||||
|
||||
|
||||
class CheckpointedLayer(nn.Module):
|
||||
|
@ -15,13 +15,15 @@ class CheckpointedLayer(nn.Module):
|
|||
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
|
||||
checkpoint for all other args.
|
||||
"""
|
||||
|
||||
def __init__(self, wrap):
|
||||
super().__init__()
|
||||
self.wrap = wrap
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
|
||||
# This would screw up checkpointing.
|
||||
assert not (isinstance(v, torch.Tensor) and v.requires_grad)
|
||||
partial = functools.partial(self.wrap, **kwargs)
|
||||
return torch.utils.checkpoint.checkpoint(partial, x, *args)
|
||||
|
||||
|
@ -31,20 +33,22 @@ class CheckpointedXTransformer(nn.Module):
|
|||
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
|
||||
to channels-last that XTransformer expects.
|
||||
"""
|
||||
|
||||
def __init__(self, **xtransformer_kwargs):
|
||||
super().__init__()
|
||||
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
|
||||
|
||||
for i in range(len(self.transformer.attn_layers.layers)):
|
||||
n, b, r = self.transformer.attn_layers.layers[i]
|
||||
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
|
||||
self.transformer.attn_layers.layers[i] = nn.ModuleList(
|
||||
[n, CheckpointedLayer(b), r])
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.transformer(x, **kwargs)
|
||||
|
||||
|
||||
class Wav2VecMatcher(nn.Module):
|
||||
W2V_COMPRESSION=320
|
||||
W2V_COMPRESSION = 320
|
||||
|
||||
def __init__(self,
|
||||
model_dim,
|
||||
|
@ -56,41 +60,42 @@ class Wav2VecMatcher(nn.Module):
|
|||
|
||||
WAV2VEC_CHANNELS = 1024
|
||||
self.conditioning_encoder = AudioMiniEncoder(1, model_dim, base_channels=32, depth=6, resnet_blocks=1,
|
||||
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
|
||||
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
|
||||
# nn.Embedding
|
||||
self.text_embedding = ml.Embedding(num_text_tokens, model_dim)
|
||||
self.encoder = CheckpointedXTransformer(
|
||||
max_seq_len=-1,
|
||||
use_pos_emb=False,
|
||||
attn_layers=Encoder(
|
||||
dim=model_dim,
|
||||
depth=encoder_depth,
|
||||
heads=model_dim//64,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_emb_dim=True,
|
||||
)
|
||||
max_seq_len=-1,
|
||||
use_pos_emb=False,
|
||||
attn_layers=Encoder(
|
||||
dim=model_dim,
|
||||
depth=encoder_depth,
|
||||
heads=model_dim//64,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_emb_dim=True,
|
||||
)
|
||||
self.decoder_start_embedding = nn.Parameter(torch.randn(1,1,model_dim))
|
||||
self.decoder_stop_embedding = nn.Parameter(torch.randn(1,model_dim))
|
||||
)
|
||||
self.decoder_start_embedding = nn.Parameter(
|
||||
torch.randn(1, 1, model_dim))
|
||||
self.decoder_stop_embedding = nn.Parameter(torch.randn(1, model_dim))
|
||||
self.w2v_query_encoder = ml.Linear(WAV2VEC_CHANNELS, model_dim)
|
||||
self.w2v_value_encoder = ml.Linear(WAV2VEC_CHANNELS, model_dim)
|
||||
self.decoder = CheckpointedXTransformer(
|
||||
max_seq_len=-1, # Should be unused
|
||||
use_pos_emb=False,
|
||||
attn_layers=Decoder(
|
||||
dim=model_dim,
|
||||
depth=decoder_depth,
|
||||
heads=model_dim//64,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
cross_attend=True,
|
||||
)
|
||||
max_seq_len=-1, # Should be unused
|
||||
use_pos_emb=False,
|
||||
attn_layers=Decoder(
|
||||
dim=model_dim,
|
||||
depth=decoder_depth,
|
||||
heads=model_dim//64,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
cross_attend=True,
|
||||
)
|
||||
)
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
|
@ -111,28 +116,32 @@ class Wav2VecMatcher(nn.Module):
|
|||
enc_inputs = torch.cat([cond_emb.unsqueeze(1), text_emb], dim=1)
|
||||
dec_context = self.encoder(enc_inputs)
|
||||
w2v_values = self.w2v_value_encoder(w2v_logits)
|
||||
dec_inputs = torch.cat([self.decoder_start_embedding.repeat(w2v_values.shape[0],1,1), w2v_values], dim=1)
|
||||
dec_inputs = torch.cat([self.decoder_start_embedding.repeat(
|
||||
w2v_values.shape[0], 1, 1), w2v_values], dim=1)
|
||||
dec_out = self.decoder(dec_inputs, context=dec_context)[:, :-1]
|
||||
w2v_queries = self.w2v_query_encoder(w2v_logits)
|
||||
|
||||
# Compute losses, A CLIP-like dot product matcher and a mechanism to force pad prediction.
|
||||
b,l,c = dec_out.shape
|
||||
b, l, c = dec_out.shape
|
||||
keys_uncompressed = dec_out.reshape(b*l, c)
|
||||
queries_uncompressed = w2v_queries.reshape(b*l, c)
|
||||
dot = torch.einsum("i c, j c -> i j", keys_uncompressed, queries_uncompressed)
|
||||
dot = torch.einsum("i c, j c -> i j",
|
||||
keys_uncompressed, queries_uncompressed)
|
||||
labels = torch.arange(0, b*l, 1, device=dot.device)
|
||||
ce_loss1 = F.cross_entropy(dot, labels, reduction="none")
|
||||
ce_loss2 = F.cross_entropy(dot.t(), labels, reduction="none")
|
||||
mse_pad_loss = F.mse_loss(keys_uncompressed, self.decoder_stop_embedding.repeat(b*l,1), reduction="none").sum(dim=-1)
|
||||
mse_pad_loss = F.mse_loss(keys_uncompressed, self.decoder_stop_embedding.repeat(
|
||||
b*l, 1), reduction="none").sum(dim=-1)
|
||||
|
||||
# Create a mask based on w2v_lengths that will be used to ensure the encodings of padding tokens are not considered in the cross entropy loss
|
||||
loss_mask = torch.ones((b,l), device=ce_loss1.device)
|
||||
loss_mask = torch.ones((b, l), device=ce_loss1.device)
|
||||
w2v_lengths = clip_lengths // self.W2V_COMPRESSION
|
||||
for i in range(b):
|
||||
loss_mask[i, w2v_lengths[i]:] = 0
|
||||
loss_mask_collapsed = loss_mask.reshape(b*l)
|
||||
|
||||
ce_loss = (ce_loss1 * loss_mask_collapsed + ce_loss2 * loss_mask_collapsed).mean()
|
||||
ce_loss = (ce_loss1 * loss_mask_collapsed +
|
||||
ce_loss2 * loss_mask_collapsed).mean()
|
||||
mse_loss = (mse_pad_loss * (loss_mask_collapsed == 0)).mean()
|
||||
|
||||
return ce_loss, mse_loss
|
||||
|
@ -151,15 +160,18 @@ class Wav2VecMatcher(nn.Module):
|
|||
dec_out = self.decoder(dec_inputs, context=dec_context)
|
||||
|
||||
# Check if that was EOS.
|
||||
l2 = F.mse_loss(dec_out[:,-1], self.decoder_stop_embedding)
|
||||
l2 = F.mse_loss(dec_out[:, -1], self.decoder_stop_embedding)
|
||||
if l2 < .1: # TODO: fix threshold.
|
||||
break
|
||||
|
||||
# Find a matching w2v logit from the given iterable.
|
||||
matching_logit_index = self.find_matching_w2v_logit(dec_out[:,-1], w2v_logit_iterable)
|
||||
matching_logit_index = self.find_matching_w2v_logit(
|
||||
dec_out[:, -1], w2v_logit_iterable)
|
||||
matching_logit = w2v_logit_iterable[matching_logit_index]
|
||||
dec_inputs = torch.cat([dec_inputs, self.w2v_value_encoder(matching_logit).unsqueeze(1)], dim=1)
|
||||
produced_audio = torch.cat([produced_audio, audio_clip_iterable[matching_logit_index]], dim=-1)
|
||||
dec_inputs = torch.cat(
|
||||
[dec_inputs, self.w2v_value_encoder(matching_logit).unsqueeze(1)], dim=1)
|
||||
produced_audio = torch.cat(
|
||||
[produced_audio, audio_clip_iterable[matching_logit_index]], dim=-1)
|
||||
return produced_audio
|
||||
|
||||
|
||||
|
@ -170,9 +182,9 @@ def register_w2v_matcher(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
model = Wav2VecMatcher(512, 8, 8)
|
||||
toks = torch.randint(0, 100, (4,100))
|
||||
tok_lens = torch.tensor([50,60,70,80])
|
||||
cond = torch.randn(4,1,44000)
|
||||
logits = torch.randn(4,120,1024)
|
||||
logit_lens = torch.tensor([60,70,80,90])
|
||||
model(toks, cond, logits, tok_lens, logit_lens)
|
||||
toks = torch.randint(0, 100, (4, 100))
|
||||
tok_lens = torch.tensor([50, 60, 70, 80])
|
||||
cond = torch.randn(4, 1, 44000)
|
||||
logits = torch.randn(4, 120, 1024)
|
||||
logit_lens = torch.tensor([60, 70, 80, 90])
|
||||
model(toks, cond, logits, tok_lens, logit_lens)
|
||||
|
|
|
@ -2,8 +2,8 @@ import torch
|
|||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from models.audio.vocoders.univnet.lvcnet import LVCBlock
|
||||
from trainer.networks import register_model
|
||||
from dlas.models.audio.vocoders.univnet.lvcnet import LVCBlock
|
||||
from dlas.trainer.networks import register_model
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
@ -11,7 +11,7 @@ MAX_WAV_VALUE = 32768.0
|
|||
class UnivNetGenerator(nn.Module):
|
||||
"""UnivNet Generator"""
|
||||
|
||||
def __init__(self, noise_dim=64, channel_size=32, dilations=[1,3,9,27], strides=[8,8,4], lReLU_slope=.2, kpnet_conv_size=3,
|
||||
def __init__(self, noise_dim=64, channel_size=32, dilations=[1, 3, 9, 27], strides=[8, 8, 4], lReLU_slope=.2, kpnet_conv_size=3,
|
||||
# Below are MEL configurations options that this generator requires.
|
||||
hop_length=256, n_mel_channels=100):
|
||||
super(UnivNetGenerator, self).__init__()
|
||||
|
@ -38,11 +38,13 @@ class UnivNetGenerator(nn.Module):
|
|||
)
|
||||
|
||||
self.conv_pre = \
|
||||
nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode='reflect'))
|
||||
nn.utils.weight_norm(
|
||||
nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode='reflect'))
|
||||
|
||||
self.conv_post = nn.Sequential(
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode='reflect')),
|
||||
nn.utils.weight_norm(
|
||||
nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode='reflect')),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
|
@ -84,11 +86,13 @@ class UnivNetGenerator(nn.Module):
|
|||
def inference(self, c, z=None):
|
||||
# pad input mel with zeros to cut artifact
|
||||
# see https://github.com/seungwonpark/melgan/issues/8
|
||||
zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
|
||||
zero = torch.full(
|
||||
(c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
|
||||
mel = torch.cat((c, zero), dim=2)
|
||||
|
||||
if z is None:
|
||||
z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
|
||||
z = torch.randn(c.shape[0], self.noise_dim,
|
||||
mel.size(2)).to(mel.device)
|
||||
|
||||
audio = self.forward(mel, z)
|
||||
audio = audio[:, :, :-(self.hop_length * 10)]
|
||||
|
@ -112,5 +116,6 @@ if __name__ == '__main__':
|
|||
print(y.shape)
|
||||
assert y.shape == torch.Size([3, 1, 2560])
|
||||
|
||||
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
pytorch_total_params = sum(p.numel()
|
||||
for p in model.parameters() if p.requires_grad)
|
||||
print(pytorch_total_params)
|
||||
|
|
|
@ -35,12 +35,15 @@ class KernelPredictor(torch.nn.Module):
|
|||
self.conv_kernel_size = conv_kernel_size
|
||||
self.conv_layers = conv_layers
|
||||
|
||||
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
|
||||
kpnet_kernel_channels = conv_in_channels * \
|
||||
conv_out_channels * conv_kernel_size * conv_layers # l_w
|
||||
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
||||
|
||||
self.input_conv = nn.Sequential(
|
||||
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
nn.utils.weight_norm(
|
||||
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
|
||||
getattr(nn, kpnet_nonlinear_activation)(
|
||||
**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.residual_convs = nn.ModuleList()
|
||||
|
@ -52,11 +55,13 @@ class KernelPredictor(torch.nn.Module):
|
|||
nn.utils.weight_norm(
|
||||
nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
|
||||
bias=True)),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
getattr(nn, kpnet_nonlinear_activation)(
|
||||
**kpnet_nonlinear_activation_params),
|
||||
nn.utils.weight_norm(
|
||||
nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
|
||||
bias=True)),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
getattr(nn, kpnet_nonlinear_activation)(
|
||||
**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
)
|
||||
self.kernel_conv = nn.utils.weight_norm(
|
||||
|
@ -170,7 +175,8 @@ class LVCBlock(torch.nn.Module):
|
|||
for i, conv in enumerate(self.conv_blocks):
|
||||
output = conv(x) # (B, c_g, stride * L')
|
||||
|
||||
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
|
||||
# (B, 2 * c_g, c_g, kernel_size, cond_length)
|
||||
k = kernels[:, i, :, :, :, :]
|
||||
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
|
||||
|
||||
output = self.location_variable_convolution(output, k, b,
|
||||
|
@ -194,23 +200,29 @@ class LVCBlock(torch.nn.Module):
|
|||
'''
|
||||
batch, _, in_length = x.shape
|
||||
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
||||
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
|
||||
assert in_length == (
|
||||
kernel_length * hop_size), "length of (x, kernel) is not matched"
|
||||
|
||||
padding = dilation * int((kernel_size - 1) / 2)
|
||||
x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding)
|
||||
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
||||
# (batch, in_channels, in_length + 2*padding)
|
||||
x = F.pad(x, (padding, padding), 'constant', 0)
|
||||
# (batch, in_channels, kernel_length, hop_size + 2*padding)
|
||||
x = x.unfold(2, hop_size + 2 * padding, hop_size)
|
||||
|
||||
if hop_size < dilation:
|
||||
x = F.pad(x, (0, dilation), 'constant', 0)
|
||||
x = x.unfold(3, dilation,
|
||||
dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
||||
x = x[:, :, :, :, :hop_size]
|
||||
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
||||
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
||||
# (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
||||
x = x.transpose(3, 4)
|
||||
# (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
||||
x = x.unfold(4, kernel_size, 1)
|
||||
|
||||
o = torch.einsum('bildsk,biokl->bolsd', x, kernel)
|
||||
o = o.to(memory_format=torch.channels_last_3d)
|
||||
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
|
||||
bias = bias.unsqueeze(-1).unsqueeze(-1).to(
|
||||
memory_format=torch.channels_last_3d)
|
||||
o = o + bias
|
||||
o = o.contiguous().view(batch, out_channels, -1)
|
||||
|
||||
|
@ -220,4 +232,4 @@ class LVCBlock(torch.nn.Module):
|
|||
self.kernel_predictor.remove_weight_norm()
|
||||
nn.utils.remove_weight_norm(self.convt_pre[1])
|
||||
for block in self.conv_blocks:
|
||||
nn.utils.remove_weight_norm(block[1])
|
||||
nn.utils.remove_weight_norm(block[1])
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import sys
|
||||
|
||||
from models.audio.tts.tacotron2.stft import STFT
|
||||
import torch
|
||||
|
||||
from dlas.models.audio.tts.tacotron2.stft import STFT
|
||||
|
||||
sys.path.append('tacotron2')
|
||||
import torch
|
||||
|
||||
|
||||
class Denoiser(torch.nn.Module):
|
||||
|
|
|
@ -25,11 +25,12 @@
|
|||
#
|
||||
# *****************************************************************************
|
||||
import copy
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
import torch.nn.functional as F
|
||||
|
||||
from trainer.networks import register_model
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
from dlas.trainer.networks import register_model
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
|
@ -57,7 +58,8 @@ class WaveGlowLoss(torch.nn.Module):
|
|||
log_s_total = log_s_total + torch.sum(log_s)
|
||||
log_det_W_total += log_det_W_list[i]
|
||||
|
||||
loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total
|
||||
loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - \
|
||||
log_s_total - log_det_W_total
|
||||
return loss/(z.size(0)*z.size(1)*z.size(2))
|
||||
|
||||
|
||||
|
@ -67,6 +69,7 @@ class Invertible1x1Conv(torch.nn.Module):
|
|||
of its weight matrix. If reverse=True it does convolution with
|
||||
inverse
|
||||
"""
|
||||
|
||||
def __init__(self, c):
|
||||
super(Invertible1x1Conv, self).__init__()
|
||||
self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
|
||||
|
@ -77,7 +80,7 @@ class Invertible1x1Conv(torch.nn.Module):
|
|||
|
||||
# Ensure determinant is 1.0 not -1.0
|
||||
if torch.det(W) < 0:
|
||||
W[:,0] = -1*W[:,0]
|
||||
W[:, 0] = -1*W[:, 0]
|
||||
W = W.view(c, c, 1)
|
||||
self.conv.weight.data = W
|
||||
|
||||
|
@ -110,11 +113,12 @@ class WN(torch.nn.Module):
|
|||
from WaveNet is the convolutions need not be causal. There is also no dilation
|
||||
size reset. The dilation only doubles on each layer
|
||||
"""
|
||||
|
||||
def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
|
||||
kernel_size):
|
||||
super(WN, self).__init__()
|
||||
assert(kernel_size % 2 == 1)
|
||||
assert(n_channels % 2 == 0)
|
||||
assert (kernel_size % 2 == 1)
|
||||
assert (n_channels % 2 == 0)
|
||||
self.n_layers = n_layers
|
||||
self.n_channels = n_channels
|
||||
self.in_layers = torch.nn.ModuleList()
|
||||
|
@ -142,14 +146,14 @@ class WN(torch.nn.Module):
|
|||
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
|
||||
self.in_layers.append(in_layer)
|
||||
|
||||
|
||||
# last one is not necessary
|
||||
if i < n_layers - 1:
|
||||
res_skip_channels = 2*n_channels
|
||||
else:
|
||||
res_skip_channels = n_channels
|
||||
res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
|
||||
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
|
||||
res_skip_layer = torch.nn.utils.weight_norm(
|
||||
res_skip_layer, name='weight')
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
|
||||
def forward(self, forward_input):
|
||||
|
@ -164,13 +168,13 @@ class WN(torch.nn.Module):
|
|||
spect_offset = i*2*self.n_channels
|
||||
acts = fused_add_tanh_sigmoid_multiply(
|
||||
self.in_layers[i](audio),
|
||||
spect[:,spect_offset:spect_offset+2*self.n_channels,:],
|
||||
spect[:, spect_offset:spect_offset+2*self.n_channels, :],
|
||||
n_channels_tensor)
|
||||
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
if i < self.n_layers - 1:
|
||||
audio = audio + res_skip_acts[:,:self.n_channels,:]
|
||||
output = output + res_skip_acts[:,self.n_channels:,:]
|
||||
audio = audio + res_skip_acts[:, :self.n_channels, :]
|
||||
output = output + res_skip_acts[:, self.n_channels:, :]
|
||||
else:
|
||||
output = output + res_skip_acts
|
||||
|
||||
|
@ -185,7 +189,7 @@ class WaveGlow(torch.nn.Module):
|
|||
self.upsample = torch.nn.ConvTranspose1d(n_mel_channels,
|
||||
n_mel_channels,
|
||||
1024, stride=256)
|
||||
assert(n_group % 2 == 0)
|
||||
assert (n_group % 2 == 0)
|
||||
self.n_flows = n_flows
|
||||
self.n_group = n_group
|
||||
self.n_early_every = n_early_every
|
||||
|
@ -215,7 +219,7 @@ class WaveGlow(torch.nn.Module):
|
|||
|
||||
# Upsample spectrogram to size of audio
|
||||
spect = self.upsample(spect)
|
||||
assert(spect.size(2) >= audio.size(1))
|
||||
assert (spect.size(2) >= audio.size(1))
|
||||
if spect.size(2) > audio.size(1):
|
||||
spect = spect[:, :, :audio.size(1)]
|
||||
|
||||
|
@ -229,15 +233,15 @@ class WaveGlow(torch.nn.Module):
|
|||
|
||||
for k in range(self.n_flows):
|
||||
if k % self.n_early_every == 0 and k > 0:
|
||||
output_audio.append(audio[:,:self.n_early_size,:])
|
||||
audio = audio[:,self.n_early_size:,:]
|
||||
output_audio.append(audio[:, :self.n_early_size, :])
|
||||
audio = audio[:, self.n_early_size:, :]
|
||||
|
||||
audio, log_det_W = self.convinv[k](audio)
|
||||
log_det_W_list.append(log_det_W)
|
||||
|
||||
n_half = int(audio.size(1)/2)
|
||||
audio_0 = audio[:,:n_half,:]
|
||||
audio_1 = audio[:,n_half:,:]
|
||||
audio_0 = audio[:, :n_half, :]
|
||||
audio_1 = audio[:, n_half:, :]
|
||||
|
||||
output = self.WN[k]((audio_0, spect))
|
||||
log_s = output[:, n_half:, :]
|
||||
|
@ -245,10 +249,10 @@ class WaveGlow(torch.nn.Module):
|
|||
audio_1 = torch.exp(log_s)*audio_1 + b
|
||||
log_s_list.append(log_s)
|
||||
|
||||
audio = torch.cat([audio_0, audio_1],1)
|
||||
audio = torch.cat([audio_0, audio_1], 1)
|
||||
|
||||
output_audio.append(audio)
|
||||
return torch.cat(output_audio,1), log_s_list, log_det_W_list
|
||||
return torch.cat(output_audio, 1), log_s_list, log_det_W_list
|
||||
|
||||
def infer(self, spect, sigma=1.0):
|
||||
spect = self.upsample(spect)
|
||||
|
@ -272,26 +276,29 @@ class WaveGlow(torch.nn.Module):
|
|||
|
||||
for k in reversed(range(self.n_flows)):
|
||||
n_half = int(audio.size(1)/2)
|
||||
audio_0 = audio[:,:n_half,:]
|
||||
audio_1 = audio[:,n_half:,:]
|
||||
audio_0 = audio[:, :n_half, :]
|
||||
audio_1 = audio[:, n_half:, :]
|
||||
|
||||
output = self.WN[k]((audio_0, spect))
|
||||
|
||||
s = output[:, n_half:, :]
|
||||
b = output[:, :n_half, :]
|
||||
audio_1 = (audio_1 - b)/torch.exp(s)
|
||||
audio = torch.cat([audio_0, audio_1],1)
|
||||
audio = torch.cat([audio_0, audio_1], 1)
|
||||
|
||||
audio = self.convinv[k](audio, reverse=True)
|
||||
|
||||
if k % self.n_early_every == 0 and k > 0:
|
||||
if spect.type() == 'torch.cuda.HalfTensor':
|
||||
z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
|
||||
z = torch.cuda.HalfTensor(spect.size(
|
||||
0), self.n_early_size, spect.size(2)).normal_()
|
||||
else:
|
||||
z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
|
||||
audio = torch.cat((sigma*z, audio),1)
|
||||
z = torch.cuda.FloatTensor(spect.size(
|
||||
0), self.n_early_size, spect.size(2)).normal_()
|
||||
audio = torch.cat((sigma*z, audio), 1)
|
||||
|
||||
audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
|
||||
audio = audio.permute(0, 2, 1).contiguous().view(
|
||||
audio.size(0), -1).data
|
||||
return audio
|
||||
|
||||
@staticmethod
|
||||
|
@ -315,4 +322,4 @@ def remove(conv_list):
|
|||
|
||||
@register_model
|
||||
def register_nv_waveglow(opt_net, opt):
|
||||
return WaveGlow(**opt_net['args'])
|
||||
return WaveGlow(**opt_net['args'])
|
||||
|
|
|
@ -8,11 +8,11 @@
|
|||
https://arxiv.org/abs/1512.03385v1
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_intermediary as ml
|
||||
|
||||
from trainer.networks import register_model
|
||||
import torch.nn as nn
|
||||
|
||||
import dlas.torch_intermediary as ml
|
||||
from dlas.trainer.networks import register_model
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
|
@ -20,53 +20,60 @@ class BasicBlock(nn.Module):
|
|||
|
||||
"""
|
||||
|
||||
#BasicBlock and BottleNeck block
|
||||
#have different output size
|
||||
#we use class attribute expansion
|
||||
#to distinct
|
||||
# BasicBlock and BottleNeck block
|
||||
# have different output size
|
||||
# we use class attribute expansion
|
||||
# to distinct
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super().__init__()
|
||||
|
||||
#residual function
|
||||
# residual function
|
||||
self.residual_function = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3,
|
||||
stride=stride, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
|
||||
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion,
|
||||
kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
||||
)
|
||||
|
||||
#shortcut
|
||||
# shortcut
|
||||
self.shortcut = nn.Sequential()
|
||||
|
||||
#the shortcut output dimension is not the same with residual function
|
||||
#use 1*1 convolution to match the dimension
|
||||
# the shortcut output dimension is not the same with residual function
|
||||
# use 1*1 convolution to match the dimension
|
||||
if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||
|
||||
|
||||
class BottleNeck(nn.Module):
|
||||
"""Residual block for resnet over 50 layers
|
||||
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super().__init__()
|
||||
self.residual_function = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
|
||||
nn.Conv2d(out_channels, out_channels, stride=stride,
|
||||
kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
|
||||
nn.Conv2d(out_channels, out_channels *
|
||||
BottleNeck.expansion, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BottleNeck.expansion),
|
||||
)
|
||||
|
||||
|
@ -74,13 +81,15 @@ class BottleNeck(nn.Module):
|
|||
|
||||
if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
|
||||
nn.Conv2d(in_channels, out_channels * BottleNeck.expansion,
|
||||
stride=stride, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BottleNeck.expansion)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, num_block, num_classes=100):
|
||||
|
@ -92,8 +101,8 @@ class ResNet(nn.Module):
|
|||
nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True))
|
||||
#we use a different inputsize than the original paper
|
||||
#so conv2_x's stride is 1
|
||||
# we use a different inputsize than the original paper
|
||||
# so conv2_x's stride is 1
|
||||
self.conv2_x = self._make_layer(block, 32, num_block[0], 1)
|
||||
self.conv3_x = self._make_layer(block, 64, num_block[1], 2)
|
||||
self.conv4_x = self._make_layer(block, 128, num_block[2], 2)
|
||||
|
@ -138,30 +147,33 @@ class ResNet(nn.Module):
|
|||
|
||||
return output
|
||||
|
||||
|
||||
@register_model
|
||||
def register_cifar_resnet18(opt_net, opt):
|
||||
""" return a ResNet 18 object
|
||||
"""
|
||||
return ResNet(BasicBlock, [2, 2, 2, 2])
|
||||
|
||||
|
||||
def resnet34():
|
||||
""" return a ResNet 34 object
|
||||
"""
|
||||
return ResNet(BasicBlock, [3, 4, 6, 3])
|
||||
|
||||
|
||||
def resnet50():
|
||||
""" return a ResNet 50 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 4, 6, 3])
|
||||
|
||||
|
||||
def resnet101():
|
||||
""" return a ResNet 101 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 4, 23, 3])
|
||||
|
||||
|
||||
def resnet152():
|
||||
""" return a ResNet 152 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 8, 36, 3])
|
||||
|
||||
|
||||
|
|
|
@ -1,18 +1,17 @@
|
|||
# A direct copy of torchvision's resnet.py modified to support gradient checkpointing.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck
|
||||
import torchvision
|
||||
import torch_intermediary as ml
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck
|
||||
|
||||
import dlas.torch_intermediary as ml
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||
'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
from dlas.trainer.networks import register_model
|
||||
from dlas.utils.util import checkpoint
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user