mrq 2023-03-21 15:39:28 +07:00
parent fe24641763
commit a4afad8837
264 changed files with 8514 additions and 5847 deletions

72
.gitignore vendored

@ -1,30 +1,3 @@
dlas/experiments/*
dlas/codes/*.txt
dlas/codes/wandb/*
dlas/codes/pretrained_models/*
dlas/codes/scripts/audio/pretrained_models/*
results/*
tb_logger/*
datasets/*
options/*
data/*
.vscode
*.html
*.png
*.jpg
*.gif
*.pth
*.pytorch
*.zip
*.cu
*.pt
*.pth
*.pdf
*.tsv
# template
# Byte-compiled / optimized / DLL files
__pycache__/
@ -36,6 +9,7 @@ __pycache__/
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
@ -47,12 +21,9 @@ lib64/
parts/
sdist/
var/
wheels/
pretrained/*
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
@ -72,9 +43,8 @@ htmlcov/
.cache
nosetests.xml
coverage.xml
*.cover
*,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
@ -83,14 +53,6 @@ coverage.xml
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
@ -98,36 +60,8 @@ docs/_build/
# PyBuilder
target/
# Jupyter Notebook
#Ipython Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/

@ -1 +1 @@
recursive-include codes/*
recursive-include dlas/*

@ -3,7 +3,7 @@ import torch
import torch.utils.data
from munch import munchify
from utils.util import opt_get
from dlas.utils.util import opt_get
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None, collate_fn=None, shuffle=True):
@ -33,77 +33,90 @@ def create_dataset(dataset_opt, return_collate=False):
# datasets for image restoration
if mode == 'fullimage':
from data.images.full_image_dataset import FullImageDataset as D
from dlas.data.images.full_image_dataset import FullImageDataset as D
elif mode == 'single_image_extensible':
from data.images.single_image_dataset import SingleImageDataset as D
from dlas.data.images.single_image_dataset import \
SingleImageDataset as D
elif mode == 'multi_frame_extensible':
from data.images.multi_frame_dataset import MultiFrameDataset as D
from dlas.data.images.multi_frame_dataset import MultiFrameDataset as D
elif mode == 'combined':
from data.combined_dataset import CombinedDataset as D
from dlas.data.combined_dataset import CombinedDataset as D
elif mode == 'multiscale':
from data.images.multiscale_dataset import MultiScaleDataset as D
from dlas.data.images.multiscale_dataset import MultiScaleDataset as D
elif mode == 'paired_frame':
from data.images.paired_frame_dataset import PairedFrameDataset as D
from dlas.data.images.paired_frame_dataset import \
PairedFrameDataset as D
elif mode == 'stylegan2':
from data.images.stylegan2_dataset import Stylegan2Dataset as D
from dlas.data.images.stylegan2_dataset import Stylegan2Dataset as D
elif mode == 'imagefolder':
from data.images.image_folder_dataset import ImageFolderDataset as D
from dlas.data.images.image_folder_dataset import \
ImageFolderDataset as D
elif mode == 'torch_dataset':
from data.torch_dataset import TorchDataset as D
elif mode == 'byol_dataset':
from data.images.byol_attachment import ByolDatasetWrapper as D
from dlas.data.images.byol_attachment import ByolDatasetWrapper as D
elif mode == 'byol_structured_dataset':
from data.images.byol_attachment import StructuredCropDatasetWrapper as D
from dlas.data.images.byol_attachment import \
StructuredCropDatasetWrapper as D
elif mode == 'random_aug_wrapper':
from data.images.byol_attachment import DatasetRandomAugWrapper as D
from dlas.data.images.byol_attachment import \
DatasetRandomAugWrapper as D
elif mode == 'random_dataset':
from data.images.random_dataset import RandomDataset as D
from dlas.data.images.random_dataset import RandomDataset as D
elif mode == 'zipfile':
from data.images.zip_file_dataset import ZipFileDataset as D
from dlas.data.images.zip_file_dataset import ZipFileDataset as D
elif mode == 'nv_tacotron':
from data.audio.nv_tacotron_dataset import TextWavLoader as D
from data.audio.nv_tacotron_dataset import TextMelCollate as C
from models.audio.tts.tacotron2 import create_hparams
from dlas.data.audio.nv_tacotron_dataset import TextMelCollate as C
from dlas.data.audio.nv_tacotron_dataset import TextWavLoader as D
from dlas.models.audio.tts.tacotron2 import create_hparams
default_params = create_hparams()
default_params.update(dataset_opt)
dataset_opt = munchify(default_params)
if opt_get(dataset_opt, ['needs_collate'], True):
collate = C()
elif mode == 'paired_voice_audio':
from data.audio.paired_voice_audio_dataset import TextWavLoader as D
from models.audio.tts.tacotron2 import create_hparams
from dlas.data.audio.paired_voice_audio_dataset import \
TextWavLoader as D
from dlas.models.audio.tts.tacotron2 import create_hparams
default_params = create_hparams()
default_params.update(dataset_opt)
dataset_opt = munchify(default_params)
elif mode == 'fast_paired_voice_audio':
from data.audio.fast_paired_dataset import FastPairedVoiceDataset as D
from models.audio.tts.tacotron2 import create_hparams
from dlas.data.audio.fast_paired_dataset import \
FastPairedVoiceDataset as D
from dlas.models.audio.tts.tacotron2 import create_hparams
default_params = create_hparams()
default_params.update(dataset_opt)
dataset_opt = munchify(default_params)
elif mode == 'fast_paired_voice_audio_with_phonemes':
from data.audio.fast_paired_dataset_with_phonemes import FastPairedVoiceDataset as D
from models.audio.tts.tacotron2 import create_hparams
from dlas.data.audio.fast_paired_dataset_with_phonemes import \
FastPairedVoiceDataset as D
from dlas.models.audio.tts.tacotron2 import create_hparams
default_params = create_hparams()
default_params.update(dataset_opt)
dataset_opt = munchify(default_params)
elif mode == 'gpt_tts':
from data.audio.gpt_tts_dataset import GptTtsDataset as D
from data.audio.gpt_tts_dataset import GptTtsCollater as C
from dlas.data.audio.gpt_tts_dataset import GptTtsCollater as C
from dlas.data.audio.gpt_tts_dataset import GptTtsDataset as D
collate = C(dataset_opt)
elif mode == 'unsupervised_audio':
from data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset as D
from dlas.data.audio.unsupervised_audio_dataset import \
UnsupervisedAudioDataset as D
elif mode == 'unsupervised_audio_with_noise':
from data.audio.audio_with_noise_dataset import AudioWithNoiseDataset as D
from dlas.data.audio.audio_with_noise_dataset import \
AudioWithNoiseDataset as D
elif mode == 'preprocessed_mel':
from data.audio.preprocessed_mel_dataset import PreprocessedMelDataset as D
from dlas.data.audio.preprocessed_mel_dataset import \
PreprocessedMelDataset as D
elif mode == 'grand_conjoined_voice':
from data.audio.grand_conjoined_dataset import GrandConjoinedDataset as D
from data.zero_pad_dict_collate import ZeroPadDictCollate as C
from dlas.data.audio.grand_conjoined_dataset import \
GrandConjoinedDataset as D
from dlas.data.zero_pad_dict_collate import ZeroPadDictCollate as C
if opt_get(dataset_opt, ['needs_collate'], False):
collate = C()
else:
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
raise NotImplementedError(
'Dataset [{:s}] is not recognized.'.format(mode))
dataset = D(dataset_opt)
if return_collate:
@ -115,9 +128,10 @@ def create_dataset(dataset_opt, return_collate=False):
def get_dataset_debugger(dataset_opt):
mode = dataset_opt['mode']
if mode == 'paired_voice_audio':
from data.audio.paired_voice_audio_dataset import PairedVoiceDebugger
from dlas.data.audio.paired_voice_audio_dataset import \
PairedVoiceDebugger
return PairedVoiceDebugger()
elif mode == 'fast_paired_voice_audio':
from data.audio.fast_paired_dataset import FastPairedVoiceDebugger
from dlas.data.audio.fast_paired_dataset import FastPairedVoiceDebugger
return FastPairedVoiceDebugger()
return None
return None

@ -2,18 +2,17 @@ import random
import sys
from math import pi
import librosa
import torch
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset
from tqdm import tqdm
import torch.nn.functional as F
from data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset, load_audio
from data.util import load_paths_from_cache, find_files_of_type, is_audio_file
# Just all ones.
from utils.util import opt_get
from dlas.data.audio.unsupervised_audio_dataset import (
UnsupervisedAudioDataset, load_audio)
from dlas.data.util import (find_files_of_type, is_audio_file,
load_paths_from_cache)
from dlas.utils.util import opt_get
def _integration_fn_fully_enabled(n):
@ -23,7 +22,7 @@ def _integration_fn_fully_enabled(n):
# Randomly assigns up to 5 blocks of the output tensor the value '1'. Rest is zero
def _integration_fn_spiky(n):
fn = torch.zeros((n,))
spikes = random.randint(1,5)
spikes = random.randint(1, 5)
for _ in range(spikes):
sz = random.randint(n//8, n//2)
pos = random.randint(0, n)
@ -35,18 +34,19 @@ def _integration_fn_spiky(n):
# Uses a sinusoidal ramp up and down (of random length) to a peak which is held for a random duration.
def _integration_fn_smooth(n):
center = random.randint(1, n-2)
max_duration=n-center-1
max_duration = n-center-1
duration = random.randint(max_duration//4, max_duration)
end = center+duration
ramp_up_sz = random.randint(n//16,n//4)
ramp_up = torch.sin(pi*torch.arange(0,ramp_up_sz)/(2*ramp_up_sz))
ramp_up_sz = random.randint(n//16, n//4)
ramp_up = torch.sin(pi*torch.arange(0, ramp_up_sz)/(2*ramp_up_sz))
if ramp_up_sz > center:
ramp_up = ramp_up[(ramp_up_sz-center):]
ramp_up_sz = center
ramp_down_sz = random.randint(n//16,n//4)
ramp_down = torch.flip(torch.sin(pi*torch.arange(0,ramp_down_sz)/(2*ramp_down_sz)), dims=[0])
ramp_down_sz = random.randint(n//16, n//4)
ramp_down = torch.flip(
torch.sin(pi*torch.arange(0, ramp_down_sz)/(2*ramp_down_sz)), dims=[0])
if ramp_down_sz > (n-end):
ramp_down = ramp_down[:(n-end)]
ramp_down_sz = n-end
@ -71,16 +71,22 @@ def load_rir(path, sr, max_sz):
Wraps a unsupervised_audio_dataset and applies noise to the output clips, then provides labels depending on what
noise was added.
'''
class AudioWithNoiseDataset(Dataset):
def __init__(self, opt):
self.underlying_dataset = UnsupervisedAudioDataset(opt)
self.env_noise_paths = load_paths_from_cache(opt['env_noise_paths'], opt['env_noise_cache'])
self.music_paths = load_paths_from_cache(opt['music_paths'], opt['music_cache'])
self.openair_paths = find_files_of_type('img', opt['openair_path'], qualifier=is_audio_file)[0]
self.env_noise_paths = load_paths_from_cache(
opt['env_noise_paths'], opt['env_noise_cache'])
self.music_paths = load_paths_from_cache(
opt['music_paths'], opt['music_cache'])
self.openair_paths = find_files_of_type(
'img', opt['openair_path'], qualifier=is_audio_file)[0]
self.min_volume = opt_get(opt, ['min_noise_volume'], .2)
self.max_volume = opt_get(opt, ['max_noise_volume'], .5)
self.sampling_rate = self.underlying_dataset.sampling_rate
self.use_gpu_for_reverb_compute = opt_get(opt, ['use_gpu_for_reverb_compute'], True)
self.use_gpu_for_reverb_compute = opt_get(
opt, ['use_gpu_for_reverb_compute'], True)
self.openair_kernels = None
self.current_item_fetch = 0
self.fetch_error_count = 0
@ -90,7 +96,8 @@ class AudioWithNoiseDataset(Dataset):
# Load the openair reverbs as CUDA tensors.
self.openair_kernels = []
for oa in self.openair_paths:
self.openair_kernels.append(load_rir(oa, self.underlying_dataset.sampling_rate, self.underlying_dataset.sampling_rate*2).cuda())
self.openair_kernels.append(load_rir(
oa, self.underlying_dataset.sampling_rate, self.underlying_dataset.sampling_rate*2).cuda())
def __getitem__(self, item):
if self.current_item_fetch != item:
@ -113,10 +120,11 @@ class AudioWithNoiseDataset(Dataset):
clip = clip * clipvol
label = random.randint(0, 4) # Current excludes GSM corruption.
#label = 3
# label = 3
if label > 0 and label < 4: # 0 is basically "leave it alone"
aug_needed = True
augvol = (random.random() * (self.max_volume-self.min_volume) + self.min_volume)
augvol = (random.random() * (self.max_volume -
self.min_volume) + self.min_volume)
if label == 1:
# Add environmental noise.
augpath = random.choice(self.env_noise_paths)
@ -131,13 +139,15 @@ class AudioWithNoiseDataset(Dataset):
# This can take two forms:
if padding_room < 22000 or random.random() < .5:
# (1) The voices talk over one another. If there is no padding room, we always take this choice.
intg_fns = [_integration_fn_smooth, _integration_fn_fully_enabled]
intg_fns = [_integration_fn_smooth,
_integration_fn_fully_enabled]
else:
# (2) There are simply two voices in the clip, separated from one another.
# This is a special case that does not use the same logic as the rest of the augmentations.
aug = load_audio(augpath, self.underlying_dataset.sampling_rate)
aug = load_audio(
augpath, self.underlying_dataset.sampling_rate)
# Pad with some random silence
aug = F.pad(aug, (random.randint(20,4000), 0))
aug = F.pad(aug, (random.randint(20, 4000), 0))
# Fit what we can given the padding room we have.
aug = aug[:, :padding_room]
clip = torch.cat([clip, aug], dim=1)
@ -146,7 +156,8 @@ class AudioWithNoiseDataset(Dataset):
out['clip_lengths'] = torch.tensor(clip.shape[-1])
aug_needed = False
if aug_needed:
aug = load_audio(augpath, self.underlying_dataset.sampling_rate)
aug = load_audio(
augpath, self.underlying_dataset.sampling_rate)
if aug.shape[1] > clip.shape[1]:
n, cn = aug.shape[1], clip.shape[1]
gap = n-cn
@ -157,7 +168,8 @@ class AudioWithNoiseDataset(Dataset):
if aug.shape[1] < clip.shape[1]:
gap = clip.shape[1] - aug.shape[1]
placement = random.randint(0, gap-1)
aug = torch.nn.functional.pad(aug, (placement, gap-placement))
aug = torch.nn.functional.pad(
aug, (placement, gap-placement))
clip = clip + aug
elif label == 4:
# Perform reverb (to simulate being in a large room with an omni-mic). This is performed by convolving
@ -166,19 +178,23 @@ class AudioWithNoiseDataset(Dataset):
rir = random.choice(self.openair_kernels)
else:
augpath = random.choice(self.openair_paths)
rir = load_rir(augpath, self.underlying_dataset.sampling_rate, clip.shape[-1])
rir = load_rir(
augpath, self.underlying_dataset.sampling_rate, clip.shape[-1])
clip = torch.nn.functional.pad(clip, (rir.shape[1]-1, 0))
if self.use_gpu_for_reverb_compute:
clip = clip.cuda()
clip = torch.nn.functional.conv1d(clip.unsqueeze(0), rir.unsqueeze(0)).squeeze(0).cpu()
clip = torch.nn.functional.conv1d(
clip.unsqueeze(0), rir.unsqueeze(0)).squeeze(0).cpu()
elif label == 5:
# Apply the GSM codec to simulate cellular phone audio.
clip = torchaudio.functional.apply_codec(clip, self.underlying_dataset.sampling_rate, format="gsm")
clip = torchaudio.functional.apply_codec(
clip, self.underlying_dataset.sampling_rate, format="gsm")
except:
if self.fetch_error_count > 10:
print(f"Exception encountered processing {item}, re-trying because this is often just a failed aug.")
print(
f"Exception encountered processing {item}, re-trying because this is often just a failed aug.")
print(sys.exc_info())
#raise # Uncomment to surface exceptions.
# raise # Uncomment to surface exceptions.
self.fetch_error_count += 1
return self[item]
@ -187,7 +203,7 @@ class AudioWithNoiseDataset(Dataset):
clip = F.pad(clip, (0, padding_room))
out['clip'] = clip
out['label'] = label
#out['aug'] = aug
# out['aug'] = aug
out['augpath'] = augpath
out['augvol'] = augvol
out['clipvol'] = clipvol
@ -216,14 +232,15 @@ if __name__ == '__main__':
'openair_path': 'D:\\data\\audio\\openair\\resampled',
'use_gpu_for_reverb_compute': False,
}
from data import create_dataset, create_dataloader, util
from data import create_dataloader, create_dataset, util
ds = create_dataset(params)
dl = create_dataloader(ds, params, pin_memory=False)
i = 0
for b in tqdm(dl):
for b_ in range(b['clip'].shape[0]):
#torchaudio.save(f'{i}_clip_{b_}_{b["label"][b_].item()}.wav', b['clip'][b_][:, :b['clip_lengths'][b_]], ds.sampling_rate)
#torchaudio.save(f'{i}_clip_{b_}_aug.wav', b['aug'][b_], ds.sampling_rate)
print(f'{i} aug path: {b["augpath"][b_]} aug volume: {b["augvol"][b_]} clip volume: {b["clipvol"][b_]}')
# torchaudio.save(f'{i}_clip_{b_}_{b["label"][b_].item()}.wav', b['clip'][b_][:, :b['clip_lengths'][b_]], ds.sampling_rate)
# torchaudio.save(f'{i}_clip_{b_}_aug.wav', b['aug'][b_], ds.sampling_rate)
print(
f'{i} aug path: {b["augpath"][b_]} aug volume: {b["augvol"][b_]} clip volume: {b["clipvol"][b_]}')
i += 1

@ -12,13 +12,15 @@ import torchaudio
from tqdm import tqdm
from transformers import Wav2Vec2CTCTokenizer
from data.audio.paired_voice_audio_dataset import CharacterTokenizer
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
from utils.util import opt_get
from dlas.data.audio.paired_voice_audio_dataset import CharacterTokenizer
from dlas.data.audio.unsupervised_audio_dataset import (load_audio,
load_similar_clips)
from dlas.utils.util import opt_get
def parse_tsv_aligned_codes(line, base_path):
fpt = line.strip().split('\t')
def convert_string_list_to_tensor(strlist):
if strlist.startswith('['):
strlist = strlist[1:]
@ -43,6 +45,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
The upshot is that this dataset loads extremely quickly and consumes almost no system memory.
"""
def __init__(self, hparams):
self.paths = hparams['path']
if not isinstance(self.paths, list):
@ -52,26 +55,33 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
self.types = opt_get(hparams, ['types'], [0 for _ in self.paths])
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
self.produce_ctc_metadata = opt_get(hparams, ['produce_ctc_metadata'], False)
self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False)
self.conditioning_candidates = opt_get(
hparams, ['num_conditioning_candidates'], 1)
self.conditioning_length = opt_get(
hparams, ['conditioning_length'], 44100)
self.produce_ctc_metadata = opt_get(
hparams, ['produce_ctc_metadata'], False)
self.debug_failures = opt_get(
hparams, ['debug_loading_failures'], False)
self.text_cleaners = hparams.text_cleaners
self.sample_rate = hparams.sample_rate
self.aligned_codes_to_audio_ratio = 443 * self.sample_rate // 22050
self.max_wav_len = opt_get(hparams, ['max_wav_length'], None)
self.load_aligned_codes = opt_get(hparams, ['load_aligned_codes'], False)
self.load_aligned_codes = opt_get(
hparams, ['load_aligned_codes'], False)
if self.max_wav_len is not None:
self.max_aligned_codes = self.max_wav_len // self.aligned_codes_to_audio_ratio
self.max_text_len = opt_get(hparams, ['max_text_length'], None)
assert self.max_wav_len is not None and self.max_text_len is not None
self.use_bpe_tokenizer = opt_get(hparams, ['use_bpe_tokenizer'], False)
if self.use_bpe_tokenizer:
from data.audio.voice_tokenizer import VoiceBpeTokenizer
self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
from dlas.data.audio.voice_tokenizer import VoiceBpeTokenizer
self.tokenizer = VoiceBpeTokenizer(opt_get(
hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
else:
self.tokenizer = CharacterTokenizer()
self.skipped_items = 0 # records how many items are skipped when accessing an index.
# records how many items are skipped when accessing an index.
self.skipped_items = 0
self.load_times = torch.zeros((256,))
self.load_ind = 0
@ -110,7 +120,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
try: # This can fail when seeking to a UTF-8 escape byte.
f.readline()
except:
return self.load_random_line(depth=depth + 1), type # On failure, just recurse and try again.
# On failure, just recurse and try again.
return self.load_random_line(depth=depth + 1), type
l2 = f.readline()
if l2:
@ -119,14 +130,16 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
return parse_tsv_aligned_codes(l2, base_path), type
except:
print(f"error parsing random offset: {sys.exc_info()}")
return self.load_random_line(depth=depth+1), type # On failure, just recurse and try again.
# On failure, just recurse and try again.
return self.load_random_line(depth=depth+1), type
def get_ctc_metadata(self, codes):
grouped = groupby(codes.tolist())
rcodes, repeats, seps = [], [], [0]
for val, group in grouped:
if val == 0:
seps[-1] = len(list(group)) # This is a very important distinction! It means the padding belongs to the character proceeding it.
# This is a very important distinction! It means the padding belongs to the character proceeding it.
seps[-1] = len(list(group))
else:
rcodes.append(val)
repeats.append(len(list(group)))
@ -142,7 +155,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
if rcodes.shape[0] < self.max_text_len:
gap = self.max_text_len - rcodes.shape[0]
rcodes = F.pad(rcodes, (0, gap))
repeats = F.pad(repeats, (0, gap), value=1) # The minimum value for repeats is 1, hence this is the pad value too.
# The minimum value for repeats is 1, hence this is the pad value too.
repeats = F.pad(repeats, (0, gap), value=1)
seps = F.pad(seps, (0, gap))
elif rcodes.shape[0] > self.max_text_len:
rcodes = rcodes[:self.max_text_len]
@ -165,7 +179,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
if text is None or len(text.strip()) == 0:
raise ValueError
cond, cond_is_self = load_similar_clips(apt[0], self.conditioning_length, self.sample_rate,
n=self.conditioning_candidates) if self.load_conditioning else (None, False)
n=self.conditioning_candidates) if self.load_conditioning else (None, False)
except:
if self.skipped_items > 100:
raise # Rethrow if we have nested too far.
@ -179,12 +193,13 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
self.skipped_items = 0
if wav is None or \
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
if self.debug_failures:
print(f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}")
rv = random.randint(0,len(self)-1)
print(
f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}")
rv = random.randint(0, len(self)-1)
return self[rv]
orig_output = wav.shape[-1]
orig_text_len = tseq.shape[0]
@ -192,7 +207,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
if wav.shape[-1] != self.max_wav_len:
wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1]))
# These codes are aligned to audio inputs, so make sure to pad them as well.
aligned_codes = F.pad(aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0]))
aligned_codes = F.pad(
aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0]))
if tseq.shape[0] != self.max_text_len:
tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0]))
@ -223,7 +239,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
return res
def __len__(self):
return self.total_size_bytes // 1000 # 1000 cuts down a TSV file to the actual length pretty well.
# 1000 cuts down a TSV file to the actual length pretty well.
return self.total_size_bytes // 1000
class FastPairedVoiceDebugger:
@ -243,7 +260,8 @@ class FastPairedVoiceDebugger:
if isinstance(state, dict):
self.total_items = opt_get(state, ['total_items'], 0)
self.loaded_items = opt_get(state, ['loaded_items'], 0)
self.self_conditioning_items = opt_get(state, ['self_conditioning_items'], 0)
self.self_conditioning_items = opt_get(
state, ['self_conditioning_items'], 0)
def update(self, batch):
self.total_items += batch['wav'].shape[0]
@ -252,7 +270,8 @@ class FastPairedVoiceDebugger:
for filename in batch['filenames']:
self.unique_files.add(hashlib.sha256(filename.encode('utf-8')))
if 'conditioning' in batch.keys():
self.self_conditioning_items += batch['conditioning_contains_self'].sum().item()
self.self_conditioning_items += batch['conditioning_contains_self'].sum(
).item()
def get_debugging_map(self):
return {
@ -269,13 +288,13 @@ if __name__ == '__main__':
params = {
'mode': 'fast_paired_voice_audio',
'path': ['y:/libritts/train-other-500/transcribed-oco.tsv',
'y:/libritts/train-clean-100/transcribed-oco.tsv',
'y:/libritts/train-clean-360/transcribed-oco.tsv',
'y:/clips/books1/transcribed-oco.tsv',
'y:/clips/books2/transcribed-oco.tsv',
'y:/bigasr_dataset/hifi_tts/transcribed-oco.tsv',
'y:/clips/podcasts-1/transcribed-oco.tsv',],
'types': [0,1,1,1,2,2,0],
'y:/libritts/train-clean-100/transcribed-oco.tsv',
'y:/libritts/train-clean-360/transcribed-oco.tsv',
'y:/clips/books1/transcribed-oco.tsv',
'y:/clips/books2/transcribed-oco.tsv',
'y:/bigasr_dataset/hifi_tts/transcribed-oco.tsv',
'y:/clips/podcasts-1/transcribed-oco.tsv',],
'types': [0, 1, 1, 1, 2, 2, 0],
'phase': 'train',
'n_workers': 0,
'batch_size': batch_sz,
@ -289,11 +308,12 @@ if __name__ == '__main__':
'load_aligned_codes': True,
'produce_ctc_metadata': True,
}
from data import create_dataset, create_dataloader
from data import create_dataloader, create_dataset
def save(b, i, ib, key, c=None):
if c is not None:
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', b[key][ib][c], 22050)
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav',
b[key][ib][c], 22050)
else:
torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)
@ -304,8 +324,8 @@ if __name__ == '__main__':
max_pads, max_repeats = 0, 0
for i, b in tqdm(enumerate(dl)):
for ib in range(batch_sz):
#max_pads = max(max_pads, b['ctc_pads'].max())
#max_repeats = max(max_repeats, b['ctc_repeats'].max())
# max_pads = max(max_pads, b['ctc_pads'].max())
# max_repeats = max(max_repeats, b['ctc_repeats'].max())
print(f'{i} {ib} {b["real_text"][ib]}')
save(b, i, ib, 'wav')
save(b, i, ib, 'conditioning', 0)
@ -314,4 +334,3 @@ if __name__ == '__main__':
if i > 15:
break
print(max_pads, max_repeats)

@ -12,13 +12,15 @@ import torchaudio
from tqdm import tqdm
from transformers import Wav2Vec2Processor
from data.audio.paired_voice_audio_dataset import CharacterTokenizer
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
from utils.util import opt_get
from dlas.data.audio.paired_voice_audio_dataset import CharacterTokenizer
from dlas.data.audio.unsupervised_audio_dataset import (load_audio,
load_similar_clips)
from dlas.utils.util import opt_get
def parse_tsv_aligned_codes(line, base_path):
fpt = line.strip().split('\t')
def convert_string_list_to_tensor(strlist):
if strlist.startswith('['):
strlist = strlist[1:]
@ -43,10 +45,12 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
The upshot is that this dataset loads extremely quickly and consumes almost no system memory.
"""
def __init__(self, hparams):
self.paths = hparams['path']
phoneme_paths = hparams['phoneme_paths']
self.paths = [(p, False) for p in self.paths] + [(p, True) for p in phoneme_paths]
self.paths = [(p, False) for p in self.paths] + [(p, True)
for p in phoneme_paths]
self.paths_size_bytes = [os.path.getsize(p) for p, _ in self.paths]
self.total_size_bytes = sum(self.paths_size_bytes)
@ -54,28 +58,36 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
self.normal_text_end_token = hparams['normal_text_end_token']
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
self.produce_ctc_metadata = opt_get(hparams, ['produce_ctc_metadata'], False)
self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False)
self.conditioning_candidates = opt_get(
hparams, ['num_conditioning_candidates'], 1)
self.conditioning_length = opt_get(
hparams, ['conditioning_length'], 44100)
self.produce_ctc_metadata = opt_get(
hparams, ['produce_ctc_metadata'], False)
self.debug_failures = opt_get(
hparams, ['debug_loading_failures'], False)
self.text_cleaners = hparams.text_cleaners
self.sample_rate = hparams.sample_rate
self.aligned_codes_to_audio_ratio = 443 * self.sample_rate // 22050
self.max_wav_len = opt_get(hparams, ['max_wav_length'], None)
self.load_aligned_codes = opt_get(hparams, ['load_aligned_codes'], False)
self.load_aligned_codes = opt_get(
hparams, ['load_aligned_codes'], False)
if self.max_wav_len is not None:
self.max_aligned_codes = self.max_wav_len // self.aligned_codes_to_audio_ratio
self.max_text_len = opt_get(hparams, ['max_text_length'], None)
assert self.max_wav_len is not None and self.max_text_len is not None
self.use_bpe_tokenizer = opt_get(hparams, ['use_bpe_tokenizer'], False)
if self.use_bpe_tokenizer:
from data.audio.voice_tokenizer import VoiceBpeTokenizer
self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
from dlas.data.audio.voice_tokenizer import VoiceBpeTokenizer
self.tokenizer = VoiceBpeTokenizer(opt_get(
hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
else:
self.tokenizer = CharacterTokenizer()
self.ipa_phoneme_tokenizer = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft").tokenizer
self.ipa_phoneme_tokenizer = Wav2Vec2Processor.from_pretrained(
"facebook/wav2vec2-lv-60-espeak-cv-ft").tokenizer
self.ipa_phoneme_tokenizer.do_phonemize = False
self.skipped_items = 0 # records how many items are skipped when accessing an index.
# records how many items are skipped when accessing an index.
self.skipped_items = 0
self.load_times = torch.zeros((256,))
self.load_ind = 0
@ -117,7 +129,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
try: # This can fail when seeking to a UTF-8 escape byte.
f.readline()
except:
return self.load_random_line(depth=depth + 1) # On failure, just recurse and try again.
# On failure, just recurse and try again.
return self.load_random_line(depth=depth + 1)
l2 = f.readline()
if l2:
@ -126,14 +139,16 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
return parse_tsv_aligned_codes(l2, base_path), type, is_phonetic
except:
print(f"error parsing random offset: {sys.exc_info()}")
return self.load_random_line(depth=depth+1) # On failure, just recurse and try again.
# On failure, just recurse and try again.
return self.load_random_line(depth=depth+1)
def get_ctc_metadata(self, codes):
grouped = groupby(codes.tolist())
rcodes, repeats, seps = [], [], [0]
for val, group in grouped:
if val == 0:
seps[-1] = len(list(group)) # This is a very important distinction! It means the padding belongs to the character proceeding it.
# This is a very important distinction! It means the padding belongs to the character proceeding it.
seps[-1] = len(list(group))
else:
rcodes.append(val)
repeats.append(len(list(group)))
@ -149,7 +164,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
if rcodes.shape[0] < self.max_text_len:
gap = self.max_text_len - rcodes.shape[0]
rcodes = F.pad(rcodes, (0, gap))
repeats = F.pad(repeats, (0, gap), value=1) # The minimum value for repeats is 1, hence this is the pad value too.
# The minimum value for repeats is 1, hence this is the pad value too.
repeats = F.pad(repeats, (0, gap), value=1)
seps = F.pad(seps, (0, gap))
elif rcodes.shape[0] > self.max_text_len:
rcodes = rcodes[:self.max_text_len]
@ -171,7 +187,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
if text is None or len(text.strip()) == 0:
raise ValueError
cond, cond_is_self = load_similar_clips(apt[0], self.conditioning_length, self.sample_rate,
n=self.conditioning_candidates) if self.load_conditioning else (None, False)
n=self.conditioning_candidates) if self.load_conditioning else (None, False)
except:
if self.skipped_items > 100:
raise # Rethrow if we have nested too far.
@ -185,12 +201,13 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
self.skipped_items = 0
if wav is None or \
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
if self.debug_failures:
print(f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}")
rv = random.randint(0,len(self)-1)
print(
f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}")
rv = random.randint(0, len(self)-1)
return self[rv]
# Shift phonetic token and aligned_code tokens over.
@ -206,7 +223,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
if wav.shape[-1] != self.max_wav_len:
wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1]))
# These codes are aligned to audio inputs, so make sure to pad them as well.
aligned_codes = F.pad(aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0]))
aligned_codes = F.pad(
aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0]))
if tseq.shape[0] != self.max_text_len:
tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0]))
@ -237,7 +255,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
return res
def __len__(self):
return self.total_size_bytes // 1000 # 1000 cuts down a TSV file to the actual length pretty well.
# 1000 cuts down a TSV file to the actual length pretty well.
return self.total_size_bytes // 1000
class FastPairedVoiceDebugger:
@ -257,7 +276,8 @@ class FastPairedVoiceDebugger:
if isinstance(state, dict):
self.total_items = opt_get(state, ['total_items'], 0)
self.loaded_items = opt_get(state, ['loaded_items'], 0)
self.self_conditioning_items = opt_get(state, ['self_conditioning_items'], 0)
self.self_conditioning_items = opt_get(
state, ['self_conditioning_items'], 0)
def update(self, batch):
self.total_items += batch['wav'].shape[0]
@ -266,7 +286,8 @@ class FastPairedVoiceDebugger:
for filename in batch['filenames']:
self.unique_files.add(hashlib.sha256(filename.encode('utf-8')))
if 'conditioning' in batch.keys():
self.self_conditioning_items += batch['conditioning_contains_self'].sum().item()
self.self_conditioning_items += batch['conditioning_contains_self'].sum(
).item()
def get_debugging_map(self):
return {
@ -284,7 +305,7 @@ if __name__ == '__main__':
'mode': 'fast_paired_voice_audio_with_phonemes',
'path': ['y:/libritts/train-clean-100/transcribed-oco.tsv',],
'phoneme_paths': ['y:/libritts/train-other-500/transcribed-phoneme-oco.tsv'],
'types': [0,0],
'types': [0, 0],
'normal_text_end_token': 256,
'phase': 'train',
'n_workers': 0,
@ -299,11 +320,12 @@ if __name__ == '__main__':
'load_aligned_codes': False,
'debug_loading_failures': True,
}
from data import create_dataset, create_dataloader
from data import create_dataloader, create_dataset
def save(b, i, ib, key, c=None):
if c is not None:
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', b[key][ib][c], 22050)
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav',
b[key][ib][c], 22050)
else:
torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)
@ -314,14 +336,13 @@ if __name__ == '__main__':
max_pads, max_repeats = 0, 0
for i, b in tqdm(enumerate(dl)):
for ib in range(batch_sz):
#max_pads = max(max_pads, b['ctc_pads'].max())
#max_repeats = max(max_repeats, b['ctc_repeats'].max())
# max_pads = max(max_pads, b['ctc_pads'].max())
# max_repeats = max(max_repeats, b['ctc_repeats'].max())
print(f'{i} {ib} {b["real_text"][ib]}')
#save(b, i, ib, 'wav')
#save(b, i, ib, 'conditioning', 0)
#save(b, i, ib, 'conditioning', 1)
# save(b, i, ib, 'wav')
# save(b, i, ib, 'conditioning', 0)
# save(b, i, ib, 'conditioning', 1)
pass
if i > 15:
break
print(max_pads, max_repeats)

@ -6,9 +6,8 @@ import torch.utils.data
from torch import LongTensor
from tqdm import tqdm
from models.audio.tts.tacotron2 import load_filepaths_and_text
from models.audio.tts.tacotron2 import symbols
from models.audio.tts.tacotron2 import text_to_sequence
from dlas.models.audio.tts.tacotron2 import (load_filepaths_and_text, symbols,
text_to_sequence)
class GptTtsDataset(torch.utils.data.Dataset):
@ -21,7 +20,7 @@ class GptTtsDataset(torch.utils.data.Dataset):
def __init__(self, opt):
self.path = os.path.dirname(opt['path'])
self.audiopaths_and_text = load_filepaths_and_text(opt['path'])
self.text_cleaners=['english_cleaners']
self.text_cleaners = ['english_cleaners']
self.MEL_DICTIONARY_SIZE = opt['mel_vocab_size']+3
self.MEL_START_TOKEN = LongTensor([self.MEL_DICTIONARY_SIZE-3])
@ -32,7 +31,8 @@ class GptTtsDataset(torch.utils.data.Dataset):
audiopath_and_text = self.audiopaths_and_text[index]
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
text = torch.IntTensor(text_to_sequence(text, self.text_cleaners))
text = torch.cat([self.TEXT_START_TOKEN, text, self.TEXT_STOP_TOKEN], dim=0)
text = torch.cat([self.TEXT_START_TOKEN, text,
self.TEXT_STOP_TOKEN], dim=0)
# Fetch quantized MELs
quant_path = audiopath.replace('wavs/', 'quantized_mels/') + '.pth'
@ -57,8 +57,9 @@ class GptTtsCollater():
def __call__(self, batch):
text_lens = [len(x[0]) for x in batch]
#max_text_len = max(text_lens)
max_text_len = self.MAX_SYMBOLS_PER_PHRASE # This forces all outputs to have the full 200 characters. Testing if this makes a difference.
# max_text_len = max(text_lens)
# This forces all outputs to have the full 200 characters. Testing if this makes a difference.
max_text_len = self.MAX_SYMBOLS_PER_PHRASE
mel_lens = [len(x[1]) for x in batch]
max_mel_len = max(mel_lens)
texts = []
@ -70,7 +71,8 @@ class GptTtsCollater():
text = F.pad(text, (0, max_text_len-len(text)), value=0)
text = torch.where(text == 0, text_range_embedding, text)
texts.append(text)
qmels.append(F.pad(qmel, (0, max_mel_len-len(qmel)), value=self.MEL_PAD_TOKEN))
qmels.append(F.pad(qmel, (0, max_mel_len-len(qmel)),
value=self.MEL_PAD_TOKEN))
filenames = [j[2] for j in batch]
@ -96,7 +98,7 @@ if __name__ == '__main__':
'batch_size': 16,
'mel_vocab_size': 512,
}
from data import create_dataset, create_dataloader
from data import create_dataloader, create_dataset
ds, c = create_dataset(params, return_collate=True)
dl = create_dataloader(ds, params, collate_fn=c)
@ -107,5 +109,5 @@ if __name__ == '__main__':
for b in tqdm(dl):
max_mel = max(max_mel, b['padded_qmel'].shape[2])
max_text = max(max_text, b['padded_text'].shape[1])
m=torch.stack(m)
m = torch.stack(m)
print(m.mean(), m.std())

@ -7,14 +7,15 @@ import torchaudio
from munch import munchify
from tqdm import tqdm
from data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset
from data.text.hf_datasets_wrapper import HfDataset
from utils.util import opt_get
from dlas.data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset
from dlas.data.text.hf_datasets_wrapper import HfDataset
from dlas.utils.util import opt_get
def build_paired_voice_dataset(args):
from data.audio.paired_voice_audio_dataset import TextWavLoader as D
from models.audio.tts.tacotron2 import create_hparams
from dlas.data.audio.paired_voice_audio_dataset import TextWavLoader as D
default_params = create_hparams()
default_params.update(args)
dataset_opt = munchify(default_params)
@ -33,6 +34,7 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
Performs tokenization at this level, ignoring any tokenization performed by upstream datasets.
"""
def __init__(self, opt):
sample_rate = 22050 # Fixed.
paired_dataset_args = opt['paired_dataset_args']
@ -47,7 +49,8 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
self.max_solo_text_length = opt['max_solo_text_length']
self.collate = opt_get(opt, ['needs_collate'], False)
self.sample_rate = sample_rate
self.num_conditioning_candidates = opt_get(opt, ['num_conditioning_candidates'], 0)
self.num_conditioning_candidates = opt_get(
opt, ['num_conditioning_candidates'], 0)
self.conditioning_length = opt_get(opt, ['conditioning_length'], 44000)
load_conditioning = self.num_conditioning_candidates > 0
@ -75,7 +78,8 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
def fetch_text_at(self, i):
try:
txt = self.text[i % len(self.text)]['text']
assert '*' not in txt # This is a hack to get around the use of '*' to mask expletives in some text-only datasets. There really isn't a linguistic use for this character anyways.
# This is a hack to get around the use of '*' to mask expletives in some text-only datasets. There really isn't a linguistic use for this character anyways.
assert '*' not in txt
tok = self.speech_and_text.get_text(txt)
padding_required = self.max_solo_text_length - tok.shape[0]
if padding_required < 0:
@ -137,7 +141,8 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
sp = self.speech[i % len(self.speech)]
# Set upper bound on solo speech lengths. This is handled automatically when collation is turned off, but needs to be done otherwise.
sp['clip'] = sp['clip'][:, :self.max_solo_audio_length]
sp['clip_lengths'] = sp['clip_lengths'].clamp(0, self.max_solo_audio_length)
sp['clip_lengths'] = sp['clip_lengths'].clamp(
0, self.max_solo_audio_length)
return self.optionally_add_conditioning_candidates({
'paired_audio': snt['wav'],
'paired_audio_lengths': snt['wav_lengths'],
@ -205,7 +210,7 @@ if __name__ == '__main__':
'use_bpe_tokenizer': False,
},
}
from data import create_dataset, create_dataloader
from data import create_dataloader, create_dataset
os.remove('test_cache_delete_me2.pth')
ds, c = create_dataset(train_params, return_collate=True)
@ -213,7 +218,8 @@ if __name__ == '__main__':
def save(b, i, ib, key, c=None):
if c is not None:
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', b[key][ib][c], 22050)
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav',
b[key][ib][c], 22050)
else:
torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)
@ -224,16 +230,17 @@ if __name__ == '__main__':
m = None
for i, b in tqdm(enumerate(dl)):
for ib in range(batch_sz):
#save(b, i, ib, 'paired_audio')
#save(b, i, ib, 'paired_audio_conditioning', 0)
#save(b, i, ib, 'paired_audio_conditioning', 1)
print(f'Paired file: {b["paired_file"][ib]} text: {b["paired_text"][ib]}')
print(f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}')
#save(b, i, ib, 'speech_audio')
#save(b, i, ib, 'speech_audio_conditioning', 0)
#save(b, i, ib, 'speech_audio_conditioning', 1)
#print(f'Text: {b["text_text"][ib]}')
#print(f'Text decoded: {decode(b, ib, "text_tokens")}')
# save(b, i, ib, 'paired_audio')
# save(b, i, ib, 'paired_audio_conditioning', 0)
# save(b, i, ib, 'paired_audio_conditioning', 1)
print(
f'Paired file: {b["paired_file"][ib]} text: {b["paired_text"][ib]}')
print(
f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}')
# save(b, i, ib, 'speech_audio')
# save(b, i, ib, 'speech_audio_conditioning', 0)
# save(b, i, ib, 'speech_audio_conditioning', 1)
# print(f'Text: {b["text_text"][ib]}')
# print(f'Text decoded: {decode(b, ib, "text_tokens")}')
if i > 5:
break

@ -7,32 +7,36 @@ import torch.utils.data
import torchaudio
from tqdm import tqdm
from data.audio.unsupervised_audio_dataset import load_audio
from data.util import find_files_of_type, is_audio_file
from models.audio.tts.tacotron2 import load_filepaths_and_text
from models.audio.tts.tacotron2 import text_to_sequence
from utils.util import opt_get
from dlas.data.audio.unsupervised_audio_dataset import load_audio
from dlas.data.util import find_files_of_type, is_audio_file
from dlas.models.audio.tts.tacotron2 import (load_filepaths_and_text,
text_to_sequence)
from dlas.utils.util import opt_get
def load_tsv(filename):
with open(filename, encoding='utf-8') as f:
components = [line.strip().split('\t') for line in f]
base = os.path.dirname(filename)
filepaths_and_text = [[os.path.join(base, f'{component[1]}'), component[0]] for component in components]
filepaths_and_text = [
[os.path.join(base, f'{component[1]}'), component[0]] for component in components]
return filepaths_and_text
def load_mozilla_cv(filename):
with open(filename, encoding='utf-8') as f:
components = [line.strip().split('\t') for line in f][1:] # First line is the header
components = [line.strip().split('\t')
for line in f][1:] # First line is the header
base = os.path.dirname(filename)
filepaths_and_text = [[os.path.join(base, f'clips/{component[1]}'), component[2]] for component in components]
filepaths_and_text = [[os.path.join(
base, f'clips/{component[1]}'), component[2]] for component in components]
return filepaths_and_text
def load_voxpopuli(filename):
with open(filename, encoding='utf-8') as f:
lines = [line.strip().split('\t') for line in f][1:] # First line is the header
lines = [line.strip().split('\t')
for line in f][1:] # First line is the header
base = os.path.dirname(filename)
filepaths_and_text = []
for line in lines:
@ -40,7 +44,8 @@ def load_voxpopuli(filename):
continue
file, raw_text, norm_text, speaker_id, split, gender = line
year = file[:4]
filepaths_and_text.append([os.path.join(base, year, f'{file}.ogg.wav'), raw_text])
filepaths_and_text.append(
[os.path.join(base, year, f'{file}.ogg.wav'), raw_text])
return filepaths_and_text
@ -56,8 +61,10 @@ class TextWavLoader(torch.utils.data.Dataset):
assert len(self.path) == len(fetcher_mode)
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 3)
self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
self.conditioning_candidates = opt_get(
hparams, ['num_conditioning_candidates'], 3)
self.conditioning_length = opt_get(
hparams, ['conditioning_length'], 44100)
self.audiopaths_and_text = []
for p, fm in zip(self.path, fetcher_mode):
if fm == 'lj' or fm == 'libritts':
@ -65,10 +72,12 @@ class TextWavLoader(torch.utils.data.Dataset):
elif fm == 'tsv':
fetcher_fn = load_tsv
elif fm == 'mozilla_cv':
assert not self.load_conditioning # Conditioning inputs are incompatible with mozilla_cv
# Conditioning inputs are incompatible with mozilla_cv
assert not self.load_conditioning
fetcher_fn = load_mozilla_cv
elif fm == 'voxpopuli':
assert not self.load_conditioning # Conditioning inputs are incompatible with voxpopuli
# Conditioning inputs are incompatible with voxpopuli
assert not self.load_conditioning
fetcher_fn = load_voxpopuli
else:
raise NotImplementedError()
@ -96,10 +105,13 @@ class TextWavLoader(torch.utils.data.Dataset):
return text_norm
def load_conditioning_candidates(self, path):
candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0]
assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
candidates = find_files_of_type(
'img', os.path.dirname(path), qualifier=is_audio_file)[0]
# Sanity check to ensure we aren't loading "related files" that aren't actually related.
assert len(candidates) < 50000
if len(candidates) == 0:
print(f"No conditioning candidates found for {path} (not even the clip itself??)")
print(
f"No conditioning candidates found for {path} (not even the clip itself??)")
raise NotImplementedError()
# Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates.
related_clips = []
@ -110,25 +122,28 @@ class TextWavLoader(torch.utils.data.Dataset):
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
elif gap > 0:
rand_start = random.randint(0, gap)
rel_clip = rel_clip[:, rand_start:rand_start+self.conditioning_length]
rel_clip = rel_clip[:, rand_start:rand_start +
self.conditioning_length]
related_clips.append(rel_clip)
return torch.stack(related_clips, dim=0)
def __getitem__(self, index):
try:
tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index])
cond = self.load_conditioning_candidates(self.audiopaths_and_text[index][0]) if self.load_conditioning else None
tseq, wav, text, path = self.get_wav_text_pair(
self.audiopaths_and_text[index])
cond = self.load_conditioning_candidates(
self.audiopaths_and_text[index][0]) if self.load_conditioning else None
except:
print(f"error loading {self.audiopaths_and_text[index][0]}")
return self[index+1]
if wav is None or \
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
#if wav is not None:
# if wav is not None:
# print(f"Exception {index} wav_len:{wav.shape[-1]} text_len:{tseq.shape[0]} fname: {path}")
rv = random.randint(0,len(self)-1)
rv = random.randint(0, len(self)-1)
return self[rv]
orig_output = wav.shape[-1]
orig_text_len = tseq.shape[0]
@ -157,6 +172,7 @@ class TextWavLoader(torch.utils.data.Dataset):
class TextMelCollate():
""" Zero-pads model inputs and targets based on number of frames per step
"""
def __call__(self, batch):
"""Collate's training batch from normalized text and wav
PARAMS
@ -226,7 +242,7 @@ if __name__ == '__main__':
'num_conditioning_candidates': 3,
'conditioning_length': 44100,
}
from data import create_dataset, create_dataloader
from data import create_dataloader, create_dataset
ds, c = create_dataset(params, return_collate=True)
dl = create_dataloader(ds, params, collate_fn=c)
@ -240,4 +256,5 @@ if __name__ == '__main__':
print(f'{i} {ib} {b["real_text"][ib]}')
torchaudio.save(f'{i}_clip_{ib}.wav', b['wav'][ib], ds.sample_rate)
for c in range(3):
torchaudio.save(f'{i}_clip_{ib}_cond{c}.wav', b['conditioning'][ib, c], ds.sample_rate)
torchaudio.save(f'{i}_clip_{ib}_cond{c}.wav',
b['conditioning'][ib, c], ds.sample_rate)

@ -8,10 +8,13 @@ import torch.utils.data
import torchaudio
from tqdm import tqdm
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
from models.audio.tts.tacotron2 import load_filepaths_and_text, load_filepaths_and_text_type
from models.audio.tts.tacotron2 import text_to_sequence, sequence_to_text
from utils.util import opt_get
from dlas.data.audio.unsupervised_audio_dataset import (load_audio,
load_similar_clips)
from dlas.models.audio.tts.tacotron2 import (load_filepaths_and_text,
load_filepaths_and_text_type,
sequence_to_text,
text_to_sequence)
from dlas.utils.util import opt_get
def load_tsv_type(filename, type):
@ -24,10 +27,12 @@ def load_tsv_type(filename, type):
if len(components) < 2:
bad_lines += 1
if bad_lines > 1000:
print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
print(
f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
raise ValueError
continue
filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0]] + [type])
filepaths_and_text.append(
[os.path.join(base, f'{components[1]}'), components[0]] + [type])
return filepaths_and_text
@ -41,10 +46,12 @@ def load_tsv(filename):
if len(components) < 2:
bad_lines += 1
if bad_lines > 1000:
print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
print(
f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
raise ValueError
continue
filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0]])
filepaths_and_text.append(
[os.path.join(base, f'{components[1]}'), components[0]])
return filepaths_and_text
@ -67,10 +74,12 @@ def load_tsv_aligned_codes_type(filename, type):
if len(components) < 3:
bad_lines += 1
if bad_lines > 1000:
print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
print(
f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
raise ValueError
continue
filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0], convert_string_list_to_tensor(components[2])] + [type])
filepaths_and_text.append([os.path.join(
base, f'{components[1]}'), components[0], convert_string_list_to_tensor(components[2])] + [type])
return filepaths_and_text
@ -84,24 +93,29 @@ def load_tsv_aligned_codes(filename):
if len(components) < 3:
bad_lines += 1
if bad_lines > 1000:
print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
print(
f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}')
raise ValueError
continue
filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0], convert_string_list_to_tensor(components[2])])
filepaths_and_text.append([os.path.join(
base, f'{components[1]}'), components[0], convert_string_list_to_tensor(components[2])])
return filepaths_and_text
def load_mozilla_cv(filename, type):
with open(filename, encoding='utf-8') as f:
components = [line.strip().split('\t') for line in f][1:] # First line is the header
components = [line.strip().split('\t')
for line in f][1:] # First line is the header
base = os.path.dirname(filename)
filepaths_and_text = [[os.path.join(base, f'clips/{component[1]}'), component[2], type] for component in components]
filepaths_and_text = [[os.path.join(
base, f'clips/{component[1]}'), component[2], type] for component in components]
return filepaths_and_text
def load_voxpopuli(filename, type):
with open(filename, encoding='utf-8') as f:
lines = [line.strip().split('\t') for line in f][1:] # First line is the header
lines = [line.strip().split('\t')
for line in f][1:] # First line is the header
base = os.path.dirname(filename)
filepaths_and_text = []
for line in lines:
@ -109,7 +123,8 @@ def load_voxpopuli(filename, type):
continue
file, raw_text, norm_text, speaker_id, split, gender = line
year = file[:4]
filepaths_and_text.append([os.path.join(base, year, f'{file}.ogg.wav'), raw_text, type])
filepaths_and_text.append(
[os.path.join(base, year, f'{file}.ogg.wav'), raw_text, type])
return filepaths_and_text
@ -134,11 +149,16 @@ class TextWavLoader(torch.utils.data.Dataset):
assert len(self.path) == len(fetcher_mode)
self.load_conditioning = opt_get(hparams, ['load_conditioning'], False)
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False)
self.load_aligned_codes = opt_get(hparams, ['load_aligned_codes'], False)
self.aligned_codes_to_audio_ratio = opt_get(hparams, ['aligned_codes_ratio'], 443)
self.conditioning_candidates = opt_get(
hparams, ['num_conditioning_candidates'], 1)
self.conditioning_length = opt_get(
hparams, ['conditioning_length'], 44100)
self.debug_failures = opt_get(
hparams, ['debug_loading_failures'], False)
self.load_aligned_codes = opt_get(
hparams, ['load_aligned_codes'], False)
self.aligned_codes_to_audio_ratio = opt_get(
hparams, ['aligned_codes_ratio'], 443)
self.audiopaths_and_text = []
for p, fm, type in zip(self.path, fetcher_mode, self.types):
if fm == 'lj' or fm == 'libritts':
@ -146,10 +166,12 @@ class TextWavLoader(torch.utils.data.Dataset):
elif fm == 'tsv':
fetcher_fn = load_tsv_aligned_codes_type if self.load_aligned_codes else load_tsv_type
elif fm == 'mozilla_cv':
assert not self.load_conditioning # Conditioning inputs are incompatible with mozilla_cv
# Conditioning inputs are incompatible with mozilla_cv
assert not self.load_conditioning
fetcher_fn = load_mozilla_cv
elif fm == 'voxpopuli':
assert not self.load_conditioning # Conditioning inputs are incompatible with voxpopuli
# Conditioning inputs are incompatible with voxpopuli
assert not self.load_conditioning
fetcher_fn = load_voxpopuli
else:
raise NotImplementedError()
@ -165,11 +187,13 @@ class TextWavLoader(torch.utils.data.Dataset):
assert self.max_wav_len is not None and self.max_text_len is not None
self.use_bpe_tokenizer = opt_get(hparams, ['use_bpe_tokenizer'], True)
if self.use_bpe_tokenizer:
from data.audio.voice_tokenizer import VoiceBpeTokenizer
self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
from dlas.data.audio.voice_tokenizer import VoiceBpeTokenizer
self.tokenizer = VoiceBpeTokenizer(opt_get(
hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json'))
else:
self.tokenizer = CharacterTokenizer()
self.skipped_items = 0 # records how many items are skipped when accessing an index.
# records how many items are skipped when accessing an index.
self.skipped_items = 0
def get_wav_text_pair(self, audiopath_and_text):
# separate filename and text
@ -191,19 +215,21 @@ class TextWavLoader(torch.utils.data.Dataset):
def __getitem__(self, index):
self.skipped_items += 1
try:
tseq, wav, text, path, type = self.get_wav_text_pair(self.audiopaths_and_text[index])
tseq, wav, text, path, type = self.get_wav_text_pair(
self.audiopaths_and_text[index])
if text is None or len(text.strip()) == 0:
raise ValueError
if wav is None or wav.shape[-1] < (.6 * self.sample_rate):
# Ultra short clips are also useless (and can cause problems within some models).
raise ValueError
cond, cond_is_self = load_similar_clips(self.audiopaths_and_text[index][0], self.conditioning_length, self.sample_rate,
n=self.conditioning_candidates) if self.load_conditioning else (None, False)
n=self.conditioning_candidates) if self.load_conditioning else (None, False)
except:
if self.skipped_items > 100:
raise # Rethrow if we have nested too far.
if self.debug_failures:
print(f"error loading {self.audiopaths_and_text[index][0]} {sys.exc_info()}")
print(
f"error loading {self.audiopaths_and_text[index][0]} {sys.exc_info()}")
return self[(index+1) % len(self)]
if self.load_aligned_codes:
@ -213,12 +239,13 @@ class TextWavLoader(torch.utils.data.Dataset):
self.skipped_items = 0
if wav is None or \
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
if self.debug_failures:
print(f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}")
rv = random.randint(0,len(self)-1)
print(
f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}")
rv = random.randint(0, len(self)-1)
return self[rv]
orig_output = wav.shape[-1]
orig_text_len = tseq.shape[0]
@ -226,7 +253,8 @@ class TextWavLoader(torch.utils.data.Dataset):
wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1]))
if self.load_aligned_codes:
# These codes are aligned to audio inputs, so make sure to pad them as well.
aligned_codes = F.pad(aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0]))
aligned_codes = F.pad(
aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0]))
if tseq.shape[0] != self.max_text_len:
tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0]))
res = {
@ -265,13 +293,15 @@ class PairedVoiceDebugger:
if isinstance(state, dict):
self.total_items = opt_get(state, ['total_items'], 0)
self.loaded_items = opt_get(state, ['loaded_items'], 0)
self.self_conditioning_items = opt_get(state, ['self_conditioning_items'], 0)
self.self_conditioning_items = opt_get(
state, ['self_conditioning_items'], 0)
def update(self, batch):
self.total_items += batch['wav'].shape[0]
self.loaded_items += batch['skipped_items'].sum().item()
if 'conditioning' in batch.keys():
self.self_conditioning_items += batch['conditioning_contains_self'].sum().item()
self.self_conditioning_items += batch['conditioning_contains_self'].sum(
).item()
def get_debugging_map(self):
return {
@ -299,11 +329,12 @@ if __name__ == '__main__':
'use_bpe_tokenizer': True,
'load_aligned_codes': False,
}
from data import create_dataset, create_dataloader
from data import create_dataloader, create_dataset
def save(b, i, ib, key, c=None):
if c is not None:
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', b[key][ib][c], 22050)
torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav',
b[key][ib][c], 22050)
else:
torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)
@ -317,4 +348,3 @@ if __name__ == '__main__':
save(b, i, ib, 'wav')
if i > 5:
break

@ -9,14 +9,15 @@ import torchaudio
import torchvision
from tqdm import tqdm
from utils.util import opt_get
from dlas.utils.util import opt_get
class PreprocessedMelDataset(torch.utils.data.Dataset):
def __init__(self, opt):
path = opt['path']
cache_path = opt['cache_path'] # Will fail when multiple paths specified, must be specified in this case.
# Will fail when multiple paths specified, must be specified in this case.
cache_path = opt['cache_path']
if os.path.exists(cache_path):
self.paths = torch.load(cache_path)
else:
@ -36,8 +37,8 @@ class PreprocessedMelDataset(torch.utils.data.Dataset):
padding_needed = self.pad_to - mel.shape[-1]
mask = torch.zeros_like(mel)
if padding_needed > 0:
mel = F.pad(mel, (0,padding_needed))
mask = F.pad(mask, (0,padding_needed), value=1)
mel = F.pad(mel, (0, padding_needed))
mask = F.pad(mask, (0, padding_needed), value=1)
output = {
'mel': mel,
@ -62,13 +63,13 @@ if __name__ == '__main__':
'n_workers': 0,
'batch_size': 16,
}
from data import create_dataset, create_dataloader
from data import create_dataloader, create_dataset
ds = create_dataset(params)
dl = create_dataloader(ds, params)
i = 0
for b in tqdm(dl):
#pass
# pass
torchvision.utils.save_image((b['mel'].unsqueeze(1)+1)/2, f'{i}.png')
i += 1
if i > 20:

@ -3,15 +3,16 @@ import random
import sys
import torch
import torch.utils.data
import torch.nn.functional as F
import torch.utils.data
import torchaudio
from audio2numpy import open_audio
from tqdm import tqdm
from data.util import find_files_of_type, is_audio_file, load_paths_from_cache
from models.audio.tts.tacotron2.taco_utils import load_wav_to_torch
from utils.util import opt_get
from dlas.data.util import (find_files_of_type, is_audio_file,
load_paths_from_cache)
from dlas.models.audio.tts.tacotron2.taco_utils import load_wav_to_torch
from dlas.utils.util import opt_get
def load_audio(audiopath, sampling_rate):
@ -53,17 +54,21 @@ def load_similar_clips(path, sample_length, sample_rate, n=3, fallback_to_self=T
similarities = torch.load(sim_path)
fname = os.path.basename(path)
if fname in similarities.keys():
candidates = [os.path.join(os.path.dirname(path), s) for s in similarities[fname]]
candidates = [os.path.join(os.path.dirname(path), s)
for s in similarities[fname]]
else:
print(f'Similarities list found for {path} but {fname} was not in that list.')
#candidates.append(path) # Always include self as a possible similar clip.
print(
f'Similarities list found for {path} but {fname} was not in that list.')
# candidates.append(path) # Always include self as a possible similar clip.
if len(candidates) == 0:
if fallback_to_self:
candidates = [path]
else:
candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0]
candidates = find_files_of_type(
'img', os.path.dirname(path), qualifier=is_audio_file)[0]
assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related.
# Sanity check to ensure we aren't loading "related files" that aren't actually related.
assert len(candidates) < 50000
if len(candidates) == 0:
print(f"No conditioning candidates found for {path}")
raise NotImplementedError()
@ -92,7 +97,8 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
def __init__(self, opt):
path = opt['path']
cache_path = opt['cache_path'] # Will fail when multiple paths specified, must be specified in this case.
# Will fail when multiple paths specified, must be specified in this case.
cache_path = opt['cache_path']
exclusions = []
if 'exclusions' in opt.keys():
for exc in opt['exclusions']:
@ -102,7 +108,8 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
assert isinstance(ew, list)
not_ew = opt_get(opt, ['not_endswith'], [])
assert isinstance(not_ew, list)
self.audiopaths = load_paths_from_cache(path, cache_path, exclusions, endswith=ew, not_endswith=not_ew)
self.audiopaths = load_paths_from_cache(
path, cache_path, exclusions, endswith=ew, not_endswith=not_ew)
# Parse options
self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050)
@ -121,7 +128,8 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
self.extra_samples = opt_get(opt, ['extra_samples'], 0)
self.extra_sample_len = opt_get(opt, ['extra_sample_length'], 44000)
self.debug_loading_failures = opt_get(opt, ['debug_loading_failures'], True)
self.debug_loading_failures = opt_get(
opt, ['debug_loading_failures'], True)
def get_audio_for_index(self, index):
audiopath = self.audiopaths[index]
@ -144,8 +152,9 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
alt_files, alt_is_self = self.get_related_audio_for_index(index)
except:
if self.debug_loading_failures:
print(f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}")
return self[random.randint(0,len(self))]
print(
f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}")
return self[random.randint(0, len(self))]
# When generating resampled clips, skew is a bias that tries to spread them out from each other, reducing their
# influence on one another.
@ -157,10 +166,12 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
for sk in skew:
if self.pad_to is not None:
if audio_norm.shape[-1] <= self.pad_to:
clips.append(torch.nn.functional.pad(audio_norm, (0, self.pad_to - audio_norm.shape[-1])))
clips.append(torch.nn.functional.pad(
audio_norm, (0, self.pad_to - audio_norm.shape[-1])))
else:
gap = audio_norm.shape[-1] - self.pad_to
start = min(max(random.randint(0, gap-1) + sk * gap // 2, 0), gap-1)
start = min(
max(random.randint(0, gap-1) + sk * gap // 2, 0), gap-1)
clips.append(audio_norm[:, start:start+self.pad_to])
else:
clips.append(audio_norm)
@ -200,16 +211,18 @@ if __name__ == '__main__':
'n_workers': 1,
'batch_size': 16,
}
from data import create_dataset, create_dataloader
from data import create_dataloader, create_dataset
ds = create_dataset(params)
dl = create_dataloader(ds, params)
i = 0
for b in tqdm(dl):
for b_ in range(b['clip'].shape[0]):
#pass
torchaudio.save(f'{i}_clip_{b_}.wav', b['clip'][b_], ds.sampling_rate)
torchaudio.save(f'{i}_alt_clip_{b_}.wav', b['alt_clips'][b_], ds.sampling_rate)
# pass
torchaudio.save(f'{i}_clip_{b_}.wav',
b['clip'][b_], ds.sampling_rate)
torchaudio.save(f'{i}_alt_clip_{b_}.wav',
b['alt_clips'][b_], ds.sampling_rate)
i += 1
if i > 200:
break

@ -1,16 +1,17 @@
import json
import re
import torch
import json
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import BpeTrainer
from data.audio.paired_voice_audio_dataset import load_mozilla_cv, load_voxpopuli, load_tsv
from models.audio.tts.tacotron2 import load_filepaths_and_text
from models.audio.tts.tacotron2.text.cleaners import english_cleaners
from dlas.data.audio.paired_voice_audio_dataset import (load_mozilla_cv,
load_tsv,
load_voxpopuli)
from dlas.models.audio.tts.tacotron2 import load_filepaths_and_text
from dlas.models.audio.tts.tacotron2.text.cleaners import english_cleaners
def remove_extraneous_punctuation(word):
@ -21,7 +22,8 @@ def remove_extraneous_punctuation(word):
'': '-', '`': '\'',
'ʼ': '\''
}
replace = re.compile("|".join([re.escape(k) for k in sorted(replacement_punctuation, key=len, reverse=True)]), flags=re.DOTALL)
replace = re.compile("|".join([re.escape(k) for k in sorted(
replacement_punctuation, key=len, reverse=True)]), flags=re.DOTALL)
word = replace.sub(lambda x: replacement_punctuation[x.group(0)], word)
# TODO: some of these are spoken ('@', '%', '+', etc). Integrate them into the cleaners.
@ -29,27 +31,32 @@ def remove_extraneous_punctuation(word):
word = extraneous.sub('', word)
return word
def expand_numbers(text):
return normalize_numbers(text)
return normalize_numbers(text)
def lowercase(text):
return text.lower()
return text.lower()
_whitespace_re = re.compile(r'\s+')
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
return unidecode(text)
def basic_cleaners(text):
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
text = lowercase(text)
text = collapse_whitespace(text)
return text
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
text = lowercase(text)
text = collapse_whitespace(text)
return text
class VoiceBpeTokenizer:
def __init__(self, vocab_file, preprocess=None):
@ -72,23 +79,24 @@ class VoiceBpeTokenizer:
kks = pykakasi.kakasi()
results = kks.convert(txt)
txt = " ".join([ result['kana'] for result in results ])
txt = " ".join([result['kana'] for result in results])
txt = basic_cleaners(txt)
else:
txt = english_cleaners(txt)
return txt
def encode(self, txt):
if self.preprocess:
txt = self.preprocess_text(txt)
txt = self.preprocess_text(txt)
txt = txt.replace(' ', '[SPACE]')
return self.tokenizer.encode(txt).ids
def decode(self, seq):
if isinstance(seq, torch.Tensor):
seq = seq.cpu().numpy()
txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(' ', '')
txt = self.tokenizer.decode(
seq, skip_special_tokens=False).replace(' ', '')
txt = txt.replace('[SPACE]', ' ')
txt = txt.replace('[STOP]', '')
txt = txt.replace('[UNK]', '')
@ -117,10 +125,11 @@ def build_text_file_from_priors(priors, output):
def train():
with open('all_texts.txt', 'r', encoding='utf-8') as at:
ttsd = at.readlines()
#bcd = datasets.load_dataset('bookcorpus', cache_dir='Z:\\huggingface_datasets\\cache')['train']
# bcd = datasets.load_dataset('bookcorpus', cache_dir='Z:\\huggingface_datasets\\cache')['train']
#allowed_characters_re = re.compile(r'^[0-9a-z!@#%_=:;"/, \-\$\^&\*\(\)\+\{\[\]\}\\\.\'\?—–ʼ]+$')
# allowed_characters_re = re.compile(r'^[0-9a-z!@#%_=:;"/, \-\$\^&\*\(\)\+\{\[\]\}\\\.\'\?—–ʼ]+$')
allowed_characters_re = re.compile(r'^[a-z!:;"/, \-\(\)\.\'\?ʼ]+$')
def preprocess_word(word, report=False):
word = english_cleaners(word)
word = remove_extraneous_punctuation(word)
@ -135,16 +144,19 @@ def train():
for i in range(0, len(ttsd), batch_size):
yield [preprocess_word(t, True) for t in ttsd[i:i+batch_size]]
#print("Processing bookcorpus.")
#for i in range(0, len(bcd), batch_size):
# print("Processing bookcorpus.")
# for i in range(0, len(bcd), batch_size):
# yield [preprocess_word(t) for t in bcd[i:i+batch_size]['text']]
trainer = BpeTrainer(special_tokens=['[STOP]', '[UNK]', '[SPACE]'], vocab_size=255)
trainer = BpeTrainer(
special_tokens=['[STOP]', '[UNK]', '[SPACE]'], vocab_size=255)
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
tokenizer.train_from_iterator(batch_iterator(), trainer, length=len(ttsd))#+len(bcd))
tokenizer.train_from_iterator(
batch_iterator(), trainer, length=len(ttsd)) # +len(bcd))
print(tokenizer.decode(tokenizer.encode("i was traveling throughhadslfghds the woods in 1235375t137{{}}").ids))
print(tokenizer.decode(tokenizer.encode(
"i was traveling throughhadslfghds the woods in 1235375t137{{}}").ids))
tokenizer.save('gpt_tts_tokenizer.json')
@ -171,5 +183,5 @@ if __name__ == '__main__':
('Y:\\clips\\books2-transcribed.tsv', 'tsv'),
('Y:\\clips\\podcasts-0-transcribed.tsv', 'tsv')], 'all_texts.txt')
'''
#train()
# train()
test()

@ -3,19 +3,19 @@ import random
import torch
import torchaudio.sox_effects
from models.audio.tts.tacotron2.taco_utils import load_wav_to_torch
from dlas.models.audio.tts.tacotron2.taco_utils import load_wav_to_torch
# Returns random double on [l,h] as a string
def rdstr(l=0,h=1):
def rdstr(l=0, h=1):
assert h > l
i=h-l
i = h-l
return str(random.random() * i + l)
# Returns a randint on [s,e] as a string
def rdi(e, s=0):
return str(random.randint(s,e))
return str(random.randint(s, e))
class WavAugmentor:
@ -43,12 +43,13 @@ class WavAugmentor:
band_effect = random.choice(band_effects)
'''
volume_effects = [
['loudness', rdi(10,-2)],
['overdrive', rdi(20,0), rdi(20,0)],
['loudness', rdi(10, -2)],
['overdrive', rdi(20, 0), rdi(20, 0)],
]
vol_effect = random.choice(volume_effects)
effects = [speed_effect, vol_effect]
out, sr = torchaudio.sox_effects.apply_effects_tensor(wav, sample_rate, effects)
out, sr = torchaudio.sox_effects.apply_effects_tensor(
wav, sample_rate, effects)
# Add a variable amount of noise
out = out + torch.rand_like(out) * random.random() * .03
return out
@ -60,4 +61,4 @@ if __name__ == '__main__':
aug = WavAugmentor()
for j in range(10):
out = aug.augment(sample, 24000)
torchaudio.save(f'out{j}.wav', out, 24000)
torchaudio.save(f'out{j}.wav', out, 24000)

@ -1,5 +1,6 @@
import torch
from data import create_dataset
from dlas.data import create_dataset
# Simple composite dataset that combines multiple other datasets.
@ -31,4 +32,4 @@ class CombinedDataset(torch.utils.data.Dataset):
return output
def __len__(self):
return max(len(d) for d in self.datasets.values())
return max(len(d) for d in self.datasets.values())

@ -4,9 +4,10 @@ Support enlarging the dataset for *iteration-oriented* training, for saving time
dataloader after each epoch
"""
import math
import torch
from torch.utils.data.sampler import Sampler
import torch.distributed as dist
from torch.utils.data.sampler import Sampler
class DistIterSampler(Sampler):
@ -30,17 +31,20 @@ class DistIterSampler(Sampler):
def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
raise RuntimeError(
"Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
raise RuntimeError(
"Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas))
self.num_samples = int(
math.ceil(len(self.dataset) * ratio / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):

@ -1,10 +1,13 @@
import torch
from torch.utils import data
from data.images.image_corruptor import ImageCorruptor
from data.images.chunk_with_reference import ChunkWithReference
import os
import cv2
import numpy as np
import torch
from torch.utils import data
from dlas.data.images.chunk_with_reference import ChunkWithReference
from dlas.data.images.image_corruptor import ImageCorruptor
# Class whose purpose is to hold as much logic as can possibly be shared between datasets that operate on raw image
# data and nothing else (which also have a very specific directory structure being used, as dictated by
@ -13,13 +16,17 @@ class BaseUnsupervisedImageDataset(data.Dataset):
def __init__(self, opt):
self.opt = opt
self.corruptor = ImageCorruptor(opt)
self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys() else None
self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys() else 1
self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys(
) else None
self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys(
) else 1
self.for_eval = opt['eval'] if 'eval' in opt.keys() else False
self.scale = opt['scale'] if not self.for_eval else 1
self.paths = opt['paths']
self.corrupt_before_downsize = opt['corrupt_before_downsize'] if 'corrupt_before_downsize' in opt.keys() else False
assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors.
self.corrupt_before_downsize = opt['corrupt_before_downsize'] if 'corrupt_before_downsize' in opt.keys(
) else False
# If we dont throw here, we get some really obscure errors.
assert (self.target_hq_size // self.scale) % self.multiple == 0
if not isinstance(self.paths, list):
self.paths = [self.paths]
self.weights = [1]
@ -34,8 +41,10 @@ class BaseUnsupervisedImageDataset(data.Dataset):
if os.path.exists(cache_path):
chunks = torch.load(cache_path)
else:
print("Building chunk cache, this can take some time for large datasets..")
chunks = [ChunkWithReference(opt, d) for d in sorted(os.scandir(path), key=lambda e: e.name) if d.is_dir()]
print(
"Building chunk cache, this can take some time for large datasets..")
chunks = [ChunkWithReference(opt, d) for d in sorted(
os.scandir(path), key=lambda e: e.name) if d.is_dir()]
# Prune out chunks that have no images
res = []
for c in chunks:
@ -60,7 +69,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
for c in self.chunks:
paths.extend(c.tiles)
return paths
# Utility method for translating a point when the dimensions of an image change.
def resize_point(self, point, orig_dim, new_dim):
oh, ow = orig_dim
@ -78,20 +87,27 @@ class BaseUnsupervisedImageDataset(data.Dataset):
for hq, hq_ref, hq_mask, hq_center in zip(imgs_hq, refs_hq, masks_hq, centers_hq):
# It is assumed that the target size is a square.
target_size = (self.target_hq_size, self.target_hq_size)
hqs_adjusted.append(cv2.resize(hq, target_size, interpolation=cv2.INTER_AREA))
hq_refs_adjusted.append(cv2.resize(hq_ref, target_size, interpolation=cv2.INTER_AREA))
hq_masks_adjusted.append(cv2.resize(hq_mask, target_size, interpolation=cv2.INTER_AREA))
hq_centers_adjusted.append(self.resize_point(hq_center, (h, w), target_size))
hqs_adjusted.append(cv2.resize(
hq, target_size, interpolation=cv2.INTER_AREA))
hq_refs_adjusted.append(cv2.resize(
hq_ref, target_size, interpolation=cv2.INTER_AREA))
hq_masks_adjusted.append(cv2.resize(
hq_mask, target_size, interpolation=cv2.INTER_AREA))
hq_centers_adjusted.append(
self.resize_point(hq_center, (h, w), target_size))
h, w = self.target_hq_size, self.target_hq_size
else:
hqs_adjusted, hq_refs_adjusted, hq_masks_adjusted, hq_centers_adjusted = imgs_hq, refs_hq, masks_hq, centers_hq
hq_masks_adjusted = [m.squeeze(-1) for m in hq_masks_adjusted] # This is done implicitly above..
hq_multiple = self.multiple * self.scale # Multiple must apply to LQ image.
# This is done implicitly above..
hq_masks_adjusted = [m.squeeze(-1) for m in hq_masks_adjusted]
# Multiple must apply to LQ image.
hq_multiple = self.multiple * self.scale
if h % hq_multiple != 0 or w % hq_multiple != 0:
hqs_conformed, hq_refs_conformed, hq_masks_conformed, hq_centers_conformed = [], [], [], []
for hq, hq_ref, hq_mask, hq_center in zip(hqs_adjusted, hq_refs_adjusted, hq_masks_adjusted, hq_centers_adjusted):
h, w = (h - h % hq_multiple), (w - w % hq_multiple)
hq_centers_conformed.append(self.resize_point(hq_center, hq.shape[:2], (h, w)))
hq_centers_conformed.append(
self.resize_point(hq_center, hq.shape[:2], (h, w)))
hqs_conformed.append(hq[:h, :w, :])
hq_refs_conformed.append(hq_ref[:h, :w, :])
hq_masks_conformed.append(hq_mask[:h, :w, :])
@ -110,10 +126,14 @@ class BaseUnsupervisedImageDataset(data.Dataset):
lms.append(hq_mask)
lcs.append(hq_center)
else:
ls.append(cv2.resize(hq, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA))
lrs.append(cv2.resize(hq_ref, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA))
lms.append(cv2.resize(hq_mask, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA))
lcs.append(self.resize_point(hq_center, (h, w), ls[0].shape[:2]))
ls.append(cv2.resize(hq, (h // self.scale, w //
self.scale), interpolation=cv2.INTER_AREA))
lrs.append(cv2.resize(hq_ref, (h // self.scale, w //
self.scale), interpolation=cv2.INTER_AREA))
lms.append(cv2.resize(hq_mask, (h // self.scale, w //
self.scale), interpolation=cv2.INTER_AREA))
lcs.append(self.resize_point(
hq_center, (h, w), ls[0].shape[:2]))
# Corrupt the LQ image (only in eval mode)
if not self.corrupt_before_downsize and not self.for_eval:
ls = self.corruptor.corrupt_images(ls)

@ -3,22 +3,20 @@ from time import time
import kornia
import numpy as np
import torch
import torchvision
from torch.utils.data import Dataset
from kornia import augmentation as augs, geometry
from kornia import filters
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from kornia import augmentation as augs
from kornia import filters, geometry
from torch.utils.data import Dataset
# Wrapper for a DLAS Dataset class that applies random augmentations from the BYOL paper to BOTH the 'lq' and 'hq'
# inputs. These are then outputted as 'aug1' and 'aug2'.
from tqdm import tqdm
from data import create_dataset
from models.arch_util import PixelUnshuffle
from utils.util import opt_get
from dlas.data import create_dataset
from dlas.models.arch_util import PixelUnshuffle
from dlas.utils.util import opt_get
class RandomApply(nn.Module):
@ -26,6 +24,7 @@ class RandomApply(nn.Module):
super().__init__()
self.fn = fn
self.p = p
def forward(self, x):
if random.random() > self.p:
return x
@ -39,9 +38,10 @@ class ByolDatasetWrapper(Dataset):
self.cropped_img_size = opt['crop_size']
self.key1 = opt_get(opt, ['key1'], 'hq')
self.key2 = opt_get(opt, ['key2'], 'lq')
for_sr = opt_get(opt, ['for_sr'], False) # When set, color alterations and blurs are disabled.
# When set, color alterations and blurs are disabled.
for_sr = opt_get(opt, ['for_sr'], False)
augmentations = [ \
augmentations = [
augs.RandomHorizontalFlip(),
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
if not for_sr:
@ -51,12 +51,14 @@ class ByolDatasetWrapper(Dataset):
if opt['normalize']:
# The paper calls for normalization. Most datasets/models in this repo don't use this.
# Recommend setting true if you want to train exactly like the paper.
augmentations.append(augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])))
augmentations.append(augs.Normalize(mean=torch.tensor(
[0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])))
self.aug = nn.Sequential(*augmentations)
def __getitem__(self, item):
item = self.wrapped_dataset[item]
item.update({'aug1': self.aug(item[self.key1]).squeeze(dim=0), 'aug2': self.aug(item[self.key2]).squeeze(dim=0)})
item.update({'aug1': self.aug(item[self.key1]).squeeze(
dim=0), 'aug2': self.aug(item[self.key2]).squeeze(dim=0)})
return item
def __len__(self):
@ -71,7 +73,7 @@ class DatasetRandomAugWrapper(Dataset):
self.wrapped_dataset = create_dataset(opt['dataset'])
self.cropped_img_size = opt['crop_size']
self.includes_labels = opt['includes_labels']
augmentations = [ \
augmentations = [
RandomApply(augs.ColorJitter(0.4, 0.4, 0.4, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)]
@ -87,18 +89,19 @@ class DatasetRandomAugWrapper(Dataset):
dtypes = []
for k in item.keys():
if 'label' in k and isinstance(item[k], torch.Tensor) and len(item[k].shape) == 3:
assert item[k].shape[0] == 1 # Only supports a channel dim of 1.
# Only supports a channel dim of 1.
assert item[k].shape[0] == 1
labels.append(k)
dtypes.append(item[k].dtype)
hq = torch.cat([hq, item[k].type(torch.float)], dim=0)
hq = self.rrc(hq.unsqueeze(0)).squeeze(0)
for i, k in enumerate(labels):
# Strip out any label values that are not whole numbers.
item[k] = hq[3+i:3+i+1,:,:]
item[k] = hq[3+i:3+i+1, :, :]
whole = (item[k].round() == item[k])
item[k] = item[k] * whole
item[k] = item[k].type(dtypes[i])
item['lq'] = hq[:3,:,:]
item['lq'] = hq[:3, :, :]
item['hq'] = item['lq']
return item
@ -137,14 +140,16 @@ def test_dataset_random_aug_wrapper():
for k, v in o.items():
# 'lq', 'hq', 'aug1', 'aug2',
if k in ['hq']:
torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
torchvision.utils.save_image(
v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
masked = v * (o['labels_mask'] * .5 + .5)
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
# torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
# Pick a random (non-zero) label and spit it out with the textual label.
if len(o['labels'].unique()) > 1:
randlbl = np.random.choice(o['labels'].unique()[1:])
moremask = v * ((1*(o['labels'] == randlbl))*.5+.5)
torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s_%s.png" % (i, k, o['label_strings'][randlbl]))
torchvision.utils.save_image(moremask.unsqueeze(
0), "debug/%i_%s_%s.png" % (i, k, o['label_strings'][randlbl]))
def no_batch_interpolate(i, size, mode):
@ -165,10 +170,10 @@ def snap(ref, other):
# Pads a tensor with zeros so that it fits in a dxd square.
def pad_to(im, d):
if len(im.shape) == 3:
pd = torch.zeros((im.shape[0],d,d))
pd = torch.zeros((im.shape[0], d, d))
pd[:, :im.shape[1], :im.shape[2]] = im
else:
pd = torch.zeros((im.shape[0],im.shape[1],d,d), device=im.device)
pd = torch.zeros((im.shape[0], im.shape[1], d, d), device=im.device)
pd[:, :, :im.shape[2], :im.shape[3]] = im
return pd
@ -182,7 +187,8 @@ class RandomSharedRegionCrop(nn.Module):
def __init__(self, multiple, jitter_range=0):
super().__init__()
self.multiple = multiple
self.jitter_range = jitter_range # When specified, images are shifted an additional random([-j,j]) pixels where j=jitter_range
# When specified, images are shifted an additional random([-j,j]) pixels where j=jitter_range
self.jitter_range = jitter_range
def forward(self, i1, i2):
assert i1.shape[-1] == i2.shape[-1]
@ -218,19 +224,25 @@ class RandomSharedRegionCrop(nn.Module):
# Step 4
m = self.multiple
jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
jt = jt if base_t != 0 else abs(jt) # If the top of a patch is zero, a negative jitter will cause it to go negative.
jt = jt if (base_t+base_h)*m != i1.shape[1] else 0 # Likewise, jitter shouldn't allow the patch to go over-bounds.
jl, jt = random.randint(-self.jitter_range,
self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
# If the top of a patch is zero, a negative jitter will cause it to go negative.
jt = jt if base_t != 0 else abs(jt)
# Likewise, jitter shouldn't allow the patch to go over-bounds.
jt = jt if (base_t+base_h)*m != i1.shape[1] else 0
jl = jl if base_l != 0 else abs(jl)
jl = jl if (base_l+base_w)*m != i1.shape[1] else 0
p1 = i1[:, base_t*m+jt:(base_t+base_h)*m+jt, base_l*m+jl:(base_l+base_w)*m+jl]
p1 = i1[:, base_t*m+jt:(base_t+base_h)*m+jt,
base_l*m+jl:(base_l+base_w)*m+jl]
p1_resized = no_batch_interpolate(p1, size=(d*m, d*m), mode="bilinear")
jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
jl, jt = random.randint(-self.jitter_range,
self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
jt = jt if im2_t != 0 else abs(jt)
jt = jt if (im2_t+im2_h)*m != i2.shape[1] else 0
jl = jl if im2_l != 0 else abs(jl)
jl = jl if (im2_l+im2_w)*m != i2.shape[1] else 0
p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt, im2_l*m+jl:(im2_l+im2_w)*m+jl]
p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt,
im2_l*m+jl:(im2_l+im2_w)*m+jl]
p2_resized = no_batch_interpolate(p2, size=(d*m, d*m), mode="bilinear")
# Step 5
@ -246,14 +258,17 @@ class RandomSharedRegionCrop(nn.Module):
i2_shared_t, i2_shared_l = snap(im2_t, base_t), snap(im2_l, base_l)
ix_h = min(base_b, im2_b) - max(base_t, im2_t)
ix_w = min(base_r, im2_r) - max(base_l, im2_l)
recompute_package = torch.tensor([d, base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, should_flip, ix_h, ix_w], dtype=torch.long)
recompute_package = torch.tensor([d, base_h, base_w, i1_shared_t, i1_shared_l, im2_h,
im2_w, i2_shared_t, i2_shared_l, should_flip, ix_h, ix_w], dtype=torch.long)
# Step 7
mask1 = torch.full((1, base_h*m, base_w*m), fill_value=.5)
mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m, i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1
mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m,
i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1
masked1 = pad_to(p1 * mask1, d*m)
mask2 = torch.full((1, im2_h*m, im2_w*m), fill_value=.5)
mask2[:, i2_shared_t*m:(i2_shared_t+ix_h)*m, i2_shared_l*m:(i2_shared_l+ix_w)*m] = 1
mask2[:, i2_shared_t*m:(i2_shared_t+ix_h)*m,
i2_shared_l*m:(i2_shared_l+ix_w)*m] = 1
masked2 = pad_to(p2 * mask2, d*m)
mask = torch.full((1, d*m, d*m), fill_value=.33)
mask[:, base_t*m:(base_t+base_w)*m, base_l*m:(base_l+base_h)*m] += .33
@ -262,10 +277,13 @@ class RandomSharedRegionCrop(nn.Module):
# Step 8 - Rebuild shared regions for testing purposes.
p1_shuf, p2_shuf = PixelUnshuffle(self.multiple)(p1_resized.unsqueeze(0)), \
PixelUnshuffle(self.multiple)(p2_resized.unsqueeze(0))
i1_shared, i2_shared = reconstructed_shared_regions(p1_shuf, p2_shuf, recompute_package.unsqueeze(0))
i1_shared = pad_to(nn.PixelShuffle(self.multiple)(i1_shared).squeeze(0), d * m)
i2_shared = pad_to(nn.PixelShuffle(self.multiple)(i2_shared).squeeze(0), d*m)
PixelUnshuffle(self.multiple)(p2_resized.unsqueeze(0))
i1_shared, i2_shared = reconstructed_shared_regions(
p1_shuf, p2_shuf, recompute_package.unsqueeze(0))
i1_shared = pad_to(nn.PixelShuffle(self.multiple)
(i1_shared).squeeze(0), d * m)
i2_shared = pad_to(nn.PixelShuffle(self.multiple)
(i2_shared).squeeze(0), d*m)
return p1_resized, p2_resized, recompute_package, masked1, masked2, masked_dbg, i1_shared, i2_shared
@ -280,7 +298,8 @@ def reconstructed_shared_regions(fea1, fea2, recompute_package: torch.Tensor):
# It'd be real nice if we could do this at the batch level, but I don't see a really good way to do that outside
# of conforming the recompute_package across the entire batch.
for b in range(package.shape[0]):
expected_dim, f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, should_flip, s_h, s_w = tuple(package[b].tolist())
expected_dim, f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, should_flip, s_h, s_w = tuple(
package[b].tolist())
# If you are hitting this assert, you specified `latent_multiple` in your dataset config wrong.
assert expected_dim == fea1.shape[2] and expected_dim == fea2.shape[2]
@ -292,8 +311,10 @@ def reconstructed_shared_regions(fea1, fea2, recompute_package: torch.Tensor):
f1s = F.interpolate(fea1[b].unsqueeze(0), (f1_h, f1_w), mode="nearest")
f2s = F.interpolate(f2.unsqueeze(0), (f2_h, f2_w), mode="nearest")
# Outputs must be padded so they can "get along" with each other.
res1.append(pad_to(f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w], pad_dim))
res2.append(pad_to(f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w], pad_dim))
res1.append(
pad_to(f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w], pad_dim))
res2.append(
pad_to(f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w], pad_dim))
return torch.cat(res1, dim=0), torch.cat(res2, dim=0)
@ -308,10 +329,11 @@ class StructuredCropDatasetWrapper(Dataset):
super().__init__()
self.wrapped_dataset = create_dataset(opt['dataset'])
augmentations = [RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)]
augs.RandomGrayscale(p=0.2),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)]
self.aug = nn.Sequential(*augmentations)
self.rrc = RandomSharedRegionCrop(opt['latent_multiple'], opt_get(opt, ['jitter_range'], 0))
self.rrc = RandomSharedRegionCrop(
opt['latent_multiple'], opt_get(opt, ['jitter_range'], 0))
def __getitem__(self, item):
item = self.wrapped_dataset[item]
@ -332,17 +354,17 @@ def test_structured_crop_dataset_wrapper():
opt = {
'dataset':
{
'mode': 'imagefolder',
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'],
'weights': [1],
'target_size': 256,
'force_multiple': 32,
'scale': 1,
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
'random_corruptions': ['noise-5', 'none'],
'num_corrupts_per_image': 1,
'corrupt_before_downsize': True,
'mode': 'imagefolder',
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'],
'weights': [1],
'target_size': 256,
'force_multiple': 32,
'scale': 1,
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
'random_corruptions': ['noise-5', 'none'],
'num_corrupts_per_image': 1,
'corrupt_before_downsize': True,
},
'latent_multiple': 16,
'jitter_range': 0,
@ -353,16 +375,17 @@ def test_structured_crop_dataset_wrapper():
os.makedirs("debug", exist_ok=True)
for i in tqdm(range(0, len(ds))):
o = ds[random.randint(0, len(ds)-1)]
#for k, v in o.items():
# 'lq', 'hq', 'aug1', 'aug2',
#if k in [ 'aug_shared_view', 'masked1', 'masked2']:
#torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
# for k, v in o.items():
# 'lq', 'hq', 'aug1', 'aug2',
# if k in [ 'aug_shared_view', 'masked1', 'masked2']:
# torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
rcpkg = o['similar_region_dimensions']
pixun = PixelUnshuffle(16)
pixsh = nn.PixelShuffle(16)
rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(0)), pixun(o['aug2'].unsqueeze(0)), rcpkg.unsqueeze(0))
#torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,))
#torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,))
rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(
0)), pixun(o['aug2'].unsqueeze(0)), rcpkg.unsqueeze(0))
# torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,))
# torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,))
if __name__ == '__main__':

@ -1,17 +1,19 @@
import os.path as osp
from data import util
import torch
import numpy as np
import torch
from dlas.data import util
# Iterable that reads all the images in a directory that contains a reference image, tile images and center coordinates.
from utils.util import opt_get
from dlas.utils.util import opt_get
class ChunkWithReference:
def __init__(self, opt, path):
self.path = path.path
self.tiles, _ = util.find_files_of_type('img', self.path)
self.need_metadata = opt_get(opt, ['strict'], False) or opt_get(opt, ['needs_metadata'], False)
self.need_metadata = opt_get(opt, ['strict'], False) or opt_get(
opt, ['needs_metadata'], False)
self.need_ref = opt_get(opt, ['need_ref'], False)
if 'ignore_first' in opt.keys():
self.tiles = self.tiles[opt['ignore_first']:]
@ -41,12 +43,14 @@ class ChunkWithReference:
else:
center = torch.tensor([128, 128], dtype=torch.long)
tile_width = 256
mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype)
mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
mask = np.full(tile.shape[:2] + (1,),
fill_value=.1, dtype=tile.dtype)
mask[center[0] - tile_width // 2:center[0] + tile_width // 2,
center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
else:
ref = np.zeros_like(tile)
mask = np.zeros(tile.shape[:2] + (1,))
center = (0,0)
center = (0, 0)
return tile, ref, center, mask, self.tiles[item]

@ -1,14 +1,15 @@
# A copy of the cifar dataset from torch which also returns coarse labels.
from PIL import Image
import os
import os.path
import numpy as np
import pickle
from typing import Any, Callable, Optional, Tuple
import numpy as np
from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import check_integrity, download_and_extract_archive
from torchvision.datasets.utils import (check_integrity,
download_and_extract_archive)
class CIFAR10(VisionDataset):
@ -104,7 +105,8 @@ class CIFAR10(VisionDataset):
with open(path, 'rb') as infile:
data = pickle.load(infile, encoding='latin1')
self.classes = data[self.meta['key']]
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
self.class_to_idx = {_class: i for i,
_class in enumerate(self.classes)}
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
@ -147,7 +149,8 @@ class CIFAR10(VisionDataset):
if self._check_integrity():
print('Files already downloaded and verified')
return
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
download_and_extract_archive(
self.url, self.root, filename=self.filename, md5=self.tgz_md5)
def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test")

@ -1,12 +1,14 @@
import random
import numpy as np
from io import BytesIO
import cv2
import numpy as np
import torch
import torch.utils.data as data
import data.util as util
from PIL import Image, ImageOps
from io import BytesIO
import torchvision.transforms.functional as F
from PIL import Image, ImageOps
import dlas.data.util as util
# Reads full-quality images and pulls tiles from them. Also extracts LR renderings of the full image with cues as to
@ -16,6 +18,7 @@ class FullImageDataset(data.Dataset):
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
If only GT images are provided, generate LQ images on-the-fly.
"""
def get_lq_path(self, i):
which_lq = random.randint(0, len(self.paths_LQ)-1)
return self.paths_LQ[which_lq][i % len(self.paths_LQ[which_lq])]
@ -27,19 +30,23 @@ class FullImageDataset(data.Dataset):
self.paths_LQ, self.paths_GT = None, None
self.sizes_LQ, self.sizes_GT = None, None
self.LQ_env, self.GT_env = None, None
self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1
self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys(
) else 1
self.paths_GT, self.sizes_GT = util.find_files_of_type(self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights'])
self.paths_GT, self.sizes_GT = util.find_files_of_type(
self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights'])
if 'dataroot_LQ' in opt.keys():
self.paths_LQ = []
if isinstance(opt['dataroot_LQ'], list):
# Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and
# we want the model to learn them all.
for dr_lq in opt['dataroot_LQ']:
lq_path, self.sizes_LQ = util.find_files_of_type(self.data_type, dr_lq)
lq_path, self.sizes_LQ = util.find_files_of_type(
self.data_type, dr_lq)
self.paths_LQ.append(lq_path)
else:
lq_path, self.sizes_LQ = util.find_files_of_type(self.data_type, opt['dataroot_LQ'])
lq_path, self.sizes_LQ = util.find_files_of_type(
self.data_type, opt['dataroot_LQ'])
self.paths_LQ.append(lq_path)
assert self.paths_GT, 'Error: GT path is empty.'
@ -48,7 +55,8 @@ class FullImageDataset(data.Dataset):
def motion_blur(self, image, size, angle):
k = np.zeros((size, size), dtype=np.float32)
k[(size - 1) // 2, :] = np.ones(size, dtype=np.float32)
k = cv2.warpAffine(k, cv2.getRotationMatrix2D((size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size))
k = cv2.warpAffine(k, cv2.getRotationMatrix2D(
(size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size))
k = k * (1.0 / np.sum(k))
return cv2.filter2D(image, -1, k)
@ -100,8 +108,10 @@ class FullImageDataset(data.Dataset):
if 'fixed_size' in self.opt.keys() and self.opt['fixed_size']:
square_size = target_sz
else:
tile_expansion_dev = self.opt['tile_scale_normal_stddev'] if 'tile_scale_normal_stddev' in self.opt.keys() else .17
square_size = int(target_sz + possible_sizes_above_target * min(np.abs(np.random.normal(scale=tile_expansion_dev)), 1.0))
tile_expansion_dev = self.opt['tile_scale_normal_stddev'] if 'tile_scale_normal_stddev' in self.opt.keys(
) else .17
square_size = int(target_sz + possible_sizes_above_target *
min(np.abs(np.random.normal(scale=tile_expansion_dev)), 1.0))
# Pick the left,top coords to draw the patch from
left = self.pick_along_range(w, square_size, .3)
@ -110,11 +120,15 @@ class FullImageDataset(data.Dataset):
mask = np.zeros((h, w, 1), dtype=image.dtype)
mask[top:top+square_size, left:left+square_size] = 1
patch = image[top:top+square_size, left:left+square_size, :]
center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long)
patch = cv2.resize(patch, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
image = cv2.resize(image, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
mask = cv2.resize(mask, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
center = torch.tensor(
[top + square_size // 2, left + square_size // 2], dtype=torch.long)
patch = cv2.resize(patch, (target_sz, target_sz),
interpolation=cv2.INTER_LINEAR)
image = cv2.resize(image, (target_sz, target_sz),
interpolation=cv2.INTER_LINEAR)
mask = cv2.resize(mask, (target_sz, target_sz),
interpolation=cv2.INTER_LINEAR)
center = self.resize_point(center, (h, w), image.shape[:2])
return patch, image, mask, center
@ -127,19 +141,24 @@ class FullImageDataset(data.Dataset):
assert H >= GT_size and W >= GT_size
LQ_size = GT_size // scale
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size),
interpolation=cv2.INTER_LINEAR)
img_GT = cv2.resize(img_GT, (GT_size, GT_size),
interpolation=cv2.INTER_LINEAR)
if self.opt['use_blurring']:
# Pick randomly between gaussian, motion, or no blur.
blur_det = random.randint(0, 100)
blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude']
blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys(
) else self.opt['blur_magnitude']
blur_magnitude = max(1, int(blur_magnitude*strength))
if blur_det < 40:
blur_sig = int(random.randrange(0, int(blur_magnitude)))
img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig)
img_LQ = cv2.GaussianBlur(
img_LQ, (blur_magnitude, blur_magnitude), blur_sig)
elif blur_det < 70:
img_LQ = self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), random.randint(0, 360))
img_LQ = self.motion_blur(img_LQ, random.randrange(
1, int(blur_magnitude) * 3), random.randint(0, 360))
return img_GT, img_LQ
@ -174,13 +193,15 @@ class FullImageDataset(data.Dataset):
# Gaussian Blur (point or motion)
blur_magnitude = 3
blur_sig = int(random.randrange(0, int(blur_magnitude)))
image = cv2.GaussianBlur(image, (blur_magnitude, blur_magnitude), blur_sig)
image = cv2.GaussianBlur(
image, (blur_magnitude, blur_magnitude), blur_sig)
elif 2 in aug_code:
# Median Blur
image = cv2.medianBlur(image, 3)
elif 3 in aug_code:
# Motion blur
image = self.motion_blur(image, random.randrange(1, 9), random.randint(0, 360))
image = self.motion_blur(
image, random.randrange(1, 9), random.randint(0, 360))
elif 4 in aug_code:
# Smooth blur
image = cv2.blur(image, ksize=3)
@ -217,15 +238,19 @@ class FullImageDataset(data.Dataset):
full_path = self.paths_GT[index % len(self.paths_GT)]
LQ_path = full_path
img_full = util.read_img(None, full_path, None)
img_full = util.channel_convert(img_full.shape[2], 'RGB', [img_full])[0]
img_full = util.channel_convert(
img_full.shape[2], 'RGB', [img_full])[0]
if self.opt['phase'] == 'train':
img_full = util.augment([img_full], self.opt['use_flip'], self.opt['use_rot'])[0]
img_full = util.augment(
[img_full], self.opt['use_flip'], self.opt['use_rot'])[0]
img_full = self.get_square_image(img_full)
img_GT, gt_fullsize_ref, gt_mask, gt_center = self.pull_tile(img_full)
img_GT, gt_fullsize_ref, gt_mask, gt_center = self.pull_tile(
img_full)
else:
img_GT, gt_fullsize_ref = img_full, img_full
gt_mask = np.ones(img_full.shape[:2], dtype=gt_fullsize_ref.dtype)
gt_center = torch.tensor([img_full.shape[0] // 2, img_full.shape[1] // 2], dtype=torch.long)
gt_center = torch.tensor(
[img_full.shape[0] // 2, img_full.shape[1] // 2], dtype=torch.long)
orig_gt_dim = gt_fullsize_ref.shape[:2]
# get LQ image
@ -233,13 +258,17 @@ class FullImageDataset(data.Dataset):
LQ_path = self.get_lq_path(index)
img_lq_full = util.read_img(None, LQ_path, None)
if self.opt['phase'] == 'train':
img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0]
img_lq_full = util.augment(
[img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0]
img_lq_full = self.get_square_image(img_lq_full)
img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(img_lq_full, lq=True)
img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(
img_lq_full, lq=True)
else:
img_LQ, lq_fullsize_ref = img_lq_full, img_lq_full
lq_mask = np.ones(img_lq_full.shape[:2], dtype=lq_fullsize_ref.dtype)
lq_center = torch.tensor([img_lq_full.shape[0] // 2, img_lq_full.shape[1] // 2], dtype=torch.long)
lq_mask = np.ones(
img_lq_full.shape[:2], dtype=lq_fullsize_ref.dtype)
lq_center = torch.tensor(
[img_lq_full.shape[0] // 2, img_lq_full.shape[1] // 2], dtype=torch.long)
else: # down-sampling on-the-fly
# randomly scale during training
if self.opt['phase'] == 'train':
@ -258,7 +287,8 @@ class FullImageDataset(data.Dataset):
H_s = _mod(H_s, random_scale, scale, GT_size)
W_s = _mod(W_s, random_scale, scale, GT_size)
img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
img_GT = cv2.resize(img_GT, (W_s, H_s),
interpolation=cv2.INTER_LINEAR)
if img_GT.ndim == 2:
img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
@ -266,10 +296,12 @@ class FullImageDataset(data.Dataset):
# using matlab imresize
img_LQ = util.imresize_np(img_GT, 1 / scale, True)
lq_fullsize_ref = util.imresize_np(gt_fullsize_ref, 1 / scale, True)
lq_fullsize_ref = util.imresize_np(
gt_fullsize_ref, 1 / scale, True)
if img_LQ.ndim == 2:
img_LQ = np.expand_dims(img_LQ, axis=2)
lq_mask, lq_center = gt_mask, self.resize_point(gt_center.clone(), orig_gt_dim, lq_fullsize_ref.shape[:2])
lq_mask, lq_center = gt_mask, self.resize_point(
gt_center.clone(), orig_gt_dim, lq_fullsize_ref.shape[:2])
orig_lq_dim = lq_fullsize_ref.shape[:2]
# Enforce force_resize constraints via clipping.
@ -285,15 +317,20 @@ class FullImageDataset(data.Dataset):
if self.opt['phase'] == 'train':
img_GT, img_LQ = self.augment_tile(img_GT, img_LQ)
gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(gt_fullsize_ref, lq_fullsize_ref, strength=.2)
gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(
gt_fullsize_ref, lq_fullsize_ref, strength=.2)
# Scale masks.
lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
gt_mask = cv2.resize(gt_mask, (gt_fullsize_ref.shape[1], gt_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
lq_mask = cv2.resize(
lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
gt_mask = cv2.resize(
gt_mask, (gt_fullsize_ref.shape[1], gt_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
# Scale center coords
lq_center = self.resize_point(lq_center, orig_lq_dim, lq_fullsize_ref.shape[:2])
gt_center = self.resize_point(gt_center, orig_gt_dim, gt_fullsize_ref.shape[:2])
lq_center = self.resize_point(
lq_center, orig_lq_dim, lq_fullsize_ref.shape[:2])
gt_center = self.resize_point(
gt_center, orig_gt_dim, gt_fullsize_ref.shape[:2])
# BGR to RGB, HWC to CHW, numpy to tensor
if img_GT.shape[2] == 3:
@ -303,16 +340,20 @@ class FullImageDataset(data.Dataset):
gt_fullsize_ref = cv2.cvtColor(gt_fullsize_ref, cv2.COLOR_BGR2RGB)
# LQ needs to go to a PIL image to perform the compression-artifact transformation.
#if self.opt['phase'] == 'train':
#img_LQ = self.pil_augment(img_LQ)
#lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2)
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float()
# if self.opt['phase'] == 'train':
# img_LQ = self.pil_augment(img_LQ)
# lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2)
img_GT = torch.from_numpy(np.ascontiguousarray(
np.transpose(img_GT, (2, 0, 1)))).float()
gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(
np.transpose(gt_fullsize_ref, (2, 0, 1)))).float()
img_LQ = F.to_tensor(img_LQ)
lq_fullsize_ref = F.to_tensor(lq_fullsize_ref)
lq_mask = torch.from_numpy(np.ascontiguousarray(lq_mask)).unsqueeze(dim=0)
gt_mask = torch.from_numpy(np.ascontiguousarray(gt_mask)).unsqueeze(dim=0)
lq_mask = torch.from_numpy(
np.ascontiguousarray(lq_mask)).unsqueeze(dim=0)
gt_mask = torch.from_numpy(
np.ascontiguousarray(gt_mask)).unsqueeze(dim=0)
if 'lq_noise' in self.opt.keys():
lq_noise = torch.randn_like(img_LQ) * self.opt['lq_noise'] / 255
@ -331,6 +372,7 @@ class FullImageDataset(data.Dataset):
def __len__(self):
return len(self.paths_GT)
if __name__ == '__main__':
'''
opt = {
@ -365,10 +407,10 @@ if __name__ == '__main__':
o = ds[i]
for k, v in o.items():
if 'path' not in k:
#if 'full' in k:
#masked = v[:3, :, :] * v[3]
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
#v = v[:3, :, :]
#import torchvision
#torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
# if 'full' in k:
# masked = v[:3, :, :] * v[3]
# torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
# v = v[:3, :, :]
# import torchvision
# torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
pass

@ -1,20 +1,18 @@
import functools
import random
from io import BytesIO
from math import cos, pi
import cv2
import kornia
import numpy as np
import torch
from kornia.augmentation import ColorJitter
from data.util import read_img
from kornia.augmentation import ColorJitter
from PIL import Image
from io import BytesIO
# Get a rough visualization of the above distribution. (Y-axis is meaningless, just spreads data)
from utils.util import opt_get
from dlas.utils.util import opt_get
'''
if __name__ == '__main__':
@ -29,9 +27,9 @@ if __name__ == '__main__':
def kornia_color_jitter_numpy(img, setting):
if setting * 255 > 1:
# I'm using Kornia's ColorJitter, which requires pytorch arrays in b,c,h,w format.
img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
img = ColorJitter(setting, setting, setting, setting)(img)
img = img.squeeze(0).permute(1,2,0).numpy()
img = img.squeeze(0).permute(1, 2, 0).numpy()
return img
@ -41,14 +39,18 @@ class ImageCorruptor:
def __init__(self, opt):
self.opt = opt
self.reset_random()
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1
self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else []
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0
self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys(
) else 1
self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else [
]
self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys(
) else 0
self.cosine_bias = opt_get(opt, ['cosine_bias'], True)
if self.num_corrupts == 0:
return
else:
self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else []
self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else [
]
def reset_random(self):
if 'random_seed' in self.opt.keys():
@ -75,7 +77,8 @@ class ImageCorruptor:
if self.num_corrupts == 0:
augmentations = []
else:
augmentations = random.choices(self.random_corruptions, k=self.num_corrupts)
augmentations = random.choices(
self.random_corruptions, k=self.num_corrupts)
# Sources of entropy
corrupted_imgs = []
@ -99,7 +102,6 @@ class ImageCorruptor:
img = ufn(img)
corrupted_imgs.append(img)
if return_entropy:
return corrupted_imgs, entropy
else:
@ -119,11 +121,11 @@ class ImageCorruptor:
setting = rand_val * (hi_end - lo_end) + lo_end
img = kornia_color_jitter_numpy(img, setting)
elif 'gaussian_blur' in aug:
img = cv2.GaussianBlur(img, (0,0), self.blur_scale*rand_val*1.5)
img = cv2.GaussianBlur(img, (0, 0), self.blur_scale*rand_val*1.5)
elif 'motion_blur' in aug:
# Motion blur
intensity = self.blur_scale*rand_val * 3 + 1
angle = random.randint(0,360)
angle = random.randint(0, 360)
k = np.zeros((intensity, intensity), dtype=np.float32)
k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32)
k = cv2.warpAffine(k, cv2.getRotationMatrix2D((intensity / 2 - 0.5, intensity / 2 - 0.5), angle, 1.0),
@ -145,10 +147,13 @@ class ImageCorruptor:
else:
scale = 4
if scale > 1:
interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4]
mode = random.randint(0,4) % len(interpolation_modes)
interpolation_modes = [
cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4]
mode = random.randint(0, 4) % len(interpolation_modes)
# Downsample first, then upsample using the random mode.
img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=mode)
img = cv2.resize(img, dsize=(
img.shape[1]//scale, img.shape[0]//scale), interpolation=mode)
def lq_resampling_undo_fn(scale, img):
return cv2.resize(img, dsize=(img.shape[1]*scale, img.shape[0]*scale), interpolation=cv2.INTER_LINEAR)
undo_fn = functools.partial(lq_resampling_undo_fn, scale)
@ -171,22 +176,23 @@ class ImageCorruptor:
elif 'jpeg' in aug:
if 'noise' not in applied_augmentations and 'noise-5' not in applied_augmentations:
if aug == 'jpeg':
lo=10
range=20
lo = 10
range = 20
elif aug == 'jpeg-low':
lo=15
range=10
lo = 15
range = 10
elif aug == 'jpeg-medium':
lo=23
range=25
lo = 23
range = 25
elif aug == 'jpeg-broad':
lo=15
range=60
lo = 15
range = 60
elif aug == 'jpeg-normal':
lo=47
range=35
lo = 47
range = 35
else:
raise NotImplementedError("specified jpeg corruption doesn't exist")
raise NotImplementedError(
"specified jpeg corruption doesn't exist")
# JPEG compression
qf = (int((1-rand_val)*range) + lo)
# Use PIL to perform a mock compression to a data buffer, then swap back to cv2.
@ -195,14 +201,15 @@ class ImageCorruptor:
buffer = BytesIO()
img.save(buffer, "JPEG", quality=qf, optimize=True)
buffer.seek(0)
jpeg_img_bytes = np.asarray(bytearray(buffer.read()), dtype="uint8")
jpeg_img_bytes = np.asarray(
bytearray(buffer.read()), dtype="uint8")
img = read_img("buffer", jpeg_img_bytes, rgb=True)
elif 'saturation' in aug:
# Lightening / saturation
saturation = rand_val * .3
img = np.clip(img + saturation, a_max=1, a_min=0)
elif 'greyscale' in aug:
img = np.tile(np.mean(img, axis=2, keepdims=True), [1,1,3])
img = np.tile(np.mean(img, axis=2, keepdims=True), [1, 1, 3])
elif 'none' not in aug:
raise NotImplementedError("Augmentation doesn't exist")

@ -1,21 +1,21 @@
import functools
import os
import random
import cv2
import numpy as np
import torch
import os
import torchvision
from data import util
from torch.utils.data import DataLoader
from torchvision.transforms import Normalize
from tqdm import tqdm
from data import util
# Builds a dataset created from a simple folder containing a list of training/test/validation images.
from data.images.image_corruptor import ImageCorruptor, kornia_color_jitter_numpy
from data.images.image_label_parser import VsNetImageLabeler
from utils.util import opt_get
from dlas.data.images.image_corruptor import (ImageCorruptor,
kornia_color_jitter_numpy)
from dlas.data.images.image_label_parser import VsNetImageLabeler
from dlas.utils.util import opt_get
def ndarray_center_crop(crop, img):

@ -50,15 +50,15 @@ class VsNetImageLabeler:
def get_labels_as_tensor(self, hq, img_key, resize_factor):
_, h, w = hq.shape
labels = torch.zeros((1,h,w), dtype=torch.long)
mask = torch.zeros((1,h,w), dtype=torch.float)
labels = torch.zeros((1, h, w), dtype=torch.long)
mask = torch.zeros((1, h, w), dtype=torch.float)
lbl_list = self.labeled_images[img_key]
for patch_lbl in lbl_list:
t, l, h, w = patch_lbl['patch_top'] // resize_factor, patch_lbl['patch_left'] // resize_factor, \
patch_lbl['patch_height'] // resize_factor, patch_lbl['patch_width'] // resize_factor
patch_lbl['patch_height'] // resize_factor, patch_lbl['patch_width'] // resize_factor
val = patch_lbl['labelValue']
labels[:,t:t+h,l:l+w] = val
mask[:,t:t+h,l:l+w] = 1.0
labels[:, t:t+h, l:l+w] = val
mask[:, t:t+h, l:l+w] = 1.0
return labels, mask, self.str_labels
def add_label(self, binding, img_name, top, left, dim):
@ -100,22 +100,23 @@ class CompactJsonLabeler:
assert self.config == parsed['config']
assert self.labels == parsed['labels']
assert self.label_map == parsed['label_map']
self.images.update(parsed['images']) # This will overwrite existing images, which is acceptable.
# This will overwrite existing images, which is acceptable.
self.images.update(parsed['images'])
def get_labeled_paths(self, base_path):
return [os.path.join(base_path, pth) for pth in self.images.keys()]
def get_labels_as_tensor(self, hq, img_key, resize_factor):
_, h, w = hq.shape
labels = torch.zeros((1,h,w), dtype=torch.long)
mask = torch.zeros((1,h,w), dtype=torch.float)
labels = torch.zeros((1, h, w), dtype=torch.long)
mask = torch.zeros((1, h, w), dtype=torch.float)
lbl_list = self.images[img_key]
for patch_lbl in lbl_list:
t, l, h, w = patch_lbl['top'] // resize_factor, patch_lbl['left'] // resize_factor, \
self.config['dim'] // resize_factor, self.config['dim'] // resize_factor
self.config['dim'] // resize_factor, self.config['dim'] // resize_factor
val = patch_lbl['labelValue']
labels[:,t:t+h,l:l+w] = val
mask[:,t:t+h,l:l+w] = 1.0
labels[:, t:t+h, l:l+w] = val
mask[:, t:t+h, l:l+w] = 1.0
return labels, mask, self.str_labels
def add_label(self, binding, img_name, top, left, dim):

@ -1,13 +1,12 @@
import torch
import os
import torch
import torchvision
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm import tqdm
# Builds a dataset created from a simple folder containing a list of training/test/validation images.
@ -15,33 +14,42 @@ class ImagePairWithCorrespondingPointsDataset(Dataset):
def __init__(self, opt):
self.opt = opt
self.path = opt['path']
self.pairs = list(filter(lambda f: not os.path.isdir(f), os.listdir(self.path)))
self.pairs = list(
filter(lambda f: not os.path.isdir(f), os.listdir(self.path)))
self.transforms = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
transforms.Normalize(
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
self.size = opt['size']
def __getitem__(self, item):
dir = self.pairs[item]
img1 = self.transforms(Image.open(os.path.join(self.path, dir, "1.jpg")))
img2 = self.transforms(Image.open(os.path.join(self.path, dir, "2.jpg")))
coords1, coords2 = torch.load(os.path.join(self.path, dir, "coords.pth"))
img1 = self.transforms(Image.open(
os.path.join(self.path, dir, "1.jpg")))
img2 = self.transforms(Image.open(
os.path.join(self.path, dir, "2.jpg")))
coords1, coords2 = torch.load(
os.path.join(self.path, dir, "coords.pth"))
assert img1.shape[-2] == img1.shape[-1]
assert img2.shape[-2] == img2.shape[-1]
if img1.shape[-1] != self.size:
scale = img1.shape[-1] / self.size
assert(int(scale) == scale) # We will only downsample to even resolutions.
# We will only downsample to even resolutions.
assert (int(scale) == scale)
scale = 1 / scale
img1 = torch.nn.functional.interpolate(img1.unsqueeze(0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0)
img1 = torch.nn.functional.interpolate(img1.unsqueeze(
0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0)
coords1 = [int(c * scale) for c in coords1]
if img2.shape[-1] != self.size:
scale = img2.shape[-1] / self.size
assert(int(scale) == scale) # We will only downsample to even resolutions.
# We will only downsample to even resolutions.
assert (int(scale) == scale)
scale = 1 / scale
img2 = torch.nn.functional.interpolate(img2.unsqueeze(0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0)
img2 = torch.nn.functional.interpolate(img2.unsqueeze(
0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0)
coords2 = [int(c * scale) for c in coords2]
coords1 = (coords1[1], coords1[0]) # The UI puts these out backwards (x,y). Flip them.
# The UI puts these out backwards (x,y). Flip them.
coords1 = (coords1[1], coords1[0])
coords2 = (coords2[1], coords2[0])
return {
'img1': img1,
@ -53,6 +61,7 @@ class ImagePairWithCorrespondingPointsDataset(Dataset):
def __len__(self):
return len(self.pairs)
if __name__ == '__main__':
opt = {
'path': 'F:\\dlas\\codes\\scripts\\ui\\image_pair_labeler\\results',
@ -60,13 +69,14 @@ if __name__ == '__main__':
}
output_path = '..'
ds = DataLoader(ImagePairWithCorrespondingPointsDataset(opt), shuffle=True, num_workers=0)
ds = DataLoader(ImagePairWithCorrespondingPointsDataset(
opt), shuffle=True, num_workers=0)
for i, d in tqdm(enumerate(ds)):
i1 = d['img1']
i2 = d['img2']
c1 = d['coords1']
c2 = d['coords2']
i1[:,:,c1[0]-3:c1[0]+3,c1[1]-3:c1[1]+3] = 0
i2[:,:,c2[0]-3:c2[0]+3,c2[1]-3:c2[1]+3] = 0
i1[:, :, c1[0]-3:c1[0]+3, c1[1]-3:c1[1]+3] = 0
i2[:, :, c2[0]-3:c2[0]+3, c2[1]-3:c2[1]+3] = 0
torchvision.utils.save_image(i1, f'{output_path}\\{i}_1.png')
torchvision.utils.save_image(i2, f'{output_path}\\{i}_2.png')
torchvision.utils.save_image(i2, f'{output_path}\\{i}_2.png')

@ -1,8 +1,12 @@
from data.images.base_unsupervised_image_dataset import BaseUnsupervisedImageDataset
import os.path as osp
from bisect import bisect_left
import numpy as np
import torch
from bisect import bisect_left
import os.path as osp
from dlas.data.images.base_unsupervised_image_dataset import \
BaseUnsupervisedImageDataset
class MultiFrameDataset(BaseUnsupervisedImageDataset):
def __init__(self, opt):
@ -30,7 +34,8 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
for i in range(self.num_frames):
idx = search_idx + i
if idx < 0 or idx >= len(self.chunks) or chunk_offset < 0 or chunk_offset >= len(self.chunks[idx]):
print("Chunk reference indexing failed for %s." % (im_name,), search_idx, i, chunk_offset, self.num_frames)
print("Chunk reference indexing failed for %s." %
(im_name,), search_idx, i, chunk_offset, self.num_frames)
h, r, c, m, p = self.chunks[search_idx + i][chunk_offset]
hqs.append(h)
refs.append(r)
@ -41,24 +46,32 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
def __getitem__(self, item):
chunk_ind = bisect_left(self.starting_indices, item)
chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1
hqs, refs, masks, centers, path = self.get_sequential_image_paths_from(chunk_ind, item-self.starting_indices[chunk_ind])
chunk_ind = chunk_ind if chunk_ind < len(
self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1
hqs, refs, masks, centers, path = self.get_sequential_image_paths_from(
chunk_ind, item-self.starting_indices[chunk_ind])
hs, hrs, hms, hcs = self.resize_hq(hqs, refs, masks, centers)
ls, lrs, lms, lcs = self.synthesize_lq(hs, hrs, hms, hcs)
# Convert to torch tensor
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hs), (0, 3, 1, 2)))).float()
hq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hrs), (0, 3, 1, 2)))).float()
hq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(hms))).unsqueeze(dim=1)
hq = torch.from_numpy(np.ascontiguousarray(
np.transpose(np.stack(hs), (0, 3, 1, 2)))).float()
hq_ref = torch.from_numpy(np.ascontiguousarray(
np.transpose(np.stack(hrs), (0, 3, 1, 2)))).float()
hq_mask = torch.from_numpy(
np.ascontiguousarray(np.stack(hms))).unsqueeze(dim=1)
hq_ref = torch.cat([hq_ref, hq_mask], dim=1)
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(ls), (0, 3, 1, 2)))).float()
lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float()
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1)
lq = torch.from_numpy(np.ascontiguousarray(
np.transpose(np.stack(ls), (0, 3, 1, 2)))).float()
lq_ref = torch.from_numpy(np.ascontiguousarray(
np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float()
lq_mask = torch.from_numpy(
np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1)
lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
if __name__ == '__main__':
@ -84,7 +97,7 @@ if __name__ == '__main__':
for i in range(len(ds)):
import random
k = 'lq'
element = ds[random.randint(0,len(ds))]
element = ds[random.randint(0, len(ds))]
base_file = osp.basename(element["GT_path"])
o = element[k].unsqueeze(0)
if bs < 32:
@ -99,8 +112,9 @@ if __name__ == '__main__':
b, fr, f, h, w = batch.shape
for j in range(fr):
import torchvision
base=osp.basename(base_file)
torchvision.utils.save_image(batch[:, j], "debug/%i_%s_%i__%s.png" % (i, k, j, base))
base = osp.basename(base_file)
torchvision.utils.save_image(
batch[:, j], "debug/%i_%s_%i__%s.png" % (i, k, j, base))
bs = 0
batch = None

@ -1,12 +1,13 @@
import random
import numpy as np
import cv2
import data.util as util
import numpy as np
import torch
import torch.utils.data as data
import data.util as util
# Reads full-quality images and pulls tiles at regular zoom intervals from them. Only usable for training purposes.
from data.images.image_corruptor import ImageCorruptor
from dlas.data.images.image_corruptor import ImageCorruptor
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
@ -27,6 +28,7 @@ def get_square_image(image):
left = max(int(center + offset * (center - 2)), 0)
return image[:, left:left + h, :]
class MultiScaleDataset(data.Dataset):
def __init__(self, opt):
super(MultiScaleDataset, self).__init__()
@ -36,10 +38,10 @@ class MultiScaleDataset(data.Dataset):
self.num_scales = self.opt['num_scales']
self.hq_size_cap = self.tile_size * 2 ** self.num_scales
self.scale = self.opt['scale']
self.paths_hq, self.sizes_hq = util.find_files_of_type(self.data_type, opt['paths'], [1 for _ in opt['paths']])
self.paths_hq, self.sizes_hq = util.find_files_of_type(
self.data_type, opt['paths'], [1 for _ in opt['paths']])
self.corruptor = ImageCorruptor(opt)
def recursively_extract_patches(self, input_img, result_list, depth):
if depth >= self.num_scales:
return
@ -49,7 +51,8 @@ class MultiScaleDataset(data.Dataset):
input_img[:patch_size, patch_size:],
input_img[patch_size:, :patch_size],
input_img[patch_size:, patch_size:]]
result_list.extend([cv2.resize(p, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA) for p in patches])
result_list.extend([cv2.resize(
p, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA) for p in patches])
for p in patches:
self.recursively_extract_patches(p, result_list, depth+1)
@ -57,30 +60,41 @@ class MultiScaleDataset(data.Dataset):
# get full size image
full_path = self.paths_hq[index % len(self.paths_hq)]
loaded_img = util.read_img(None, full_path, None)
img_full1 = util.channel_convert(loaded_img.shape[2], 'RGB', [loaded_img])[0]
img_full1 = util.channel_convert(
loaded_img.shape[2], 'RGB', [loaded_img])[0]
img_full2 = util.augment([img_full1], True, True)[0]
img_full3 = get_square_image(img_full2)
# This error crops up from time to time. I suspect an issue with util.read_img.
if img_full3.shape[0] == 0 or img_full3.shape[1] == 0:
print("Error with image: %s. Loaded image shape: %s" % (full_path,str(loaded_img.shape)), str(img_full1.shape), str(img_full2.shape), str(img_full3.shape))
print("Error with image: %s. Loaded image shape: %s" % (full_path, str(
loaded_img.shape)), str(img_full1.shape), str(img_full2.shape), str(img_full3.shape))
# Attempt to recover by just using a fixed array of zeros, which the downstream networks should be fine training against, within reason.
img_full3 = np.zeros((1024,1024,3), dtype=np.int)
img_full = cv2.resize(img_full3, (self.hq_size_cap, self.hq_size_cap), interpolation=cv2.INTER_AREA)
patches_hq = [cv2.resize(img_full, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)]
img_full3 = np.zeros((1024, 1024, 3), dtype=np.int)
img_full = cv2.resize(
img_full3, (self.hq_size_cap, self.hq_size_cap), interpolation=cv2.INTER_AREA)
patches_hq = [cv2.resize(
img_full, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)]
self.recursively_extract_patches(img_full, patches_hq, 1)
# Image corruption is applied against the full size image for this dataset.
img_corrupted = self.corruptor.corrupt_images([img_full])[0]
patches_hq_corrupted = [cv2.resize(img_corrupted, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)]
self.recursively_extract_patches(img_corrupted, patches_hq_corrupted, 1)
patches_hq_corrupted = [cv2.resize(
img_corrupted, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)]
self.recursively_extract_patches(
img_corrupted, patches_hq_corrupted, 1)
# BGR to RGB, HWC to CHW, numpy to tensor
if patches_hq[0].shape[2] == 3:
patches_hq = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq]
patches_hq_corrupted = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq_corrupted]
patches_hq = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq]
patches_hq = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB)
for p in patches_hq]
patches_hq_corrupted = [cv2.cvtColor(
p, cv2.COLOR_BGR2RGB) for p in patches_hq_corrupted]
patches_hq = [torch.from_numpy(np.ascontiguousarray(
np.transpose(p, (2, 0, 1)))).float() for p in patches_hq]
patches_hq = torch.stack(patches_hq, dim=0)
patches_hq_corrupted = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq_corrupted]
patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='area').squeeze() for p in patches_hq_corrupted]
patches_hq_corrupted = [torch.from_numpy(np.ascontiguousarray(
np.transpose(p, (2, 0, 1)))).float() for p in patches_hq_corrupted]
patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(
0), scale_factor=1/self.scale, mode='area').squeeze() for p in patches_hq_corrupted]
patches_lq = torch.stack(patches_lq, dim=0)
d = {'lq': patches_lq, 'hq': patches_hq, 'GT_path': full_path}
@ -89,6 +103,7 @@ class MultiScaleDataset(data.Dataset):
def __len__(self):
return len(self.paths_hq)
class MultiscaleTreeNode:
def __init__(self, index, parent, i):
self.index = index
@ -117,7 +132,8 @@ def build_multiscale_patch_index_map(depth):
def _build_multiscale_patch_index_map(depth, ind, node, leaves):
subnodes = [node.add_child(MultiscaleTreeNode(ind+i, node, i)) for i in range(4)]
subnodes = [node.add_child(MultiscaleTreeNode(ind+i, node, i))
for i in range(4)]
ind += 4
if depth == 1:
leaves.extend(subnodes)
@ -146,7 +162,7 @@ if __name__ == '__main__':
os.makedirs("debug", exist_ok=True)
multiscale_tree = build_multiscale_patch_index_map(4)
for i in range(500, len(ds)):
quadrant=2
quadrant = 2
print(i)
o = ds[random.randint(0, len(ds))]
tree_ind = random.randint(0, len(multiscale_tree))
@ -155,9 +171,10 @@ if __name__ == '__main__':
continue
depth = 0
node = multiscale_tree[tree_ind]
#for j, img in enumerate(v):
# for j, img in enumerate(v):
# torchvision.utils.save_image(img.unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j))
while node is not None:
torchvision.utils.save_image(v[node.index].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, depth))
torchvision.utils.save_image(v[node.index].unsqueeze(
0), "debug/%i_%s_%i.png" % (i, k, depth))
depth += 1
node = node.parent
node = node.parent

@ -1,8 +1,12 @@
from data.images.base_unsupervised_image_dataset import BaseUnsupervisedImageDataset
import os.path as osp
from bisect import bisect_left
import numpy as np
import torch
from bisect import bisect_left
import os.path as osp
from dlas.data.images.base_unsupervised_image_dataset import \
BaseUnsupervisedImageDataset
class PairedFrameDataset(BaseUnsupervisedImageDataset):
def __init__(self, opt):
@ -26,24 +30,32 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset):
def __getitem__(self, item):
chunk_ind = bisect_left(self.starting_indices, item)
chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1
hqs, refs, masks, centers, path = self.get_pair(chunk_ind, item-self.starting_indices[chunk_ind])
chunk_ind = chunk_ind if chunk_ind < len(
self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1
hqs, refs, masks, centers, path = self.get_pair(
chunk_ind, item-self.starting_indices[chunk_ind])
hs, hrs, hms, hcs = self.resize_hq(hqs, refs, masks, centers)
ls, lrs, lms, lcs = self.synthesize_lq(hs, hrs, hms, hcs)
# Convert to torch tensor
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hs), (0, 3, 1, 2)))).float()
hq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hrs), (0, 3, 1, 2)))).float()
hq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(hms))).squeeze().unsqueeze(dim=1)
hq = torch.from_numpy(np.ascontiguousarray(
np.transpose(np.stack(hs), (0, 3, 1, 2)))).float()
hq_ref = torch.from_numpy(np.ascontiguousarray(
np.transpose(np.stack(hrs), (0, 3, 1, 2)))).float()
hq_mask = torch.from_numpy(np.ascontiguousarray(
np.stack(hms))).squeeze().unsqueeze(dim=1)
hq_ref = torch.cat([hq_ref, hq_mask], dim=1)
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(ls), (0, 3, 1, 2)))).float()
lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float()
lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).squeeze().unsqueeze(dim=1)
lq = torch.from_numpy(np.ascontiguousarray(
np.transpose(np.stack(ls), (0, 3, 1, 2)))).float()
lq_ref = torch.from_numpy(np.ascontiguousarray(
np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float()
lq_mask = torch.from_numpy(np.ascontiguousarray(
np.stack(lms))).squeeze().unsqueeze(dim=1)
lq_ref = torch.cat([lq_ref, lq_mask], dim=1)
return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)}
if __name__ == '__main__':
@ -51,7 +63,7 @@ if __name__ == '__main__':
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\vr\\validation'],
'weights': [1],
#'target_size': 128,
# 'target_size': 128,
'force_multiple': 32,
'scale': 2,
'eval': False,
@ -69,7 +81,7 @@ if __name__ == '__main__':
for i in range(len(ds)):
import random
k = 'lq'
element = ds[random.randint(0,len(ds))]
element = ds[random.randint(0, len(ds))]
base_file = osp.basename(element["GT_path"])
o = element[k].unsqueeze(0)
if bs < 2:
@ -84,8 +96,9 @@ if __name__ == '__main__':
b, fr, f, h, w = batch.shape
for j in range(fr):
import torchvision
base=osp.basename(base_file)
torchvision.utils.save_image(batch[:, j], "debug/%i_%s_%i__%s.png" % (i, k, j, base))
base = osp.basename(base_file)
torchvision.utils.save_image(
batch[:, j], "debug/%i_%s_%i__%s.png" % (i, k, j, base))
bs = 0
batch = None

@ -1,8 +1,11 @@
import random
from bisect import bisect_left
import numpy as np
import torch
from data.images.base_unsupervised_image_dataset import BaseUnsupervisedImageDataset
from dlas.data.images.base_unsupervised_image_dataset import \
BaseUnsupervisedImageDataset
# Builds a dataset composed of a set of folders. Each folder represents a single high resolution image that has been
@ -14,30 +17,40 @@ class SingleImageDataset(BaseUnsupervisedImageDataset):
def get_paths(self):
for i in range(len(self)):
chunk_ind = bisect_left(self.starting_indices, i)
chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == i else chunk_ind-1
chunk_ind = chunk_ind if chunk_ind < len(
self.starting_indices) and self.starting_indices[chunk_ind] == i else chunk_ind-1
yield self.chunks[chunk_ind].tiles[i-self.starting_indices[chunk_ind]]
def __getitem__(self, item):
chunk_ind = bisect_left(self.starting_indices, item)
chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1
hq, hq_ref, hq_center, hq_mask, path = self.chunks[chunk_ind][item-self.starting_indices[chunk_ind]]
chunk_ind = chunk_ind if chunk_ind < len(
self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1
hq, hq_ref, hq_center, hq_mask, path = self.chunks[chunk_ind][item -
self.starting_indices[chunk_ind]]
hs, hrs, hms, hcs = self.resize_hq([hq], [hq_ref], [hq_mask], [hq_center])
hs, hrs, hms, hcs = self.resize_hq(
[hq], [hq_ref], [hq_mask], [hq_center])
ls, lrs, lms, lcs = self.synthesize_lq(hs, hrs, hms, hcs)
# Convert to torch tensor
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float()
hq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(hrs[0], (2, 0, 1)))).float()
hq_mask = torch.from_numpy(np.ascontiguousarray(hms[0])).unsqueeze(dim=0)
hq = torch.from_numpy(np.ascontiguousarray(
np.transpose(hs[0], (2, 0, 1)))).float()
hq_ref = torch.from_numpy(np.ascontiguousarray(
np.transpose(hrs[0], (2, 0, 1)))).float()
hq_mask = torch.from_numpy(
np.ascontiguousarray(hms[0])).unsqueeze(dim=0)
hq_ref = torch.cat([hq_ref, hq_mask], dim=0)
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float()
lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(lrs[0], (2, 0, 1)))).float()
lq_mask = torch.from_numpy(np.ascontiguousarray(lms[0])).unsqueeze(dim=0)
lq = torch.from_numpy(np.ascontiguousarray(
np.transpose(ls[0], (2, 0, 1)))).float()
lq_ref = torch.from_numpy(np.ascontiguousarray(
np.transpose(lrs[0], (2, 0, 1)))).float()
lq_mask = torch.from_numpy(
np.ascontiguousarray(lms[0])).unsqueeze(dim=0)
lq_ref = torch.cat([lq_ref, lq_mask], dim=0)
return {'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref,
'lq_center': torch.tensor(lcs[0], dtype=torch.long), 'gt_center': torch.tensor(hcs[0], dtype=torch.long),
'LQ_path': path, 'GT_path': path}
'lq_center': torch.tensor(lcs[0], dtype=torch.long), 'gt_center': torch.tensor(hcs[0], dtype=torch.long),
'LQ_path': path, 'GT_path': path}
if __name__ == '__main__':
@ -60,13 +73,14 @@ if __name__ == '__main__':
os.makedirs("debug", exist_ok=True)
for i in range(0, len(ds)):
o = ds[random.randint(0, len(ds))]
#for k, v in o.items():
# for k, v in o.items():
k = 'lq'
v = o[k]
#if 'LQ' in k and 'path' not in k and 'center' not in k:
#if 'full' in k:
#masked = v[:3, :, :] * v[3]
#torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
#v = v[:3, :, :]
# if 'LQ' in k and 'path' not in k and 'center' not in k:
# if 'full' in k:
# masked = v[:3, :, :] * v[3]
# torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k))
# v = v[:3, :, :]
import torchvision
torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
torchvision.utils.save_image(
v.unsqueeze(0), "debug/%i_%s.png" % (i, k))

@ -1,15 +1,15 @@
from functools import partial
from pathlib import Path
from random import random
import torch
import torch.nn as nn
import torchvision
from PIL import Image
from torch.utils import data
from torchvision import transforms
import torch.nn as nn
from pathlib import Path
import models.image_generation.stylegan.stylegan2_lucidrains as sg2
import dlas.models.image_generation.stylegan.stylegan2_lucidrains as sg2
def convert_transparent_to_rgb(image):
@ -29,6 +29,7 @@ def resize_to_minimum_size(min_size, image):
return torchvision.transforms.functional.resize(image, min_size)
return image
class RandomApply(nn.Module):
def __init__(self, prob, fn, fn_else=lambda x: x):
super().__init__()
@ -59,7 +60,8 @@ class expand_greyscale(object):
color = tensor[:1].expand(3, -1, -1)
alpha = tensor[1:]
else:
raise Exception(f'image with invalid number of channels given {channels}')
raise Exception(
f'image with invalid number of channels given {channels}')
if not sg2.exists(alpha) and self.transparent:
alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device)
@ -73,17 +75,21 @@ class Stylegan2Dataset(data.Dataset):
EXTS = ['jpg', 'jpeg', 'png', 'webp']
self.folder = opt['path']
self.image_size = opt['target_size']
self.paths = [p for ext in EXTS for p in Path(f'{self.folder}').glob(f'**/*.{ext}')]
self.paths = [p for ext in EXTS for p in Path(
f'{self.folder}').glob(f'**/*.{ext}')]
aug_prob = opt['aug_prob']
transparent = opt['transparent'] if 'transparent' in opt.keys() else False
assert len(self.paths) > 0, f'No images were found in {self.folder} for training'
transparent = opt['transparent'] if 'transparent' in opt.keys(
) else False
assert len(
self.paths) > 0, f'No images were found in {self.folder} for training'
convert_image_fn = convert_transparent_to_rgb if not transparent else convert_rgb_to_transparent
num_channels = 3 if not transparent else 4
self.transform = transforms.Compose([
transforms.Lambda(convert_image_fn),
transforms.Lambda(partial(resize_to_minimum_size, self.image_size)),
transforms.Lambda(
partial(resize_to_minimum_size, self.image_size)),
transforms.Resize(self.image_size),
RandomApply(aug_prob, transforms.RandomResizedCrop(self.image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)),
transforms.CenterCrop(self.image_size)),

@ -1,9 +1,10 @@
import PIL.Image
import zipfile
import PIL.Image
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torchvision.transforms import Compose, Normalize, Resize, ToTensor
class ZipFileDataset(torch.utils.data.Dataset):
@ -14,9 +15,10 @@ class ZipFileDataset(torch.utils.data.Dataset):
self.resolution = opt['resolution']
self.paired_mode = opt['paired_mode']
self.transforms = Compose([ToTensor(),
Resize(self.resolution),
Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
Resize(self.resolution),
Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
])
self.zip = None
def __len__(self):
@ -49,10 +51,12 @@ class ZipFileDataset(torch.utils.data.Dataset):
aname = fname.replace('1.jpg', '0.jpg')
out['alt_hq'] = self.load_image(aname)
except:
print(f"Error loading {fname} from zipfile. Attempting to recover by loading next element.")
print(
f"Error loading {fname} from zipfile. Attempting to recover by loading next element.")
return self[i+1]
return out
if __name__ == '__main__':
opt = {
'path': 'E:\\4k6k\\datasets\\images\\youtube-imagenet-paired\\output.zip',
@ -65,4 +69,3 @@ if __name__ == '__main__':
for i, d in enumerate(loader):
torchvision.utils.save_image(d['hq'], f'{i}_hq.png')
torchvision.utils.save_image(d['alt_hq'], f'{i}_althq.png')

@ -1,18 +1,20 @@
from torch.utils.data import Dataset
import datasets
from torch.utils.data import Dataset
class HfDataset(Dataset):
"""
Simple wrapper for a HuggingFace dataset that can re-map keys if desired.
"""
def __init__(self, corpi, cache_path=None, key_maps=None, dataset_spec_key='train'):
self.hfd = []
for corpus in corpi:
dataset_name, config = corpus
if config == '' or config == 'None':
config = None
self.hfd.append(datasets.load_dataset(dataset_name, config, cache_dir=cache_path)[dataset_spec_key])
self.hfd.append(datasets.load_dataset(
dataset_name, config, cache_dir=cache_path)[dataset_spec_key])
self.key_maps = key_maps
def __getitem__(self, item):
@ -32,5 +34,6 @@ class HfDataset(Dataset):
if __name__ == '__main__':
d = HfDataset([['wikipedia', '20200501.en'], ['bookcorpus', '']], dataset_spec_key='train', cache_path='Z:\\huggingface_datasets\\cache')
d = HfDataset([['wikipedia', '20200501.en'], ['bookcorpus', '']],
dataset_spec_key='train', cache_path='Z:\\huggingface_datasets\\cache')
print(d[5])

@ -1,10 +1,10 @@
from torch.utils.data import Dataset
import torchvision.transforms as T
from torch.utils.data import Dataset
from torchvision import datasets
# Wrapper for basic pytorch datasets which re-wraps them into a format usable by ExtensibleTrainer.
from data.images.cifar import CIFAR100, CIFAR10
from utils.util import opt_get
from dlas.data.images.cifar import CIFAR10, CIFAR100
from dlas.utils.util import opt_get
class TorchDataset(Dataset):
@ -17,7 +17,8 @@ class TorchDataset(Dataset):
"imagenet": datasets.ImageNet,
"imagefolder": datasets.ImageFolder
}
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[
0.229, 0.224, 0.225])
if opt_get(opt, ['random_crop'], False):
transforms = [
T.RandomResizedCrop(opt['image_size']),
@ -34,7 +35,8 @@ class TorchDataset(Dataset):
normalize,
]
transforms = T.Compose(transforms)
self.dataset = DATASET_MAP[opt['dataset']](transform=transforms, **opt['kwargs'])
self.dataset = DATASET_MAP[opt['dataset']](
transform=transforms, **opt['kwargs'])
self.len = opt_get(opt, ['fixed_len'], len(self.dataset))
self.offset = opt_get(opt, ['offset'], 0)
@ -53,6 +55,7 @@ class TorchDataset(Dataset):
def __len__(self):
return self.len-self.offset
if __name__ == '__main__':
opt = {
'flip': True,

@ -1,21 +1,22 @@
import os
import glob
import math
import os
import pickle
import random
import cv2
import numpy
import numpy as np
import glob
import torch
import torchvision
import cv2
####################
# Files & IO
####################
###################### get image path list ######################
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.webp', '.WEBP']
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png',
'.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.webp', '.WEBP']
def torch2cv(tensor):
@ -30,7 +31,8 @@ def torch2cv(tensor):
def cv2torch(cv, batchify=True):
cv = cv2.cvtColor(cv, cv2.COLOR_BGR2RGB)
tens = torch.from_numpy(np.ascontiguousarray(np.transpose(cv, (2, 0, 1)))).float()
tens = torch.from_numpy(np.ascontiguousarray(
np.transpose(cv, (2, 0, 1)))).float()
if batchify:
tens = tens.unsqueeze(0)
return tens
@ -65,7 +67,8 @@ def _get_paths_from_images(path, qualifier=is_image_file):
def _get_paths_from_lmdb(dataroot):
"""get image path list from lmdb meta info"""
meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb'))
meta_info = pickle.load(
open(os.path.join(dataroot, 'meta_info.pkl'), 'rb'))
paths = meta_info['keys']
sizes = meta_info['resolution']
if len(sizes) == 1:
@ -123,7 +126,8 @@ def read_img(env, path, size=None, rgb=False):
# Indirect open then process to support unicode files.
stream = open(path, "rb")
bytes = bytearray(stream.read())
img = cv2.imdecode(np.asarray(bytes, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
img = cv2.imdecode(np.asarray(bytes, dtype=np.uint8),
cv2.IMREAD_UNCHANGED)
elif env == 'lmdb':
img = _read_img_lmdb(env, path, size)
elif env == 'buffer':
@ -161,7 +165,8 @@ def read_img_seq(path):
# stack to Torch tensor
imgs = np.stack(img_l, axis=0)
imgs = imgs[:, :, :, [2, 1, 0]]
imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
imgs = torch.from_numpy(np.ascontiguousarray(
np.transpose(imgs, (0, 3, 1, 2)))).float()
return imgs
@ -478,9 +483,12 @@ def imresize(img, scale, antialiasing=True):
kernel_width = weights_H.size(1)
for i in range(out_H):
idx = int(indices_H[i][0])
out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
out_1[0, i, :] = img_aug[0, idx:idx + kernel_width,
:].transpose(0, 1).mv(weights_H[i])
out_1[1, i, :] = img_aug[1, idx:idx + kernel_width,
:].transpose(0, 1).mv(weights_H[i])
out_1[2, i, :] = img_aug[2, idx:idx + kernel_width,
:].transpose(0, 1).mv(weights_H[i])
# process W dimension
# symmetric copying
@ -501,9 +509,12 @@ def imresize(img, scale, antialiasing=True):
kernel_width = weights_W.size(1)
for i in range(out_W):
idx = int(indices_W[i][0])
out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i])
out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i])
out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i])
out_2[0, :, i] = out_1_aug[0, :, idx:idx +
kernel_width].mv(weights_W[i])
out_2[1, :, i] = out_1_aug[1, :, idx:idx +
kernel_width].mv(weights_W[i])
out_2[2, :, i] = out_1_aug[2, :, idx:idx +
kernel_width].mv(weights_W[i])
return out_2
@ -548,9 +559,12 @@ def imresize_np(img, scale, antialiasing=True):
kernel_width = weights_H.size(1)
for i in range(out_H):
idx = int(indices_H[i][0])
out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i])
out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i])
out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i])
out_1[i, :, 0] = img_aug[idx:idx + kernel_width,
:, 0].transpose(0, 1).mv(weights_H[i])
out_1[i, :, 1] = img_aug[idx:idx + kernel_width,
:, 1].transpose(0, 1).mv(weights_H[i])
out_1[i, :, 2] = img_aug[idx:idx + kernel_width,
:, 2].transpose(0, 1).mv(weights_H[i])
# process W dimension
# symmetric copying
@ -571,9 +585,12 @@ def imresize_np(img, scale, antialiasing=True):
kernel_width = weights_W.size(1)
for i in range(out_W):
idx = int(indices_W[i][0])
out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i])
out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i])
out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i])
out_2[:, i, 0] = out_1_aug[:, idx:idx +
kernel_width, 0].mv(weights_W[i])
out_2[:, i, 1] = out_1_aug[:, idx:idx +
kernel_width, 1].mv(weights_W[i])
out_2[:, i, 2] = out_1_aug[:, idx:idx +
kernel_width, 2].mv(weights_W[i])
return out_2.numpy()
@ -587,7 +604,8 @@ def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=[], not
print(f"Building cache for contents of {paths}..")
output = []
for p in paths:
output.extend(find_files_of_type('img', p, qualifier=is_audio_file)[0])
output.extend(find_files_of_type(
'img', p, qualifier=is_audio_file)[0])
if exclusion_list is not None and len(exclusion_list) > 0:
print(f"Removing exclusion lists..")
before = len(output)
@ -597,6 +615,7 @@ def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=[], not
print(f"Excluded {before-len(output)} files.")
if endswith is not None:
before = len(output)
def filter_fn(p):
for e in endswith:
if not p.endswith(e):
@ -606,7 +625,8 @@ def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=[], not
return False
return True
output = list(filter(filter_fn, output))
print(f"!!Excluded {before-len(output)} files with endswith mask. For total of {len(output)} files")
print(
f"!!Excluded {before-len(output)} files with endswith mask. For total of {len(output)} files")
print("Done.")
torch.save(output, cache_path)
return output
@ -617,7 +637,8 @@ if __name__ == '__main__':
# read images
img = cv2.imread('test.png')
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img = torch.from_numpy(np.transpose(
img[:, :, [2, 1, 0]], (2, 0, 1))).float()
# imresize
scale = 1 / 4
import time

@ -7,12 +7,14 @@ class ZeroPadDictCollate():
Given a list of dictionary outputs with torch.Tensors from a Dataset, iterates through each one, finds the longest
tensor, and zero pads all the other tensors together.
"""
def collate_tensors(self, batch, key):
result = []
largest_dims = [0 for _ in range(len(batch[0][key].shape))]
for elem in batch:
result.append(elem[key])
largest_dims = [max(current_largest, new_consideration) for current_largest, new_consideration in zip(largest_dims, elem[key].shape)]
largest_dims = [max(current_largest, new_consideration)
for current_largest, new_consideration in zip(largest_dims, elem[key].shape)]
# Now pad each tensor by the largest dimension.
for i in range(len(result)):
padding_tuple = ()
@ -24,7 +26,6 @@ class ZeroPadDictCollate():
return torch.stack(result, dim=0)
def collate_into_list(self, batch, key):
result = []
for elem in batch:
@ -42,4 +43,4 @@ class ZeroPadDictCollate():
collated[key] = torch.stack([b[key] for b in batch])
else:
collated[key] = self.collate_into_list(batch, key)
return collated
return collated

@ -3,13 +3,11 @@ from abc import abstractmethod
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.nn.utils.spectral_norm as SpectralNorm
from math import sqrt
import torch.nn.init as init
from utils.util import checkpoint
import torch_intermediary as ml
import dlas.torch_intermediary as ml
from dlas.utils.util import checkpoint
def exists(val):
@ -21,14 +19,14 @@ def default(val, d):
def l2norm(t):
return F.normalize(t, p = 2, dim = -1)
return F.normalize(t, p=2, dim=-1)
def ema_inplace(moving_avg, new, decay):
moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay))
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def laplace_smoothing(x, n_categories, eps = 1e-5):
def laplace_smoothing(x, n_categories, eps=1e-5):
return (x + eps) / (x.sum() + n_categories * eps)
@ -36,9 +34,9 @@ def sample_vectors(samples, num):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device = device)[:num]
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device = device)
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices]
@ -239,7 +237,8 @@ class AttentionPool2d(nn.Module):
b, c, *_spatial = x.shape
x = x.reshape(b, c, -1) # NC(HW)
x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
x = x + self.positional_embedding[None, :, :x.shape[-1]].to(x.dtype) # NC(HW+1)
x = x + self.positional_embedding[None,
:, :x.shape[-1]].to(x.dtype) # NC(HW+1)
x = self.qkv_proj(x)
x = self.attention(x)
x = self.c_proj(x)
@ -296,7 +295,8 @@ class Upsample(nn.Module):
if dims == 1:
ksize = 5
pad = 2
self.conv = conv_nd(dims, self.channels, self.out_channels, ksize, padding=pad)
self.conv = conv_nd(dims, self.channels,
self.out_channels, ksize, padding=pad)
def forward(self, x):
assert x.shape[1] == self.channels
@ -346,6 +346,7 @@ class cGLU(nn.Module):
"""
Gated GELU for channel-first architectures.
"""
def __init__(self, dim_in, dim_out=None):
super().__init__()
dim_out = dim_in if dim_out is None else dim_out
@ -395,7 +396,8 @@ class ResBlock(nn.Module):
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
conv_nd(dims, channels, self.out_channels,
kernel_size, padding=padding),
)
self.updown = up or down
@ -414,7 +416,8 @@ class ResBlock(nn.Module):
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
conv_nd(dims, self.out_channels, self.out_channels,
kernel_size, padding=padding)
),
)
@ -425,7 +428,8 @@ class ResBlock(nn.Module):
dims, channels, self.out_channels, kernel_size, padding=padding
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 1)
def forward(self, x):
"""
@ -466,10 +470,10 @@ def build_local_attention_mask(n, l, fixed_region=0):
A mask that can be applied to AttentionBlock to achieve local attention.
"""
assert l*2 < n, f'Local context must be less than global context. {l}, {n}'
o = torch.arange(0,n)
c = o.unsqueeze(-1).repeat(1,n)
r = o.unsqueeze(0).repeat(n,1)
localized = ((-(r-c).abs())+l).clamp(0,l-1) / (l-1)
o = torch.arange(0, n)
c = o.unsqueeze(-1).repeat(1, n)
r = o.unsqueeze(0).repeat(n, 1)
localized = ((-(r-c).abs())+l).clamp(0, l-1) / (l-1)
localized[:fixed_region] = 1
localized[:, :fixed_region] = 1
mask = localized > 0
@ -477,7 +481,7 @@ def build_local_attention_mask(n, l, fixed_region=0):
def test_local_attention_mask():
print(build_local_attention_mask(9,4,1))
print(build_local_attention_mask(9, 4, 1))
class RelativeQKBias(nn.Module):
@ -487,17 +491,18 @@ class RelativeQKBias(nn.Module):
If symmetric=False, a different bias is applied to each side of the input element, otherwise the bias is symmetric.
"""
def __init__(self, l, max_positions=4000, symmetric=True):
super().__init__()
if symmetric:
self.emb = nn.Parameter(torch.randn(l+1) * .01)
o = torch.arange(0,max_positions)
c = o.unsqueeze(-1).repeat(1,max_positions)
r = o.unsqueeze(0).repeat(max_positions,1)
M = ((-(r-c).abs())+l).clamp(0,l)
o = torch.arange(0, max_positions)
c = o.unsqueeze(-1).repeat(1, max_positions)
r = o.unsqueeze(0).repeat(max_positions, 1)
M = ((-(r-c).abs())+l).clamp(0, l)
else:
self.emb = nn.Parameter(torch.randn(l*2+2) * .01)
a = torch.arange(0,max_positions)
a = torch.arange(0, max_positions)
c = a.unsqueeze(-1) - a
m = (c >= -l).logical_and(c <= l)
M = (l+c+1)*m
@ -508,7 +513,7 @@ class RelativeQKBias(nn.Module):
# return self.emb[self.M[:n, :n]].view(1,n,n)
# However, indexing operations like this have horrible efficiency on GPUs: https://github.com/pytorch/pytorch/issues/15245
# So, enter this horrible, equivalent mess:
return torch.gather(self.emb.unsqueeze(-1).repeat(1,n), 0, self.M[:n,:n]).view(1,n,n)
return torch.gather(self.emb.unsqueeze(-1).repeat(1, n), 0, self.M[:n, :n]).view(1, n, n)
class AttentionBlock(nn.Module):
@ -550,7 +555,8 @@ class AttentionBlock(nn.Module):
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1)
self.x_proj = nn.Identity() if out_channels == channels else conv_nd(
1, channels, out_channels, 1)
self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1))
def forward(self, x, mask=None, qk_bias=None):
@ -572,7 +578,7 @@ class AttentionBlock(nn.Module):
b, c, *spatial = x.shape
if mask is not None:
if len(mask.shape) == 2:
mask = mask.unsqueeze(0).repeat(x.shape[0],1,1)
mask = mask.unsqueeze(0).repeat(x.shape[0], 1, 1)
if mask.shape[1] != x.shape[-1]:
mask = mask[:, :x.shape[-1], :x.shape[-1]]
@ -606,7 +612,8 @@ class QKVAttentionLegacy(nn.Module):
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3,
length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum(
"bct,bcs->bts", q * scale, k * scale
@ -651,7 +658,8 @@ class QKVAttention(nn.Module):
mask = mask.repeat(self.n_heads, 1, 1)
weight[mask.logical_not()] = -torch.inf
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
a = torch.einsum("bts,bcs->bct", weight,
v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
@ -678,7 +686,8 @@ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
output = F.grid_sample(
x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
return output
@ -690,7 +699,8 @@ class PixelUnshuffle(nn.Module):
def forward(self, x):
(b, f, w, h) = x.shape
x = x.contiguous().view(b, f, w // self.r, self.r, h // self.r, self.r)
x = x.permute(0, 1, 3, 5, 2, 4).contiguous().view(b, f * (self.r ** 2), w // self.r, h // self.r)
x = x.permute(0, 1, 3, 5, 2, 4).contiguous().view(
b, f * (self.r ** 2), w // self.r, h // self.r)
return x
@ -704,6 +714,8 @@ def silu(input):
# create a class wrapper from PyTorch nn.Module, so
# the function now can be easily used in models
class SiLU(nn.Module):
'''
Applies the Sigmoid Linear Unit (SiLU) function element-wise:
@ -720,11 +732,12 @@ class SiLU(nn.Module):
>>> input = torch.randn(2)
>>> output = m(input)
'''
def __init__(self):
'''
Init method.
'''
super().__init__() # init the base class
super().__init__() # init the base class
def forward(self, input):
'''
@ -735,12 +748,15 @@ class SiLU(nn.Module):
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
kernel sizes. '''
class ConvBnRelu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True):
super(ConvBnRelu, self).__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys()
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size,
stride, padding_map[kernel_size], bias=bias)
if norm:
self.bn = nn.BatchNorm2d(filters_out)
else:
@ -753,7 +769,8 @@ class ConvBnRelu(nn.Module):
# Init params.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
@ -770,12 +787,15 @@ class ConvBnRelu(nn.Module):
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
kernel sizes. '''
class ConvBnSilu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1):
super(ConvBnSilu, self).__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys()
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size,
stride, padding_map[kernel_size], bias=bias)
if norm:
self.bn = nn.BatchNorm2d(filters_out)
else:
@ -788,7 +808,8 @@ class ConvBnSilu(nn.Module):
# Init params.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
m.weight.data *= weight_init_factor
if m.bias is not None:
m.bias.data.zero_()
@ -808,12 +829,15 @@ class ConvBnSilu(nn.Module):
''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard
kernel sizes. '''
class ConvBnLelu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1):
super(ConvBnLelu, self).__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys()
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size,
stride, padding_map[kernel_size], bias=bias)
if norm:
self.bn = nn.BatchNorm2d(filters_out)
else:
@ -847,12 +871,15 @@ class ConvBnLelu(nn.Module):
''' Convenience class with Conv->GroupNorm->LeakyReLU. Includes weight initialization and auto-padding for standard
kernel sizes. '''
class ConvGnLelu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1):
super(ConvGnLelu, self).__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys()
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size,
stride, padding_map[kernel_size], bias=bias)
if norm:
self.gn = nn.GroupNorm(num_groups, filters_out)
else:
@ -886,12 +913,15 @@ class ConvGnLelu(nn.Module):
''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard
kernel sizes. '''
class ConvGnSilu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, convnd=nn.Conv2d):
super(ConvGnSilu, self).__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys()
self.conv = convnd(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
self.conv = convnd(filters_in, filters_out, kernel_size,
stride, padding_map[kernel_size], bias=bias)
if norm:
self.gn = nn.GroupNorm(num_groups, filters_out)
else:
@ -904,7 +934,8 @@ class ConvGnSilu(nn.Module):
# Init params.
for m in self.modules():
if isinstance(m, convnd):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear')
m.weight.data *= weight_init_factor
if m.bias is not None:
m.bias.data.zero_()
@ -924,12 +955,15 @@ class ConvGnSilu(nn.Module):
''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard
kernel sizes. '''
class ConvBnRelu(nn.Module):
def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1):
super(ConvBnRelu, self).__init__()
padding_map = {1: 0, 3: 1, 5: 2, 7: 3}
assert kernel_size in padding_map.keys()
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias)
self.conv = nn.Conv2d(filters_in, filters_out, kernel_size,
stride, padding_map[kernel_size], bias=bias)
if norm:
self.bn = nn.BatchNorm2d(filters_out)
else:
@ -942,7 +976,8 @@ class ConvBnRelu(nn.Module):
# Init params.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear')
m.weight.data *= weight_init_factor
if m.bias is not None:
m.bias.data.zero_()
@ -969,7 +1004,8 @@ class MultiConvBlock(nn.Module):
self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor)] +
[ConvBnLelu(filters_mid, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor) for i in range(depth - 2)] +
[ConvBnLelu(filters_mid, filters_out, kernel_size, activation=False, norm=False, bias=False, weight_init_factor=weight_init_factor)])
self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init, dtype=torch.float))
self.scale = nn.Parameter(torch.full(
(1,), fill_value=scale_init, dtype=torch.float))
self.bias = nn.Parameter(torch.zeros(1))
def forward(self, x, noise=None):
@ -988,10 +1024,14 @@ class ExpansionBlock(nn.Module):
super(ExpansionBlock, self).__init__()
if filters_out is None:
filters_out = filters_in // 2
self.decimate = block(filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True)
self.process_passthrough = block(filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True)
self.conjoin = block(filters_out*2, filters_out, kernel_size=3, bias=False, activation=True, norm=False)
self.process = block(filters_out, filters_out, kernel_size=3, bias=False, activation=True, norm=True)
self.decimate = block(
filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True)
self.process_passthrough = block(
filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True)
self.conjoin = block(filters_out*2, filters_out,
kernel_size=3, bias=False, activation=True, norm=False)
self.process = block(filters_out, filters_out,
kernel_size=3, bias=False, activation=True, norm=True)
# input is the feature signal with shape (b, f, w, h)
# passthrough is the structure signal with shape (b, f/2, w*2, h*2)
@ -1012,10 +1052,14 @@ class ExpansionBlock2(nn.Module):
super(ExpansionBlock2, self).__init__()
if filters_out is None:
filters_out = filters_in // 2
self.decimate = block(filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True)
self.process_passthrough = block(filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True)
self.conjoin = block(filters_out*2, filters_out*2, kernel_size=3, bias=False, activation=True, norm=False)
self.reduce = block(filters_out*2, filters_out, kernel_size=3, bias=False, activation=True, norm=True)
self.decimate = block(
filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True)
self.process_passthrough = block(
filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True)
self.conjoin = block(filters_out*2, filters_out*2,
kernel_size=3, bias=False, activation=True, norm=False)
self.reduce = block(filters_out*2, filters_out,
kernel_size=3, bias=False, activation=True, norm=True)
# input is the feature signal with shape (b, f, w, h)
# passthrough is the structure signal with shape (b, f/2, w*2, h*2)
@ -1036,8 +1080,10 @@ class ConjoinBlock(nn.Module):
filters_out = filters_in
if filters_pt is None:
filters_pt = filters_in
self.process = block(filters_in + filters_pt, filters_in + filters_pt, kernel_size=3, bias=False, activation=True, norm=norm)
self.decimate = block(filters_in + filters_pt, filters_out, kernel_size=1, bias=False, activation=False, norm=norm)
self.process = block(filters_in + filters_pt, filters_in + filters_pt,
kernel_size=3, bias=False, activation=True, norm=norm)
self.decimate = block(filters_in + filters_pt, filters_out,
kernel_size=1, bias=False, activation=False, norm=norm)
def forward(self, input, passthrough):
x = torch.cat([input, passthrough], dim=1)
@ -1053,7 +1099,8 @@ class ReferenceJoinBlock(nn.Module):
scale_init=residual_weight_init_factor, norm=False,
weight_init_factor=residual_weight_init_factor)
if join:
self.join_conv = block(nf, nf, kernel_size=kernel_size, norm=final_norm, bias=False, activation=True)
self.join_conv = block(
nf, nf, kernel_size=kernel_size, norm=final_norm, bias=False, activation=True)
else:
self.join_conv = None
@ -1070,7 +1117,8 @@ class ReferenceJoinBlock(nn.Module):
class UpconvBlock(nn.Module):
def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True, activation=True, bias=False):
super(UpconvBlock, self).__init__()
self.process = block(filters_in, filters_out, kernel_size=3, bias=bias, activation=activation, norm=norm)
self.process = block(filters_in, filters_out, kernel_size=3,
bias=bias, activation=activation, norm=norm)
def forward(self, x):
x = F.interpolate(x, scale_factor=2, mode="nearest")
@ -1083,21 +1131,29 @@ class FinalUpsampleBlock2x(nn.Module):
super(FinalUpsampleBlock2x, self).__init__()
if scale == 2:
self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True),
UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True),
block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True),
UpconvBlock(
nf, nf // 2, block=block, norm=False, activation=True, bias=True),
block(nf // 2, nf // 2, kernel_size=3,
norm=False, activation=False, bias=True),
block(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False))
else:
self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True),
UpconvBlock(nf, nf, block=block, norm=False, activation=True, bias=True),
block(nf, nf, kernel_size=3, norm=False, activation=False, bias=True),
UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True),
block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True),
UpconvBlock(
nf, nf, block=block, norm=False, activation=True, bias=True),
block(nf, nf, kernel_size=3, norm=False,
activation=False, bias=True),
UpconvBlock(
nf, nf // 2, block=block, norm=False, activation=True, bias=True),
block(nf // 2, nf // 2, kernel_size=3,
norm=False, activation=False, bias=True),
block(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False))
def forward(self, x):
return self.chain(x)
# torch.gather() which operates as it always fucking should have: pulling indexes from the input.
def gather_2d(input, index):
b, c, h, w = input.shape
nodim = input.view(b, c, h * w)

@ -3,13 +3,14 @@ from itertools import groupby
import torch
import torch.nn as nn
from transformers import Wav2Vec2ForCTC
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Attention, Wav2Vec2Model
from transformers.models.wav2vec2.modeling_wav2vec2 import (Wav2Vec2Attention,
Wav2Vec2Model)
from data.audio.unsupervised_audio_dataset import load_audio
from models.audio.tts.tacotron2.text import sequence_to_text
from trainer.networks import register_model
from utils.util import opt_get
import torch_intermediary as ml
import dlas.torch_intermediary as ml
from dlas.data.audio.unsupervised_audio_dataset import load_audio
from dlas.models.audio.tts.tacotron2.text import sequence_to_text
from dlas.trainer.networks import register_model
from dlas.utils.util import opt_get
def only_letters(string):
@ -22,6 +23,7 @@ class Wav2VecFeatureExtractor(nn.Module):
Basic wrapper that only does feature extraction. Useful to build out this portion of the model so it can be
operated through DDP.
"""
def __init__(self, basis_model='facebook/wav2vec2-large'):
super().__init__()
w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
@ -34,7 +36,8 @@ class Wav2VecFeatureExtractor(nn.Module):
def forward(self, audio, wav_lengths):
with torch.no_grad():
audio = audio[:, :, :wav_lengths.max()]
audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
audio_norm = (audio - audio.mean()) / \
torch.sqrt(audio.var() + 1e-7)
return self.extractor(audio_norm.squeeze(1))
@ -42,13 +45,14 @@ class Wav2VecWrapper(nn.Module):
"""
Basic wrapper class that makes Wav2Vec2 usable by DLAS.
"""
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True,
checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True,
remove_feature_extractor=False, ramp_dropout_mode=False, ramp_dropout_end=20000, ramp_dropout_min=.1,
ramp_dropout_max=.5, layer_drop_pct=.1):
super().__init__()
self.provide_attention_mask = provide_attention_mask
self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
# Perform some surgery to get the model we actually want.
self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled
@ -99,7 +103,8 @@ class Wav2VecWrapper(nn.Module):
unaligned_tokens[b, text_lengths[b]:] = -100
model_inp = fea_extractor if self.remove_feature_extractor else audio
outputs = self.w2v(input_values=model_inp, attention_mask=attention_mask, labels=unaligned_tokens)
outputs = self.w2v(
input_values=model_inp, attention_mask=attention_mask, labels=unaligned_tokens)
if self.output_wer:
self.last_pred.append(torch.argmax(outputs.logits, dim=-1))
@ -126,11 +131,15 @@ class Wav2VecWrapper(nn.Module):
pred_strings = []
for last_labels, last_pred in zip(self.last_labels, self.last_pred):
last_labels[last_labels == -100] = 0
label_strings.extend([only_letters(sequence_to_text(lbl)) for lbl in last_labels])
pred_strings.extend([only_letters(sequence_to_text(self.decode_ctc(pred))) for pred in last_pred])
wer = wer_metric.compute(predictions=pred_strings, references=label_strings)
label_strings.extend(
[only_letters(sequence_to_text(lbl)) for lbl in last_labels])
pred_strings.extend([only_letters(sequence_to_text(
self.decode_ctc(pred))) for pred in last_pred])
wer = wer_metric.compute(
predictions=pred_strings, references=label_strings)
res['wer'] = wer
print(f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}")
print(
f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}")
if self.ramp_dropout_mode:
res['dropout_rate'] = self.current_dropout_rate
return res
@ -149,7 +158,8 @@ class Wav2VecWrapper(nn.Module):
def update_for_step(self, step, *args):
if self.ramp_dropout_mode and step % 10 == 0:
dropout_gap = self.ramp_dropout_max - self.ramp_dropout_min
new_dropout_rate = self.ramp_dropout_min + dropout_gap * min(step / self.ramp_dropout_end, 1)
new_dropout_rate = self.ramp_dropout_min + \
dropout_gap * min(step / self.ramp_dropout_end, 1)
self.current_dropout_rate = new_dropout_rate
for name, module in self.w2v.named_modules():
if isinstance(module, nn.Dropout):
@ -187,14 +197,18 @@ def register_wav2vec2(opt_net, opt):
if __name__ == '__main__':
fe = Wav2VecFeatureExtractor(basis_model='facebook/wav2vec2-large-960h')
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True, ramp_dropout_mode=True)
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h',
freeze_transformer=True, remove_feature_extractor=True, ramp_dropout_mode=True)
w2v.update_for_step(8000)
fea = fe(torch.randn(2,1,50000), torch.tensor([20000, 30000]))
loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]), fea)
w2v.get_debug_values(0,"")
fea = fe(torch.randn(2, 1, 50000), torch.tensor([20000, 30000]))
loss = w2v(torch.randn(2, 1, 50000), torch.randint(0, 40, (2, 70)),
torch.tensor([20000, 30000]), torch.tensor([35, 50]), fea)
w2v.get_debug_values(0, "")
sd = torch.load('../experiments/train_wav2vec_mass_archived_r0/models/19500_wav2vec.pth')
sd = torch.load(
'../experiments/train_wav2vec_mass_archived_r0/models/19500_wav2vec.pth')
w2v.load_state_dict(sd)
pred = w2v.inference(load_audio('Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav', 16000).unsqueeze(0))
pred = w2v.inference(load_audio(
'Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav', 16000).unsqueeze(0))
res = sequence_to_text(pred[0])
print(res)

@ -1,12 +1,12 @@
from typing import Any, Callable, List, Optional, Type, Union
import torch
from torch import Tensor
import torch.nn as nn
from torch import Tensor
from trainer.networks import register_model
from utils.util import opt_get
from typing import Type, Any, Callable, Union, List, Optional
import torch_intermediary as ml
import dlas.torch_intermediary as ml
from dlas.trainer.networks import register_model
from dlas.utils.util import opt_get
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
@ -42,9 +42,11 @@ class BasicBlock(nn.Module):
if norm_layer is None:
norm_layer = nn.BatchNorm1d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
raise ValueError(
'BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
raise NotImplementedError(
"Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
@ -177,7 +179,8 @@ class ResNet(nn.Module):
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
@ -188,9 +191,11 @@ class ResNet(nn.Module):
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
# type: ignore[arg-type]
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
# type: ignore[arg-type]
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
stride: int = 1, dilate: bool = False) -> nn.Sequential:
@ -386,5 +391,5 @@ def register_audio_resnet(opt_net, opt):
if __name__ == '__main__':
m = resnet34()
o = m(torch.randn((1,1,48000)))
print(o.shape)
o = m(torch.randn((1, 1, 48000)))
print(o.shape)

@ -9,13 +9,14 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributed
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.models.wav2vec2.modeling_wav2vec2 import (
_compute_mask_indices, _sample_negative_indices)
from models.arch_util import ResBlock
from trainer.networks import register_model
from utils.util import checkpoint
import torch_intermediary as ml
import dlas.torch_intermediary as ml
from dlas.models.arch_util import ResBlock
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint
class Mel2Vec2FeatureProjection(nn.Module):
@ -243,6 +244,8 @@ class Wav2Vec2SamePadLayer(nn.Module):
from torch.nn.utils.weight_norm import WeightNorm
def __deepcopy__(self, memo):
# save and delete all weightnorm weights on self
weights = {}

@ -2,13 +2,13 @@ import torch
import torch.nn.functional as F
from torch import nn
from transformers import GPT2Config, GPT2Model
import torch_intermediary as ml
from models.arch_util import AttentionBlock, ResBlock
from models.audio.tts.lucidrains_dvae import DiscreteVAE
from models.lucidrains.x_transformers import Encoder
from trainer.networks import register_model
from utils.util import opt_get, ceil_multiple, print_network
import dlas.torch_intermediary as ml
from dlas.models.arch_util import AttentionBlock, ResBlock
from dlas.models.audio.tts.lucidrains_dvae import DiscreteVAE
from dlas.models.lucidrains.x_transformers import Encoder
from dlas.trainer.networks import register_model
from dlas.utils.util import ceil_multiple, opt_get, print_network
class ConditioningEncoder(nn.Module):
@ -22,23 +22,23 @@ class ConditioningEncoder(nn.Module):
super().__init__()
self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1)
self.attn = Encoder(
dim=embedding_dim,
depth=attn_blocks,
heads=num_attn_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
zero_init_branch_output=True,
ff_mult=2,
do_checkpointing=do_checkpointing
)
dim=embedding_dim,
depth=attn_blocks,
heads=num_attn_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
zero_init_branch_output=True,
ff_mult=2,
do_checkpointing=do_checkpointing
)
self.dim = embedding_dim
def forward(self, x):
h = self.init(x).permute(0,2,1)
h = self.attn(h).permute(0,2,1)
h = self.init(x).permute(0, 2, 1)
h = self.attn(h).permute(0, 2, 1)
return h.mean(-1)
@ -46,7 +46,7 @@ class ConditioningAR(nn.Module):
def __init__(self, dim, layers, dropout=0, num_vectors=8192, cond_free_percent=.15, fp16=False):
super().__init__()
self.cond_encoder = ConditioningEncoder(256, dim)
self.cond_free_emb = nn.Parameter(torch.randn(1,dim))
self.cond_free_emb = nn.Parameter(torch.randn(1, dim))
self.unconditioned_percentage = cond_free_percent
self.fp16 = fp16
@ -65,13 +65,16 @@ class ConditioningAR(nn.Module):
cond = self.cond_encoder(conditioning)
if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((cond.shape[0],1), device=cond.device) < self.unconditioned_percentage
cond = torch.where(unconditioned_batches, self.cond_free_emb.repeat(cond.shape[0],1), cond)
unconditioned_batches = torch.rand(
(cond.shape[0], 1), device=cond.device) < self.unconditioned_percentage
cond = torch.where(unconditioned_batches,
self.cond_free_emb.repeat(cond.shape[0], 1), cond)
unused_params.append(self.cond_free_emb)
h = self.embeddings(cheater_codes)
h = torch.cat([cond.unsqueeze(1), h], dim=1)
targets = cheater_codes # Since we padded above by 1, the input alignment works.
# Since we padded above by 1, the input alignment works.
targets = cheater_codes
with torch.autocast(cheater_codes.device.type, enabled=self.fp16):
h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state
@ -79,12 +82,13 @@ class ConditioningAR(nn.Module):
if return_latent:
return h.float()
logits = self.head(h[:,:-1]).permute(0,2,1)
logits = self.head(h[:, :-1]).permute(0, 2, 1)
loss = F.cross_entropy(logits, targets, reduction="none")
# Perform masking
if code_lengths is not None:
mask = torch.arange(0, loss.shape[1], device=h.device).unsqueeze(0).repeat(loss.shape[0], 1) < code_lengths.unsqueeze(1)
mask = torch.arange(0, loss.shape[1], device=h.device).unsqueeze(
0).repeat(loss.shape[0], 1) < code_lengths.unsqueeze(1)
loss = loss * mask
loss = loss.mean()
@ -114,14 +118,13 @@ def test_ar():
model = ConditioningAR(512, 8, cond_free_percent=.5)
print_network(model)
codes = torch.randint(0,8192, (2,400))
cond = torch.randn(2,256,400)
cl = torch.tensor([200,10])
codes[1,10:] = 2
codes = torch.randint(0, 8192, (2, 400))
cond = torch.randn(2, 256, 400)
cl = torch.tensor([200, 10])
codes[1, 10:] = 2
model(codes, cond, cl)
pg = model.get_grad_norm_parameter_groups()
if __name__ == '__main__':
test_ar()

@ -13,17 +13,16 @@
# limitations under the License.
# ==============================================================================
from math import sqrt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_intermediary as ml
from math import sqrt
from torch.utils.checkpoint import checkpoint
from trainer.networks import register_model
import dlas.torch_intermediary as ml
from dlas.trainer.networks import register_model
Linear = ml.Linear
ConvTranspose2d = nn.ConvTranspose2d
@ -43,7 +42,8 @@ def silu(x):
class DiffusionEmbedding(nn.Module):
def __init__(self, max_steps):
super().__init__()
self.register_buffer('embedding', self._build_embedding(max_steps), persistent=False)
self.register_buffer('embedding', self._build_embedding(
max_steps), persistent=False)
self.projection1 = Linear(128, 512)
self.projection2 = Linear(512, 512)
@ -76,8 +76,10 @@ class DiffusionEmbedding(nn.Module):
class SpectrogramUpsampler(nn.Module):
def __init__(self, n_mels):
super().__init__()
self.conv1 = ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8])
self.conv2 = ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8])
self.conv1 = ConvTranspose2d(1, 1, [3, 32], stride=[
1, 16], padding=[1, 8])
self.conv2 = ConvTranspose2d(1, 1, [3, 32], stride=[
1, 16], padding=[1, 8])
def forward(self, x):
x = torch.unsqueeze(x, 1)
@ -98,27 +100,32 @@ class ResidualBlock(nn.Module):
:param uncond: disable spectrogram conditional
'''
super().__init__()
self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation)
self.dilated_conv = Conv1d(
residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation)
self.diffusion_projection = Linear(512, residual_channels)
if not uncond: # conditional model
self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1)
self.conditioner_projection = Conv1d(
n_mels, 2 * residual_channels, 1)
else: # unconditional model
self.conditioner_projection = None
self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)
self.output_projection = Conv1d(
residual_channels, 2 * residual_channels, 1)
def forward(self, x, diffusion_step, conditioner=None):
assert (conditioner is None and self.conditioner_projection is None) or \
(conditioner is not None and self.conditioner_projection is not None)
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
diffusion_step = self.diffusion_projection(
diffusion_step).unsqueeze(-1)
y = x + diffusion_step
if self.conditioner_projection is None: # using a unconditional model
y = self.dilated_conv(y)
else:
y = self.dilated_conv(y)
conditioner = self.conditioner_projection(conditioner)
conditioner = F.interpolate(conditioner, size=y.shape[-1], mode='nearest')
conditioner = F.interpolate(
conditioner, size=y.shape[-1], mode='nearest')
y = y + conditioner
gate, filter = torch.chunk(y, 2, dim=1)
@ -141,7 +148,8 @@ class DiffWave(nn.Module):
self.spectrogram_upsampler = SpectrogramUpsampler(n_mels)
self.residual_layers = nn.ModuleList([
ResidualBlock(n_mels, residual_channels, 2 ** (i % dilation_cycle_length), uncond=unconditional)
ResidualBlock(n_mels, residual_channels, 2 ** (i %
dilation_cycle_length), uncond=unconditional)
for i in range(residual_layers)
])
self.skip_projection = Conv1d(residual_channels, residual_channels, 1)
@ -177,4 +185,5 @@ def register_diffwave(opt_net, opt):
if __name__ == '__main__':
model = DiffWave()
model(torch.randn(2,1,65536), torch.tensor([500,3999]), torch.randn(2,128,256))
model(torch.randn(2, 1, 65536), torch.tensor(
[500, 3999]), torch.randn(2, 128, 256))

@ -3,10 +3,10 @@ import torch.nn.functional as F
from torch import nn
from transformers import GPT2Config, GPT2Model
from models.arch_util import AttentionBlock, ResBlock
from models.audio.tts.lucidrains_dvae import DiscreteVAE
from trainer.networks import register_model
from utils.util import opt_get, ceil_multiple, print_network
from dlas.models.arch_util import AttentionBlock, ResBlock
from dlas.models.audio.tts.lucidrains_dvae import DiscreteVAE
from dlas.trainer.networks import register_model
from dlas.utils.util import ceil_multiple, opt_get, print_network
class ResEncoder16x(nn.Module):
@ -18,20 +18,29 @@ class ResEncoder16x(nn.Module):
):
super().__init__()
attn = []
def edim(m):
dd = min(spec_dim + m * 128, hidden_dim)
return ceil_multiple(dd, 8)
self.downsampler = nn.Sequential(
ResBlock(spec_dim, out_channels=edim(2), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(3), out_channels=edim(3), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(4), out_channels=edim(4), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
ResBlock(spec_dim, out_channels=edim(2), use_conv=True, dims=1,
down=True, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1,
down=True, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(3), out_channels=edim(3), use_conv=True,
dims=1, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1,
down=True, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(4), out_channels=edim(4), use_conv=True,
dims=1, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(4), out_channels=hidden_dim, use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled))
self.encoder = nn.Sequential(
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True,
dims=1, checkpointing_enabled=checkpointing_enabled),
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True,
dims=1, checkpointing_enabled=checkpointing_enabled),
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True,
dims=1, checkpointing_enabled=checkpointing_enabled),
nn.GroupNorm(8, hidden_dim),
nn.SiLU(),
nn.Conv1d(hidden_dim, embedding_dim, 1),

@ -4,18 +4,22 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
import torch_intermediary as ml
from models.arch_util import ResBlock
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, TimestepBlock
from trainer.networks import register_model
from utils.util import checkpoint
import dlas.torch_intermediary as ml
from dlas.models.arch_util import ResBlock
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
timestep_embedding, zero_module)
from dlas.models.diffusion.unet_diffusion import (AttentionBlock,
TimestepBlock,
TimestepEmbedSequential)
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint
def is_latent(t):
return t.dtype == torch.float
def is_sequence(t):
return t.dtype == torch.long
@ -24,7 +28,8 @@ class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim):
super().__init__()
# nn.Embedding
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
self.m = nn.ModuleList(
[ml.Embedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
@ -56,7 +61,8 @@ class TimestepResBlock(TimestepBlock):
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding),
conv_nd(dims, channels, self.out_channels,
eff_kernel, padding=eff_padding),
)
self.emb_layers = nn.Sequential(
@ -71,14 +77,16 @@ class TimestepResBlock(TimestepBlock):
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
conv_nd(dims, self.out_channels, self.out_channels,
kernel_size, padding=padding)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
self.skip_connection = conv_nd(
dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
def forward(self, x, emb):
"""
@ -111,8 +119,10 @@ class TimestepResBlock(TimestepBlock):
class DiffusionLayer(TimestepBlock):
def __init__(self, model_channels, dropout, num_heads):
super().__init__()
self.resblk = TimestepResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True)
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
self.resblk = TimestepResBlock(
model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True)
self.attn = AttentionBlock(
model_channels, num_heads, relative_pos_embeddings=True)
def forward(self, x, time_emb):
y = self.resblk(x, time_emb)
@ -134,7 +144,8 @@ class FlatDiffusion(nn.Module):
num_heads=8,
# Parameters for regularization.
layer_drop=.1,
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
# This implements a mechanism similar to what is used in classifier-free training.
unconditioned_percentage=.1,
train_mel_head=False,
):
super().__init__()
@ -163,37 +174,52 @@ class FlatDiffusion(nn.Module):
# nn.Embedding
self.embeddings = ml.Embedding(token_count, model_channels)
else:
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
self.embeddings = MultiGroupEmbedding(
token_count, in_groups, model_channels)
self.latent_conditioner = nn.Sequential(
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads,
relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads,
relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads,
relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads,
relative_pos_embeddings=True),
)
self.code_converter = nn.Sequential(
ResBlock(dims=1, channels=model_channels, dropout=dropout),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads,
relative_pos_embeddings=True),
ResBlock(dims=1, channels=model_channels, dropout=dropout),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads,
relative_pos_embeddings=True),
ResBlock(dims=1, channels=model_channels, dropout=dropout),
)
self.code_norm = normalization(model_channels)
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2),
nn.Conv1d(
model_channels, model_channels*2, 3, padding=1, stride=2),
AttentionBlock(
model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(
model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(
model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(
model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False))
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
self.unconditioned_embedding = nn.Parameter(
torch.randn(1, model_channels, 1))
self.conditioning_timestep_integrator = TimestepEmbedSequential(
DiffusionLayer(model_channels, dropout, num_heads),
DiffusionLayer(model_channels, dropout, num_heads),
DiffusionLayer(model_channels, dropout, num_heads),
)
self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1)
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
self.integrating_conv = nn.Conv1d(
model_channels*2, model_channels, kernel_size=1)
self.mel_head = nn.Conv1d(
model_channels, in_channels, kernel_size=3, padding=1)
self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
[TimestepResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])
@ -201,7 +227,8 @@ class FlatDiffusion(nn.Module):
self.out = nn.Sequential(
normalization(model_channels),
nn.SiLU(),
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
zero_module(conv_nd(1, model_channels,
out_channels, 3, padding=1)),
)
if train_mel_head:
@ -231,7 +258,8 @@ class FlatDiffusion(nn.Module):
conditioning_input.shape) == 3 else conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
conds.append(self.contextual_embedder(
speech_conditioning_input[:, j]))
conds = torch.cat(conds, dim=-1)
cond_emb = conds.mean(dim=-1)
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
@ -239,18 +267,21 @@ class FlatDiffusion(nn.Module):
code_emb = self.latent_conditioner(aligned_conditioning)
else:
code_emb = self.embeddings(aligned_conditioning)
code_emb = code_emb.permute(0,2,1)
code_emb = code_emb.permute(0, 2, 1)
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
unconditioned_batches = torch.zeros(
(code_emb.shape[0], 1, 1), device=code_emb.device)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
device=code_emb.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
code_emb)
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
expanded_code_emb = F.interpolate(
code_emb, size=expected_seq_len, mode='nearest')
expanded_code_emb = self.code_converter(expanded_code_emb)
expanded_code_emb = self.code_norm(expanded_code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
expanded_code_emb = self.code_norm(
expanded_code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
if not return_code_pred:
return expanded_code_emb
@ -260,7 +291,6 @@ class FlatDiffusion(nn.Module):
mel_pred = mel_pred * unconditioned_batches.logical_not()
return expanded_code_emb, mel_pred
def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
"""
Apply the model to an input batch.
@ -273,27 +303,36 @@ class FlatDiffusion(nn.Module):
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
:return: an [N x C x ...] Tensor of outputs.
"""
assert precomputed_aligned_embeddings is not None or (codes is not None and conditioning_input is not None)
assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
assert precomputed_aligned_embeddings is not None or (
codes is not None and conditioning_input is not None)
# These two are mutually exclusive.
assert not (
return_code_pred and precomputed_aligned_embeddings is not None)
unused_params = []
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
code_emb = self.unconditioned_embedding.repeat(
x.shape[0], 1, x.shape[-1])
unused_params.extend(
list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
unused_params.extend(list(self.latent_conditioner.parameters()))
else:
if precomputed_aligned_embeddings is not None:
code_emb = precomputed_aligned_embeddings
else:
code_emb, mel_pred = self.timestep_independent(codes, conditioning_input, x.shape[-1], True)
code_emb, mel_pred = self.timestep_independent(
codes, conditioning_input, x.shape[-1], True)
if is_latent(codes):
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
unused_params.extend(
list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
else:
unused_params.extend(list(self.latent_conditioner.parameters()))
unused_params.extend(
list(self.latent_conditioner.parameters()))
unused_params.append(self.unconditioned_embedding)
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
time_emb = self.time_embed(
timestep_embedding(timesteps, self.model_channels))
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
x = self.inp_block(x)
x = torch.cat([x, code_emb], dim=1)
@ -325,10 +364,12 @@ class FlatDiffusion(nn.Module):
conditioning_input.shape) == 3 else conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
conds.append(self.contextual_embedder(
speech_conditioning_input[:, j]))
conds = torch.cat(conds, dim=-1)
return conds.mean(dim=-1)
@register_model
def register_flat_diffusion(opt_net, opt):
return FlatDiffusion(**opt_net['kwargs'])
@ -336,13 +377,13 @@ def register_flat_diffusion(opt_net, opt):
if __name__ == '__main__':
clip = torch.randn(2, 256, 400)
aligned_latent = torch.randn(2,388,512)
aligned_sequence = torch.randint(0,8,(2,100,8))
aligned_latent = torch.randn(2, 388, 512)
aligned_sequence = torch.randint(0, 8, (2, 100, 8))
cond = torch.randn(2, 256, 400)
ts = torch.LongTensor([600, 600])
model = FlatDiffusion(512, layer_drop=.3, unconditioned_percentage=.5, in_groups=8, train_mel_head=True)
model = FlatDiffusion(
512, layer_drop=.3, unconditioned_percentage=.5, in_groups=8, train_mel_head=True)
# Test with latent aligned conditioning
#o = model(clip, ts, aligned_latent, cond)
# o = model(clip, ts, aligned_latent, cond)
# Test with sequence aligned conditioning
o = model(clip, ts, aligned_sequence, cond, return_code_pred=True)

@ -1,17 +1,17 @@
import torch
from torch import nn
import torch.nn.functional as F
from torch import nn
from transformers import GPT2Config, GPT2Model
import torch_intermediary as ml
from models.arch_util import AttentionBlock, ResBlock
from models.audio.music.music_quantizer import MusicQuantizer
from models.audio.music.music_quantizer2 import MusicQuantizer2
from models.audio.tts.lucidrains_dvae import DiscreteVAE
from models.lucidrains.x_transformers import Encoder
from models.vqvae.vqvae import Quantize
from trainer.networks import register_model
from utils.util import opt_get, checkpoint, ceil_multiple, print_network
import dlas.torch_intermediary as ml
from dlas.models.arch_util import AttentionBlock, ResBlock
from dlas.models.audio.music.music_quantizer import MusicQuantizer
from dlas.models.audio.music.music_quantizer2 import MusicQuantizer2
from dlas.models.audio.tts.lucidrains_dvae import DiscreteVAE
from dlas.models.lucidrains.x_transformers import Encoder
from dlas.models.vqvae.vqvae import Quantize
from dlas.trainer.networks import register_model
from dlas.utils.util import ceil_multiple, checkpoint, opt_get, print_network
class ConditioningEncoder(nn.Module):
@ -22,9 +22,11 @@ class ConditioningEncoder(nn.Module):
num_attn_heads=4):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=3, stride=2, padding=1)
self.init = nn.Conv1d(spec_dim, embedding_dim,
kernel_size=3, stride=2, padding=1)
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_activation=True))
attn.append(AttentionBlock(embedding_dim,
num_attn_heads, do_activation=True))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
@ -43,14 +45,20 @@ class UpperConditioningEncoder(nn.Module):
super().__init__()
attn = []
self.init = nn.Sequential(nn.Conv1d(spec_dim, min(spec_dim+128, embedding_dim), kernel_size=3, stride=2, padding=1),
nn.Conv1d(min(spec_dim+128, embedding_dim), min(spec_dim+256, embedding_dim), kernel_size=3, stride=2, padding=1),
nn.Conv1d(min(spec_dim+256, embedding_dim), min(spec_dim+384, embedding_dim), kernel_size=3, stride=2, padding=1),
nn.Conv1d(min(spec_dim+384, embedding_dim), min(spec_dim+512, embedding_dim), kernel_size=3, stride=2, padding=1),
ResBlock(min(spec_dim+512, embedding_dim), dims=1),
nn.Conv1d(min(spec_dim+512, embedding_dim), min(spec_dim+512, embedding_dim), kernel_size=3, stride=2, padding=1),
nn.Conv1d(min(spec_dim+128, embedding_dim), min(
spec_dim+256, embedding_dim), kernel_size=3, stride=2, padding=1),
nn.Conv1d(min(spec_dim+256, embedding_dim), min(
spec_dim+384, embedding_dim), kernel_size=3, stride=2, padding=1),
nn.Conv1d(min(spec_dim+384, embedding_dim), min(
spec_dim+512, embedding_dim), kernel_size=3, stride=2, padding=1),
ResBlock(
min(spec_dim+512, embedding_dim), dims=1),
nn.Conv1d(min(spec_dim+512, embedding_dim), min(
spec_dim+512, embedding_dim), kernel_size=3, stride=2, padding=1),
ResBlock(min(spec_dim+512, embedding_dim), dims=1))
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_activation=True))
attn.append(AttentionBlock(embedding_dim,
num_attn_heads, do_activation=True))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
@ -67,21 +75,31 @@ class UpperQuantizer(nn.Module):
num_tokens):
super().__init__()
attn = []
def edim(m):
dd = max(embedding_dim//m, 128, spec_dim)
return ceil_multiple(dd, 8)
self.encoder = nn.Sequential(
ResBlock(spec_dim, out_channels=edim(6), use_conv=True, dims=1, down=True),
ResBlock(edim(6), out_channels=edim(5), use_conv=True, dims=1, down=True),
ResBlock(edim(5), out_channels=edim(4), use_conv=True, dims=1, down=True),
ResBlock(edim(4), out_channels=edim(3), use_conv=True, dims=1, down=True),
ResBlock(spec_dim, out_channels=edim(6),
use_conv=True, dims=1, down=True),
ResBlock(edim(6), out_channels=edim(5),
use_conv=True, dims=1, down=True),
ResBlock(edim(5), out_channels=edim(4),
use_conv=True, dims=1, down=True),
ResBlock(edim(4), out_channels=edim(3),
use_conv=True, dims=1, down=True),
ResBlock(edim(3), out_channels=edim(3), use_conv=True, dims=1),
ResBlock(edim(3), out_channels=edim(2), use_conv=True, dims=1, down=True),
ResBlock(edim(3), out_channels=edim(2),
use_conv=True, dims=1, down=True),
ResBlock(edim(2), out_channels=edim(2), use_conv=True, dims=1),
ResBlock(edim(2), out_channels=embedding_dim, use_conv=True, dims=1, down=True),
ResBlock(embedding_dim, out_channels=embedding_dim, use_conv=True, dims=1),
ResBlock(embedding_dim, out_channels=embedding_dim, use_conv=True, dims=1),
ResBlock(embedding_dim, out_channels=embedding_dim, use_conv=True, dims=1),
ResBlock(edim(2), out_channels=embedding_dim,
use_conv=True, dims=1, down=True),
ResBlock(embedding_dim, out_channels=embedding_dim,
use_conv=True, dims=1),
ResBlock(embedding_dim, out_channels=embedding_dim,
use_conv=True, dims=1),
ResBlock(embedding_dim, out_channels=embedding_dim,
use_conv=True, dims=1),
nn.GroupNorm(8, embedding_dim)
)
self.quantizer = Quantize(embedding_dim, num_tokens)
@ -95,7 +113,7 @@ class UpperQuantizer(nn.Module):
h = x
for lyr in self.encoder:
h = lyr(h)
h = h.permute(0,2,1)
h = h.permute(0, 2, 1)
h_quant, commitment_loss, codes = self.quantizer(h)
self.log_codes(codes)
return h_quant, commitment_loss
@ -105,7 +123,8 @@ class UpperQuantizer(nn.Module):
if self.internal_step % 10 == 0:
codes = codes.flatten()
l = codes.shape[0]
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
i = self.code_ind if (
self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
self.codes[i:i+l] = codes.cpu()
self.code_ind = self.code_ind + l
if self.code_ind >= self.codes.shape[0]:
@ -122,7 +141,8 @@ class GptMusicLower(nn.Module):
self.freeze_upper_until = freeze_upper_until
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, use_cache=False)
self.target_quantizers = nn.ModuleList([DiscreteVAE(**vqargs).eval() for _ in range(num_vaes)])
self.target_quantizers = nn.ModuleList(
[DiscreteVAE(**vqargs).eval() for _ in range(num_vaes)])
self.upper_quantizer = UpperQuantizer(256, dim, num_upper_vectors)
self.fp16 = fp16
self.internal_step = 0
@ -132,14 +152,17 @@ class GptMusicLower(nn.Module):
p.DO_NOT_TRAIN = True
p.requires_grad = False
self.conditioning_encoder = ConditioningEncoder(256, dim, attn_blocks=4, num_attn_heads=dim//64)
self.conditioning_encoder = ConditioningEncoder(
256, dim, attn_blocks=4, num_attn_heads=dim//64)
self.gpt = GPT2Model(self.config)
del self.gpt.wte # Unused, we'll do our own embeddings.
# nn.Embedding
self.embeddings = nn.ModuleList([ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
self.heads = nn.ModuleList([ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
self.embeddings = nn.ModuleList(
[ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
self.heads = nn.ModuleList(
[ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
def forward(self, mel, conditioning, return_latent=False):
unused_params = []
@ -159,14 +182,17 @@ class GptMusicLower(nn.Module):
unused_params.extend(list(self.upper_quantizer.parameters()))
else:
self.upper_quantizer = self.upper_quantizer.train()
upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True)
upper_vector = F.interpolate(upper_vector.permute(0,2,1), size=codes.shape[1], mode='linear')
upper_vector = upper_vector.permute(0,2,1)
upper_vector, upper_diversity = self.upper_quantizer(
mel, return_decoder_latent=True)
upper_vector = F.interpolate(upper_vector.permute(
0, 2, 1), size=codes.shape[1], mode='linear')
upper_vector = upper_vector.permute(0, 2, 1)
inputs = codes[:, :-1]
targets = codes
upper_vector = upper_vector[:, :-1]
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
h = [embedding(inputs[:, :, i])
for i, embedding in enumerate(self.embeddings)]
h = torch.cat(h, dim=-1) + upper_vector
with torch.autocast(mel.device.type, enabled=self.fp16):
@ -183,8 +209,8 @@ class GptMusicLower(nn.Module):
losses = 0
for i, head in enumerate(self.heads):
logits = head(h).permute(0,2,1)
loss = F.cross_entropy(logits, targets[:,:,i])
logits = head(h).permute(0, 2, 1)
loss = F.cross_entropy(logits, targets[:, :, i])
losses = losses + loss
unused_adder = 0
@ -221,11 +247,15 @@ class GptMusicUpper(nn.Module):
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True,
use_cache=False)
self.upper_quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[dim,
max(512,dim-128),
max(512,dim-256),
max(512,dim-384),
max(512,dim-512),
max(512,dim-512)], codevector_dim=dim,
max(512,
dim-128),
max(512,
dim-256),
max(512,
dim-384),
max(512,
dim-512),
max(512, dim-512)], codevector_dim=dim,
codebook_size=num_upper_vectors, codebook_groups=num_upper_groups,
expressive_downsamples=True)
# Following are unused quantizer constructs we delete to avoid DDP errors (and to be efficient.. of course..)
@ -235,15 +265,17 @@ class GptMusicUpper(nn.Module):
p.DO_NOT_TRAIN = True
p.requires_grad = False
self.conditioning_encoder = UpperConditioningEncoder(256, dim, attn_blocks=4, num_attn_heads=dim//64)
self.conditioning_encoder = UpperConditioningEncoder(
256, dim, attn_blocks=4, num_attn_heads=dim//64)
self.gpt = GPT2Model(self.config)
del self.gpt.wte # Unused, we'll do our own embeddings.
# nn.Embedding
self.embeddings = nn.ModuleList([ml.Embedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)])
self.heads = nn.ModuleList([ml.Linear(dim, num_upper_vectors) for _ in range(num_upper_groups)])
self.embeddings = nn.ModuleList([ml.Embedding(
num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)])
self.heads = nn.ModuleList(
[ml.Linear(dim, num_upper_vectors) for _ in range(num_upper_groups)])
def forward(self, mel, conditioning, return_latent=False):
with torch.no_grad():
@ -252,7 +284,8 @@ class GptMusicUpper(nn.Module):
inputs = codes[:, :-1]
targets = codes
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
h = [embedding(inputs[:, :, i])
for i, embedding in enumerate(self.embeddings)]
h = torch.cat(h, dim=-1)
with torch.autocast(mel.device.type, enabled=self.fp16):
@ -269,8 +302,8 @@ class GptMusicUpper(nn.Module):
losses = 0
for i, head in enumerate(self.heads):
logits = head(h).permute(0,2,1)
loss = F.cross_entropy(logits, targets[:,:,i])
logits = head(h).permute(0, 2, 1)
loss = F.cross_entropy(logits, targets[:, :, i])
losses = losses + loss
return losses / self.num_groups
@ -293,6 +326,7 @@ class GptMusicUpper(nn.Module):
def register_music_gpt_lower(opt_net, opt):
return GptMusicLower(**opt_get(opt_net, ['kwargs'], {}))
@register_model
def register_music_gpt_upper(opt_net, opt):
return GptMusicUpper(**opt_get(opt_net, ['kwargs'], {}))
@ -301,11 +335,11 @@ def register_music_gpt_upper(opt_net, opt):
def test_lower():
model = GptMusicLower(dim=512, layers=12, fp16=False, freeze_upper_until=1000,
num_target_vectors=8192, num_upper_vectors=8192, num_vaes=4,
vqargs= {
'positional_dims': 1, 'channels': 64,
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
})
vqargs={
'positional_dims': 1, 'channels': 64,
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
})
quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\7500_generator.pth',
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\11000_generator.pth',
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_high\\models\\11500_generator.pth',
@ -316,7 +350,7 @@ def test_lower():
torch.save(model.state_dict(), 'sample.pth')
print_network(model)
mel = torch.randn(2,256,400)
mel = torch.randn(2, 256, 400)
model(mel, mel)
pg = model.get_grad_norm_parameter_groups()
@ -335,14 +369,15 @@ def test_lower():
def test_upper():
lower = GptMusicLower(512, 12)
lower.load_state_dict(torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth'))
lower.load_state_dict(torch.load(
'D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth'))
model = GptMusicUpper(512, 12)
model.upper_quantizer.load_state_dict(lower.upper_quantizer.state_dict())
torch.save(model.state_dict(), 'sample.pth')
mel = torch.randn(2,256,2500)
mel = torch.randn(2, 256, 2500)
model(mel, mel)
model.get_grad_norm_parameter_groups()
if __name__ == '__main__':
test_lower()
test_lower()

@ -2,12 +2,12 @@ import torch
import torch.nn.functional as F
from torch import nn
from transformers import GPT2Config, GPT2Model
import torch_intermediary as ml
from models.arch_util import AttentionBlock, ResBlock
from models.audio.tts.lucidrains_dvae import DiscreteVAE
from trainer.networks import register_model
from utils.util import opt_get, ceil_multiple, print_network
import dlas.torch_intermediary as ml
from dlas.models.arch_util import AttentionBlock, ResBlock
from dlas.models.audio.tts.lucidrains_dvae import DiscreteVAE
from dlas.trainer.networks import register_model
from dlas.utils.util import ceil_multiple, opt_get, print_network
class UpperEncoder(nn.Module):
@ -19,22 +19,30 @@ class UpperEncoder(nn.Module):
):
super().__init__()
attn = []
def edim(m):
dd = min(spec_dim + m * 128, hidden_dim)
return ceil_multiple(dd, 8)
self.downsampler = nn.Sequential(
ResBlock(spec_dim, out_channels=edim(1), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(1), out_channels=edim(2), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
ResBlock(spec_dim, out_channels=edim(1), use_conv=True, dims=1,
down=True, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(1), out_channels=edim(2), use_conv=True, dims=1,
down=True, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1,
down=True, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(3), out_channels=edim(4), use_conv=True,
dims=1, checkpointing_enabled=checkpointing_enabled),
ResBlock(edim(4), out_channels=hidden_dim, use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled))
self.encoder = nn.Sequential(
AttentionBlock(hidden_dim, 4, do_activation=True),
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True,
dims=1, checkpointing_enabled=checkpointing_enabled),
AttentionBlock(hidden_dim, 4, do_activation=True),
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True,
dims=1, checkpointing_enabled=checkpointing_enabled),
AttentionBlock(hidden_dim, 4, do_activation=True),
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled),
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True,
dims=1, checkpointing_enabled=checkpointing_enabled),
nn.GroupNorm(8, hidden_dim),
nn.SiLU(),
nn.Conv1d(hidden_dim, embedding_dim, 1),
@ -47,8 +55,6 @@ class UpperEncoder(nn.Module):
return h
class GptMusicLower(nn.Module):
def __init__(self, dim, layers, encoder_out_dim, dropout=0, num_target_vectors=8192, fp16=True, num_vaes=4, vqargs={}):
super().__init__()
@ -58,7 +64,8 @@ class GptMusicLower(nn.Module):
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True,
use_cache=False)
self.target_quantizers = nn.ModuleList([DiscreteVAE(**vqargs).eval() for _ in range(num_vaes)])
self.target_quantizers = nn.ModuleList(
[DiscreteVAE(**vqargs).eval() for _ in range(num_vaes)])
self.upper_encoder = UpperEncoder(256, dim, encoder_out_dim)
self.encoder_projector = nn.Conv1d(encoder_out_dim, dim, 1)
self.fp16 = fp16
@ -75,8 +82,10 @@ class GptMusicLower(nn.Module):
del self.gpt.wte # Unused, we'll do our own embeddings.
# nn.Embedding
self.embeddings = nn.ModuleList([ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
self.heads = nn.ModuleList([ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
self.embeddings = nn.ModuleList(
[ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
self.heads = nn.ModuleList(
[ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
def forward(self, mel, return_latent=False):
unused_params = []
@ -92,20 +101,23 @@ class GptMusicLower(nn.Module):
upper_vector = self.upper_encoder(mel)
upper_vector = self.encoder_projector(upper_vector)
# WTB slerp
upper_vector = F.interpolate(upper_vector, size=codes.shape[1], mode='linear')
upper_vector = upper_vector.permute(0,2,1)
upper_vector = F.interpolate(
upper_vector, size=codes.shape[1], mode='linear')
upper_vector = upper_vector.permute(0, 2, 1)
inputs = codes[:, :-1]
targets = codes
upper_vector = upper_vector[:, :-1]
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
h = [embedding(inputs[:, :, i])
for i, embedding in enumerate(self.embeddings)]
h = torch.cat(h, dim=-1) + upper_vector
with torch.autocast(mel.device.type, enabled=self.fp16):
# Stick the conditioning embedding on the front of the input sequence.
# The transformer will learn how to integrate it.
# This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token.
h = torch.cat([self.start_token.repeat(h.shape[0], 1, 1), h], dim=1)
h = torch.cat(
[self.start_token.repeat(h.shape[0], 1, 1), h], dim=1)
h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state
@ -114,8 +126,8 @@ class GptMusicLower(nn.Module):
losses = 0
for i, head in enumerate(self.heads):
logits = head(h).permute(0,2,1)
loss = F.cross_entropy(logits, targets[:,:,i])
logits = head(h).permute(0, 2, 1)
loss = F.cross_entropy(logits, targets[:, :, i])
losses = losses + loss
unused_adder = 0
@ -143,10 +155,10 @@ def register_music_gpt_lower2(opt_net, opt):
def test_lower():
model = GptMusicLower(dim=1024, encoder_out_dim=256, layers=16, fp16=False, num_target_vectors=8192, num_vaes=4,
vqargs= {'positional_dims': 1, 'channels': 64,
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
})
vqargs={'positional_dims': 1, 'channels': 64,
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
})
quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\7500_generator.pth',
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\11000_generator.pth',
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_high\\models\\11500_generator.pth',
@ -157,7 +169,7 @@ def test_lower():
torch.save(model.state_dict(), 'sample.pth')
print_network(model)
mel = torch.randn(2,256,400)
mel = torch.randn(2, 256, 400)
model(mel)
pg = model.get_grad_norm_parameter_groups()

@ -3,13 +3,15 @@ import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding
from models.lucidrains.vq import VectorQuantize
from models.lucidrains.x_transformers import FeedForward, Attention, Decoder, RMSScaleShiftNorm
from trainer.networks import register_model
from utils.util import checkpoint
import dlas.torch_intermediary as ml
from dlas.models.diffusion.nn import timestep_embedding
from dlas.models.lucidrains.vq import VectorQuantize
from dlas.models.lucidrains.x_transformers import (Attention, Decoder,
FeedForward,
RMSScaleShiftNorm)
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint
class SelfClassifyingHead(nn.Module):
@ -19,7 +21,7 @@ class SelfClassifyingHead(nn.Module):
self.num_classes = classes
self.temperature = init_temperature
self.dec = Decoder(dim=dim, depth=head_depth, heads=4, ff_dropout=dropout, ff_mult=2, attn_dropout=dropout,
use_rmsnorm=True, ff_glu=True, do_checkpointing=False)
use_rmsnorm=True, ff_glu=True, do_checkpointing=False)
self.quantizer = VectorQuantize(out_dim, classes, use_cosine_sim=False, threshold_ema_dead_code=2,
sample_codebook_temp=init_temperature)
self.to_output = ml.Linear(dim, out_dim)
@ -39,7 +41,8 @@ class SelfClassifyingHead(nn.Module):
codes = []
q_reg = 0
for i in range(self.seq_len):
q, c = checkpoint(functools.partial(self.do_ar_step, used_codes=codes), torch.stack(stack, dim=1))
q, c = checkpoint(functools.partial(
self.do_ar_step, used_codes=codes), torch.stack(stack, dim=1))
q_reg = q_reg + (q ** 2).mean()
s = torch.sigmoid(q)
@ -48,10 +51,13 @@ class SelfClassifyingHead(nn.Module):
# If the addition would strictly make the result worse, set it to 0. Sometimes.
if len(results) > 0:
worsen = (F.mse_loss(outputs[-1], target, reduction='none').sum(-1) < F.mse_loss(output, target, reduction='none').sum(-1)).float()
worsen = (F.mse_loss(outputs[-1], target, reduction='none').sum(-1) < F.mse_loss(
output, target, reduction='none').sum(-1)).float()
probabilistic_worsen = torch.rand_like(worsen) * worsen > .5
output = output * probabilistic_worsen.unsqueeze(-1) # This is non-differentiable, but still deterministic.
c[probabilistic_worsen] = -1 # Code of -1 means the code was unused.
# This is non-differentiable, but still deterministic.
output = output * probabilistic_worsen.unsqueeze(-1)
# Code of -1 means the code was unused.
c[probabilistic_worsen] = -1
s = s * probabilistic_worsen.unsqueeze(-1)
outputs[-1] = s
@ -65,7 +71,8 @@ class VectorResBlock(nn.Module):
def __init__(self, dim, dropout):
super().__init__()
self.norm = nn.BatchNorm1d(dim)
self.ff = FeedForward(dim, mult=2, glu=True, dropout=dropout, zero_init_output=True)
self.ff = FeedForward(dim, mult=2, glu=True,
dropout=dropout, zero_init_output=True)
def forward(self, x):
h = self.norm(x.unsqueeze(-1)).squeeze(-1)
@ -92,8 +99,10 @@ class InstrumentQuantizer(nn.Module):
super().__init__()
self.op_dim = op_dim
self.proj = ml.Linear(op_dim, dim)
self.encoder = nn.ModuleList([VectorResBlock(dim, dropout) for _ in range(enc_depth)])
self.heads = SelfClassifyingHead(dim, num_classes, op_dim, head_depth, class_seq_len, dropout, max_temp)
self.encoder = nn.ModuleList(
[VectorResBlock(dim, dropout) for _ in range(enc_depth)])
self.heads = SelfClassifyingHead(
dim, num_classes, op_dim, head_depth, class_seq_len, dropout, max_temp)
self.min_gumbel_temperature = min_temp
self.max_gumbel_temperature = max_temp
self.gumbel_temperature_decay = temp_decay
@ -109,16 +118,19 @@ class InstrumentQuantizer(nn.Module):
x = (x + 1) / 2
b, c, s = x.shape
px = x.permute(0,2,1) # B,S,C shape
px = x.permute(0, 2, 1) # B,S,C shape
f = px.reshape(-1, self.op_dim)
h = self.proj(f)
for lyr in self.encoder:
h = lyr(h)
reconstructions, codes, q_reg = self.heads(h, f)
reconstruction_losses = torch.stack([F.mse_loss(r.reshape(b, s, c), px) for r in reconstructions])
r_follow = torch.arange(1, reconstruction_losses.shape[0]+1, device=x.device)
reconstruction_losses = (reconstruction_losses * r_follow / r_follow.shape[0])
reconstruction_losses = torch.stack(
[F.mse_loss(r.reshape(b, s, c), px) for r in reconstructions])
r_follow = torch.arange(
1, reconstruction_losses.shape[0]+1, device=x.device)
reconstruction_losses = (
reconstruction_losses * r_follow / r_follow.shape[0])
self.log_codes(codes)
return reconstruction_losses, q_reg
@ -126,7 +138,8 @@ class InstrumentQuantizer(nn.Module):
def log_codes(self, codes):
if self.internal_step % 5 == 0:
l = codes.shape[0]
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
i = self.code_ind if (
self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
self.codes[i:i+l] = codes.cpu()
self.code_ind = self.code_ind + l
if self.code_ind >= self.codes.shape[0]:
@ -143,9 +156,9 @@ class InstrumentQuantizer(nn.Module):
def update_for_step(self, step, *args):
self.internal_step = step
self.heads.quantizer._codebook.sample_codebook_temp = max(
self.max_gumbel_temperature * self.gumbel_temperature_decay**step,
self.min_gumbel_temperature,
)
self.max_gumbel_temperature * self.gumbel_temperature_decay**step,
self.min_gumbel_temperature,
)
def get_grad_norm_parameter_groups(self):
groups = {
@ -162,6 +175,6 @@ def register_instrument_quantizer(opt_net, opt):
if __name__ == '__main__':
inp = torch.randn((4,256,200)).clamp(-1,1)
inp = torch.randn((4, 256, 200)).clamp(-1, 1)
model = InstrumentQuantizer(256, 512, 4096, 8, 3)
model(inp)

@ -2,10 +2,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from models.arch_util import ResBlock, AttentionBlock
from models.audio.music.flat_diffusion import MultiGroupEmbedding
from trainer.networks import register_model
from utils.util import checkpoint
from dlas.models.arch_util import AttentionBlock, ResBlock
from dlas.models.audio.music.flat_diffusion import MultiGroupEmbedding
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint
class Code2Mel(nn.Module):
@ -13,27 +13,33 @@ class Code2Mel(nn.Module):
super().__init__()
self.emb = MultiGroupEmbedding(num_tokens, num_groups, base_dim)
self.base_blocks = nn.Sequential(ResBlock(base_dim, dropout, dims=1),
AttentionBlock(base_dim, num_heads=base_dim//64),
AttentionBlock(
base_dim, num_heads=base_dim//64),
ResBlock(base_dim, dropout, dims=1))
l2dim = base_dim-256
self.l2_up_block = nn.Conv1d(base_dim, l2dim, kernel_size=5, padding=2)
self.l2_blocks = nn.Sequential(ResBlock(l2dim, dropout, kernel_size=5, dims=1),
AttentionBlock(l2dim, num_heads=base_dim//64),
ResBlock(l2dim, dropout, kernel_size=5, dims=1),
AttentionBlock(l2dim, num_heads=base_dim//64),
ResBlock(l2dim, dropout, dims=1),
ResBlock(l2dim, dropout, dims=1))
AttentionBlock(
l2dim, num_heads=base_dim//64),
ResBlock(l2dim, dropout,
kernel_size=5, dims=1),
AttentionBlock(
l2dim, num_heads=base_dim//64),
ResBlock(l2dim, dropout, dims=1),
ResBlock(l2dim, dropout, dims=1))
l3dim = l2dim-256
self.l3_up_block = nn.Conv1d(l2dim, l3dim, kernel_size=5, padding=2)
self.l3_blocks = nn.Sequential(ResBlock(l3dim, dropout, kernel_size=5, dims=1),
AttentionBlock(l3dim, num_heads=base_dim//64),
ResBlock(l3dim, dropout, kernel_size=5, dims=1),
AttentionBlock(
l3dim, num_heads=base_dim//64),
ResBlock(l3dim, dropout,
kernel_size=5, dims=1),
ResBlock(l3dim, dropout, dims=1))
self.final_block = nn.Conv1d(l3dim, out_dim, kernel_size=3, padding=1)
def forward(self, codes, target):
with torch.autocast(codes.device.type):
h = self.emb(codes).permute(0,2,1)
h = self.emb(codes).permute(0, 2, 1)
h = checkpoint(self.base_blocks, h)
h = F.interpolate(h, scale_factor=2, mode='linear')
h = self.l2_up_block(h)
@ -52,6 +58,6 @@ def register_code2mel(opt_net, opt):
if __name__ == '__main__':
model = Code2Mel()
codes = torch.randint(0,16, (2,200,4))
target = torch.randn(2,256,804)
model(codes, target)
codes = torch.randint(0, 16, (2, 200, 4))
target = torch.randn(2, 256, 804)
model(codes, target)

@ -1,11 +1,11 @@
import torch
from torch import nn
import torch.nn.functional as F
import torch_intermediary as ml
from torch import nn
from transformers import GPT2Config, GPT2Model
from trainer.networks import register_model
from utils.util import opt_get
import dlas.torch_intermediary as ml
from dlas.trainer.networks import register_model
from dlas.utils.util import opt_get
class Mel2VecCodesGpt(nn.Module):
@ -19,8 +19,10 @@ class Mel2VecCodesGpt(nn.Module):
self.gpt = GPT2Model(self.config)
del self.gpt.wte # Unused, we'll do our own embeddings.
# nn.Embedding
self.embeddings = nn.ModuleList([ml.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)])
self.heads = nn.ModuleList([ml.Linear(dim, num_vectors) for _ in range(num_groups)])
self.embeddings = nn.ModuleList(
[ml.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)])
self.heads = nn.ModuleList(
[ml.Linear(dim, num_vectors) for _ in range(num_groups)])
def forward(self, codes):
assert codes.shape[-1] == self.num_groups
@ -28,14 +30,15 @@ class Mel2VecCodesGpt(nn.Module):
inputs = codes[:, :-1]
targets = codes[:, 1:]
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
h = [embedding(inputs[:, :, i])
for i, embedding in enumerate(self.embeddings)]
h = torch.cat(h, dim=-1)
h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state
losses = 0
for i, head in enumerate(self.heads):
logits = head(h).permute(0,2,1)
loss = F.cross_entropy(logits, targets[:,:,i])
logits = head(h).permute(0, 2, 1)
loss = F.cross_entropy(logits, targets[:, :, i])
losses = losses + loss
return losses / self.num_groups
@ -48,5 +51,5 @@ def register_music_gpt(opt_net, opt):
if __name__ == '__main__':
model = Mel2VecCodesGpt(512, 8)
codes = torch.randint(0,8, (2,300,8))
model(codes)
codes = torch.randint(0, 8, (2, 300, 8))
model(codes)

@ -1,14 +1,14 @@
import functools
import torch
from torch import nn
import torch.nn.functional as F
import torch_intermediary as ml
from torch import nn
from models.arch_util import zero_module
from models.vqvae.vqvae import Quantize
from trainer.networks import register_model
from utils.util import checkpoint, ceil_multiple, print_network
import dlas.torch_intermediary as ml
from dlas.models.arch_util import zero_module
from dlas.models.vqvae.vqvae import Quantize
from dlas.trainer.networks import register_model
from dlas.utils.util import ceil_multiple, checkpoint, print_network
class Downsample(nn.Module):
@ -37,13 +37,13 @@ class ResBlock(nn.Module):
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv1d(chan, chan, 3, padding = 1),
nn.Conv1d(chan, chan, 3, padding=1),
nn.GroupNorm(8, chan),
nn.SiLU(),
nn.Conv1d(chan, chan, 3, padding = 1),
nn.Conv1d(chan, chan, 3, padding=1),
nn.GroupNorm(8, chan),
nn.SiLU(),
zero_module(nn.Conv1d(chan, chan, 3, padding = 1)),
zero_module(nn.Conv1d(chan, chan, 3, padding=1)),
)
def forward(self, x):
@ -74,7 +74,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
# storage for codebook variables (codewords)
self.codevectors = nn.Parameter(
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
torch.FloatTensor(1, self.num_groups * self.num_vars,
codevector_dim // self.num_groups)
)
self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars)
@ -95,7 +96,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
else:
marginal_probs = probs.mean(dim=0)
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
perplexity = torch.exp(-torch.sum(marginal_probs *
torch.log(marginal_probs + 1e-7), dim=-1)).sum()
return perplexity
def get_codes(self, hidden_states):
@ -103,9 +105,11 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
# project to codevector dim
hidden_states = self.weight_proj(hidden_states)
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
hidden_states = hidden_states.view(
batch_size * sequence_length * self.num_groups, -1)
codevector_idx = hidden_states.argmax(dim=-1)
idxs = codevector_idx.view(batch_size, sequence_length, self.num_groups)
idxs = codevector_idx.view(
batch_size, sequence_length, self.num_groups)
return idxs
def forward(self, hidden_states, mask_time_indices=None, return_probs=False):
@ -113,7 +117,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
# project to codevector dim
hidden_states = self.weight_proj(hidden_states)
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
hidden_states = hidden_states.view(
batch_size * sequence_length * self.num_groups, -1)
if self.training:
# sample code vector probs via gumbel in differentiable way
@ -125,7 +130,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
codevector_soft_dist = torch.softmax(
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
)
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
perplexity = self._compute_perplexity(
codevector_soft_dist, mask_time_indices)
else:
# take argmax in non-differentiable way
# compute hard codevector distribution (one hot)
@ -133,15 +139,20 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
-1, codevector_idx.view(-1, 1), 1.0
)
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
codevector_probs = codevector_probs.view(
batch_size * sequence_length, self.num_groups, -1)
perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
perplexity = self._compute_perplexity(
codevector_probs, mask_time_indices)
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
codevector_probs = codevector_probs.view(
batch_size * sequence_length, -1)
# use probs to retrieve codevectors
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
codevectors_per_group = codevector_probs.unsqueeze(
-1) * self.codevectors
codevectors = (
codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
codevectors_per_group.view(
batch_size * sequence_length, self.num_groups, self.num_vars, -1)
.sum(-2)
.view(batch_size, sequence_length, -1)
)
@ -164,7 +175,8 @@ class MusicQuantizer(nn.Module):
self.use_vqvae_quantizer = use_vqvae_quantizer
if use_vqvae_quantizer:
self.quantizer = Quantize(inner_dim[0], codebook_size)
assert codevector_dim == inner_dim[0] # Because this quantizer doesn't support different sizes.
# Because this quantizer doesn't support different sizes.
assert codevector_dim == inner_dim[0]
else:
self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim[0], codevector_dim=codevector_dim,
num_codevector_groups=codebook_groups,
@ -174,8 +186,10 @@ class MusicQuantizer(nn.Module):
self.num_losses_record = []
if down_steps == 0:
self.down = nn.Conv1d(inp_channels, inner_dim[0], kernel_size=3, padding=1)
self.up = nn.Conv1d(inner_dim[0], inp_channels, kernel_size=3, padding=1)
self.down = nn.Conv1d(
inp_channels, inner_dim[0], kernel_size=3, padding=1)
self.up = nn.Conv1d(
inner_dim[0], inp_channels, kernel_size=3, padding=1)
elif down_steps == 2:
self.down = nn.Sequential(nn.Conv1d(inp_channels, inner_dim[-1], kernel_size=3, padding=1),
Downsample(inner_dim[-1], inner_dim[-2]),
@ -201,25 +215,27 @@ class MusicQuantizer(nn.Module):
def get_codes(self, mel):
h = self.down(mel)
h = self.encoder(h)
h = self.enc_norm(h.permute(0,2,1))
h = self.enc_norm(h.permute(0, 2, 1))
return self.quantizer.get_codes(h)
def forward(self, mel, return_decoder_latent=False):
orig_mel = mel
cm = ceil_multiple(mel.shape[-1], 4)
if cm != 0:
mel = F.pad(mel, (0,cm-mel.shape[-1]))
mel = F.pad(mel, (0, cm-mel.shape[-1]))
h = self.down(mel)
h = self.encoder(h)
h = self.enc_norm(h.permute(0,2,1))
h = self.enc_norm(h.permute(0, 2, 1))
if self.use_vqvae_quantizer:
codevectors, diversity, codes = self.quantizer(h)
else:
codevectors, perplexity, codes = self.quantizer(h, return_probs=True)
diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors
codevectors, perplexity, codes = self.quantizer(
h, return_probs=True)
diversity = (self.quantizer.num_codevectors -
perplexity) / self.quantizer.num_codevectors
self.log_codes(codes)
h = self.decoder(codevectors.permute(0,2,1))
h = self.decoder(codevectors.permute(0, 2, 1))
if return_decoder_latent:
return h, diversity
@ -233,13 +249,14 @@ class MusicQuantizer(nn.Module):
if self.internal_step % 5 == 0:
if not self.use_vqvae_quantizer:
codes = torch.argmax(codes, dim=-1)
ccodes = codes[:,:,0]
for j in range(1,codes.shape[-1]):
ccodes += codes[:,:,j] * self.codebook_size ** j
ccodes = codes[:, :, 0]
for j in range(1, codes.shape[-1]):
ccodes += codes[:, :, j] * self.codebook_size ** j
codes = ccodes
codes = codes.flatten()
l = codes.shape[0]
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
i = self.code_ind if (
self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
self.codes[i:i+l] = codes.cpu()
self.code_ind = self.code_ind + l
if self.code_ind >= self.codes.shape[0]:
@ -259,7 +276,8 @@ def register_music_quantizer(opt_net, opt):
if __name__ == '__main__':
model = MusicQuantizer(inner_dim=[1024,1024,512], codevector_dim=1024, codebook_size=8192, codebook_groups=0, use_vqvae_quantizer=True)
model = MusicQuantizer(inner_dim=[1024, 1024, 512], codevector_dim=1024,
codebook_size=8192, codebook_groups=0, use_vqvae_quantizer=True)
print_network(model)
mel = torch.randn((2,256,782))
model(mel)
mel = torch.randn((2, 256, 782))
model(mel)

@ -1,14 +1,14 @@
import functools
import torch
from torch import nn
import torch.nn.functional as F
import torch_intermediary as ml
from torch import nn
from models.arch_util import zero_module
from models.vqvae.vqvae import Quantize
from trainer.networks import register_model
from utils.util import checkpoint, ceil_multiple, print_network
import dlas.torch_intermediary as ml
from dlas.models.arch_util import zero_module
from dlas.models.vqvae.vqvae import Quantize
from dlas.trainer.networks import register_model
from dlas.utils.util import ceil_multiple, checkpoint, print_network
class Downsample(nn.Module):
@ -16,7 +16,8 @@ class Downsample(nn.Module):
super().__init__()
self.interpolate = not stride_down
if stride_down:
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=3, padding=1, stride=2)
self.conv = nn.Conv1d(
chan_in, chan_out, kernel_size=3, padding=1, stride=2)
else:
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=3, padding=1)
if norm:
@ -49,13 +50,13 @@ class ResBlock(nn.Module):
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv1d(chan, chan, 3, padding = 1),
nn.Conv1d(chan, chan, 3, padding=1),
nn.GroupNorm(8, chan),
nn.SiLU(),
nn.Conv1d(chan, chan, 3, padding = 1),
nn.Conv1d(chan, chan, 3, padding=1),
nn.GroupNorm(8, chan),
nn.SiLU(),
zero_module(nn.Conv1d(chan, chan, 3, padding = 1)),
zero_module(nn.Conv1d(chan, chan, 3, padding=1)),
)
def forward(self, x):
@ -86,7 +87,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
# storage for codebook variables (codewords)
self.codevectors = nn.Parameter(
torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups)
torch.FloatTensor(1, self.num_groups * self.num_vars,
codevector_dim // self.num_groups)
)
self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars)
@ -107,7 +109,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
else:
marginal_probs = probs.mean(dim=0)
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
perplexity = torch.exp(-torch.sum(marginal_probs *
torch.log(marginal_probs + 1e-7), dim=-1)).sum()
return perplexity
def get_codes(self, hidden_states):
@ -115,9 +118,11 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
# project to codevector dim
hidden_states = self.weight_proj(hidden_states)
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
hidden_states = hidden_states.view(
batch_size * sequence_length * self.num_groups, -1)
codevector_idx = hidden_states.argmax(dim=-1)
idxs = codevector_idx.view(batch_size, sequence_length, self.num_groups)
idxs = codevector_idx.view(
batch_size, sequence_length, self.num_groups)
return idxs
def forward(self, hidden_states, mask_time_indices=None, return_probs=False):
@ -125,7 +130,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
# project to codevector dim
hidden_states = self.weight_proj(hidden_states)
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
hidden_states = hidden_states.view(
batch_size * sequence_length * self.num_groups, -1)
if self.training:
# sample code vector probs via gumbel in differentiable way
@ -137,7 +143,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
codevector_soft_dist = torch.softmax(
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
)
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
perplexity = self._compute_perplexity(
codevector_soft_dist, mask_time_indices)
else:
# take argmax in non-differentiable way
# compute hard codevector distribution (one hot)
@ -145,15 +152,20 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
-1, codevector_idx.view(-1, 1), 1.0
)
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
codevector_probs = codevector_probs.view(
batch_size * sequence_length, self.num_groups, -1)
perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
perplexity = self._compute_perplexity(
codevector_probs, mask_time_indices)
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
codevector_probs = codevector_probs.view(
batch_size * sequence_length, -1)
# use probs to retrieve codevectors
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
codevectors_per_group = codevector_probs.unsqueeze(
-1) * self.codevectors
codevectors = (
codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
codevectors_per_group.view(
batch_size * sequence_length, self.num_groups, self.num_vars, -1)
.sum(-2)
.view(batch_size, sequence_length, -1)
)
@ -183,12 +195,14 @@ class MusicQuantizer2(nn.Module):
self.num_losses_record = []
if down_steps == 0:
self.down = nn.Conv1d(inp_channels, inner_dim[0], kernel_size=3, padding=1)
self.up = nn.Conv1d(inner_dim[0], inp_channels, kernel_size=3, padding=1)
self.down = nn.Conv1d(
inp_channels, inner_dim[0], kernel_size=3, padding=1)
self.up = nn.Conv1d(
inner_dim[0], inp_channels, kernel_size=3, padding=1)
elif down_steps == 2:
self.down = nn.Sequential(nn.Conv1d(inp_channels, inner_dim[-1], kernel_size=3, padding=1),
*[Downsample(inner_dim[-i], inner_dim[-i-1], norm=expressive_downsamples, act=expressive_downsamples,
stride_down=expressive_downsamples) for i in range(1,len(inner_dim))])
stride_down=expressive_downsamples) for i in range(1, len(inner_dim))])
self.up = nn.Sequential(*[Upsample(inner_dim[i], inner_dim[i+1]) for i in range(len(inner_dim)-1)] +
[nn.Conv1d(inner_dim[-1], inp_channels, kernel_size=3, padding=1)])
@ -209,22 +223,23 @@ class MusicQuantizer2(nn.Module):
def get_codes(self, mel):
h = self.down(mel)
h = self.encoder(h)
h = self.enc_norm(h.permute(0,2,1))
h = self.enc_norm(h.permute(0, 2, 1))
return self.quantizer.get_codes(h)
def forward(self, mel, return_decoder_latent=False):
orig_mel = mel
cm = ceil_multiple(mel.shape[-1], 2 ** (len(self.down)-1))
if cm != 0:
mel = F.pad(mel, (0,cm-mel.shape[-1]))
mel = F.pad(mel, (0, cm-mel.shape[-1]))
h = self.down(mel)
h = self.encoder(h)
h = self.enc_norm(h.permute(0,2,1))
h = self.enc_norm(h.permute(0, 2, 1))
codevectors, perplexity, codes = self.quantizer(h, return_probs=True)
diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors
diversity = (self.quantizer.num_codevectors -
perplexity) / self.quantizer.num_codevectors
self.log_codes(codes)
h = self.decoder(codevectors.permute(0,2,1))
h = self.decoder(codevectors.permute(0, 2, 1))
if return_decoder_latent:
return h, diversity
@ -237,13 +252,14 @@ class MusicQuantizer2(nn.Module):
def log_codes(self, codes):
if self.internal_step % 5 == 0:
codes = torch.argmax(codes, dim=-1)
ccodes = codes[:,:,0]
for j in range(1,codes.shape[-1]):
ccodes += codes[:,:,j] * self.codebook_size ** j
ccodes = codes[:, :, 0]
for j in range(1, codes.shape[-1]):
ccodes += codes[:, :, j] * self.codebook_size ** j
codes = ccodes
codes = codes.flatten()
l = codes.shape[0]
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
i = self.code_ind if (
self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
self.codes[i:i+l] = codes.cpu()
self.code_ind = self.code_ind + l
if self.code_ind >= self.codes.shape[0]:
@ -258,9 +274,9 @@ class MusicQuantizer2(nn.Module):
def update_for_step(self, step, *args):
self.quantizer.temperature = max(
self.max_gumbel_temperature * self.gumbel_temperature_decay**step,
self.min_gumbel_temperature,
)
self.max_gumbel_temperature * self.gumbel_temperature_decay**step,
self.min_gumbel_temperature,
)
@register_model
@ -269,7 +285,8 @@ def register_music_quantizer2(opt_net, opt):
if __name__ == '__main__':
model = MusicQuantizer2(inner_dim=[1024], codevector_dim=1024, codebook_size=256, codebook_groups=2)
model = MusicQuantizer2(
inner_dim=[1024], codevector_dim=1024, codebook_size=256, codebook_groups=2)
print_network(model)
mel = torch.randn((2,256,782))
model(mel)
mel = torch.randn((2, 256, 782))
model(mel)

@ -8,14 +8,17 @@ import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchvision
import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepBlock
from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, RotaryEmbedding, \
FeedForward
from trainer.networks import register_model
from utils.util import checkpoint, print_network, load_audio
import dlas.torch_intermediary as ml
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
timestep_embedding, zero_module)
from dlas.models.diffusion.unet_diffusion import TimestepBlock
from dlas.models.lucidrains.x_transformers import (Attention, Encoder,
FeedForward,
RMSScaleShiftNorm,
RotaryEmbedding)
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint, load_audio, print_network
class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock):
@ -389,17 +392,20 @@ def inference_tfdpc5_with_cheater():
use_fp16=False, unconditioned_percentage=0).eval().cuda()
model.load_state_dict(torch.load('x:/dlas/experiments/train_music_cheater_gen_v3/models/59000_generator_ema.pth'))
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector
from trainer.injectors.audio_injectors import \
TorchMelSpectrogramInjector
spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000, 'true_normalization': True,
'normalize': True, 'in': 'in', 'out': 'out'}, {}).cuda()
ref_mel = spec_fn({'in': sample.unsqueeze(0)})['out']
from trainer.injectors.audio_injectors import MusicCheaterLatentInjector
from trainer.injectors.audio_injectors import \
MusicCheaterLatentInjector
cheater_encoder = MusicCheaterLatentInjector({'in': 'in', 'out': 'out'}, {}).cuda()
ref_cheater = cheater_encoder({'in': ref_mel})['out']
from models.diffusion.respace import SpacedDiffusion
from models.diffusion.respace import space_timesteps
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
from models.diffusion.gaussian_diffusion import \
get_named_beta_schedule
from models.diffusion.respace import (SpacedDiffusion,
space_timesteps)
diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [128]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
conditioning_free=True, conditioning_free_k=1)
@ -415,7 +421,8 @@ def inference_tfdpc5_with_cheater():
# Just decode the ref.
#gen_cheater = ref_cheater
from models.audio.music.transformer_diffusion12 import TransformerDiffusionWithCheaterLatent
from models.audio.music.transformer_diffusion12 import \
TransformerDiffusionWithCheaterLatent
diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [32]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000),
conditioning_free=True, conditioning_free_k=1)

@ -4,18 +4,21 @@ from time import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_intermediary as ml
from models.arch_util import ResBlock
from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower
from models.audio.music.music_quantizer2 import MusicQuantizer2
from models.audio.tts.lucidrains_dvae import DiscreteVAE
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepBlock
from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, RotaryEmbedding, \
FeedForward
from trainer.networks import register_model
from utils.util import checkpoint, print_network
import dlas.torch_intermediary as ml
from dlas.models.arch_util import ResBlock
from dlas.models.audio.music.gpt_music2 import GptMusicLower, UpperEncoder
from dlas.models.audio.music.music_quantizer2 import MusicQuantizer2
from dlas.models.audio.tts.lucidrains_dvae import DiscreteVAE
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
timestep_embedding, zero_module)
from dlas.models.diffusion.unet_diffusion import TimestepBlock
from dlas.models.lucidrains.x_transformers import (Attention, Encoder,
FeedForward,
RMSScaleShiftNorm,
RotaryEmbedding)
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint, print_network
def is_latent(t):

@ -1,26 +1,29 @@
import itertools
import random
from random import randrange
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_intermediary as ml
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask, cGLU, \
RelativeQKBias
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepBlock
from trainer.networks import register_model
from utils.util import checkpoint
import dlas.torch_intermediary as ml
from dlas.models.arch_util import (AttentionBlock, RelativeQKBias, ResBlock,
TimestepEmbedSequential,
build_local_attention_mask, cGLU)
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
timestep_embedding, zero_module)
from dlas.models.diffusion.unet_diffusion import TimestepBlock
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint
class SubBlock(nn.Module):
def __init__(self, inp_dim, contraction_dim, heads, dropout):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads)
self.register_buffer('mask', build_local_attention_mask(n=6000, l=64), persistent=False)
self.attn = AttentionBlock(
inp_dim, out_channels=contraction_dim, num_heads=heads)
self.register_buffer('mask', build_local_attention_mask(
n=6000, l=64), persistent=False)
self.pos_bias = RelativeQKBias(l=64, max_positions=6000)
ff_contract = contraction_dim//2
self.ff1 = nn.Sequential(nn.Conv1d(inp_dim+contraction_dim, ff_contract, kernel_size=1),
@ -31,7 +34,8 @@ class SubBlock(nn.Module):
cGLU(ff_contract))
def forward(self, x):
ah = self.dropout(self.attn(x, mask=self.mask, qk_bias=self.pos_bias(x.shape[-1])))
ah = self.dropout(self.attn(x, mask=self.mask,
qk_bias=self.pos_bias(x.shape[-1])))
h = torch.cat([ah, x], dim=1)
hf = self.dropout(checkpoint(self.ff1, h))
h = torch.cat([h, hf], dim=1)
@ -44,17 +48,21 @@ class ConcatAttentionBlock(TimestepBlock):
super().__init__()
self.contraction_dim = contraction_dim
self.prenorm = nn.GroupNorm(8, trunk_dim)
self.block1 = SubBlock(trunk_dim+blk_dim, contraction_dim, heads, dropout)
self.block2 = SubBlock(trunk_dim+blk_dim+contraction_dim*2, contraction_dim, heads, dropout)
self.out = nn.Conv1d(contraction_dim*4, trunk_dim, kernel_size=1, bias=False)
self.block1 = SubBlock(
trunk_dim+blk_dim, contraction_dim, heads, dropout)
self.block2 = SubBlock(
trunk_dim+blk_dim+contraction_dim*2, contraction_dim, heads, dropout)
self.out = nn.Conv1d(contraction_dim*4, trunk_dim,
kernel_size=1, bias=False)
self.out.weight.data.zero_()
def forward(self, x, blk_emb):
h = self.prenorm(x)
h = torch.cat([h, blk_emb.unsqueeze(-1).repeat(1,1,x.shape[-1])], dim=1)
h = torch.cat(
[h, blk_emb.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
h = self.block1(h)
h = self.block2(h)
h = self.out(h[:,-self.contraction_dim*4:])
h = self.out(h[:, -self.contraction_dim*4:])
return h + x
@ -72,10 +80,13 @@ class ConditioningEncoder(nn.Module):
self.init = nn.Conv1d(spec_dim, hidden_dim, kernel_size=5, stride=2)
# nn.Embedding
self.resolution_embedding = ml.Embedding(num_resolutions, hidden_dim)
self.resolution_embedding.weight.data.mul(.1) # Reduces the relative influence of this embedding from the start.
# Reduces the relative influence of this embedding from the start.
self.resolution_embedding.weight.data.mul(.1)
for a in range(attn_blocks):
attn.append(AttentionBlock(hidden_dim, num_attn_heads, do_checkpoint=do_checkpointing))
attn.append(ResBlock(hidden_dim, dims=1, checkpointing_enabled=do_checkpointing))
attn.append(AttentionBlock(hidden_dim, num_attn_heads,
do_checkpoint=do_checkpointing))
attn.append(ResBlock(hidden_dim, dims=1,
checkpointing_enabled=do_checkpointing))
self.attn = nn.Sequential(*attn)
self.out = ml.Linear(hidden_dim, out_dim, bias=False)
self.dim = hidden_dim
@ -91,6 +102,7 @@ class TransformerDiffusion(nn.Module):
"""
A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way?
"""
def __init__(
self,
resolution_steps=8,
@ -108,7 +120,8 @@ class TransformerDiffusion(nn.Module):
dropout=0,
use_fp16=False,
# Parameters for regularization.
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
# This implements a mechanism similar to what is used in classifier-free training.
unconditioned_percentage=.1,
):
super().__init__()
@ -135,17 +148,21 @@ class TransformerDiffusion(nn.Module):
)
# nn.Embedding
self.resolution_embed = ml.Embedding(resolution_steps, time_proj_dim)
self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim))
self.conditioning_encoder = ConditioningEncoder(
in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64)
self.unconditioned_embedding = nn.Parameter(
torch.randn(1, cond_proj_dim))
self.inp_block = conv_nd(1, in_channels+input_vec_dim, model_channels, 3, 1, 1)
self.inp_block = conv_nd(
1, in_channels+input_vec_dim, model_channels, 3, 1, 1)
self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_proj_dim*3 + cond_proj_dim,
num_heads, dropout) for _ in range(num_layers)])
self.out = nn.Sequential(
normalization(model_channels),
nn.SiLU(),
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
zero_module(conv_nd(1, model_channels,
out_channels, 3, padding=1)),
)
self.debug_codes = {}
@ -163,17 +180,20 @@ class TransformerDiffusion(nn.Module):
s_diff = s.shape[-1] - self.max_window
if s_diff > 1:
start = randrange(0, s_diff)
s = s[:,:,start:start+self.max_window]
s = s[:, :, start:start+self.max_window]
s_prior = F.interpolate(s, scale_factor=.25, mode='nearest')
s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True)
s_prior = F.interpolate(s_prior, size=(
s.shape[-1],), mode='linear', align_corners=True)
# Now diffuse the prior randomly between the x timestep and 0.
adv = torch.rand_like(ts.float())
t_prior = (adv * ts).long() - 1
# The t_prior-1 below is an important detail: it forces s_prior to be unmodified for ts=0. It also means that t_prior is not on the same timescale as ts (instead it is shifted by 1).
s_prior_diffused = diffuser.q_sample(s_prior, t_prior-1, torch.randn_like(s_prior), allow_negatives=True)
s_prior_diffused = diffuser.q_sample(
s_prior, t_prior-1, torch.randn_like(s_prior), allow_negatives=True)
self.preprocessed = (s_prior_diffused, t_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device))
self.preprocessed = (s_prior_diffused, t_prior, torch.tensor(
[resolution] * x.shape[0], dtype=torch.long, device=x.device))
return s
def forward(self, x, timesteps, prior_timesteps=None, x_prior=None, resolution=None, conditioning_input=None, conditioning_free=False):
@ -202,12 +222,16 @@ class TransformerDiffusion(nn.Module):
x_prior, prior_timesteps, resolution = self.preprocessed
self.preprocessed = None
else:
assert x.shape[-1] > x_prior.shape[-1] * 3.9, f'{x.shape} {x_prior.shape}'
assert x.shape[-1] > x_prior.shape[-1] * \
3.9, f'{x.shape} {x_prior.shape}'
if prior_timesteps is None:
# This is taken to mean a fully diffused prior was given.
prior_timesteps = torch.tensor([0], device=x.device) # Assuming batch_size=1 for inference.
x_prior = F.interpolate(x_prior, size=(x.shape[-1],), mode='linear', align_corners=True)
assert torch.all(timesteps - prior_timesteps >= 0), f'Prior timesteps should always be lower (more resolved) than input timesteps. {timesteps}, {prior_timesteps}'
# Assuming batch_size=1 for inference.
prior_timesteps = torch.tensor([0], device=x.device)
x_prior = F.interpolate(x_prior, size=(
x.shape[-1],), mode='linear', align_corners=True)
assert torch.all(timesteps - prior_timesteps >=
0), f'Prior timesteps should always be lower (more resolved) than input timesteps. {timesteps}, {prior_timesteps}'
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1)
@ -218,19 +242,26 @@ class TransformerDiffusion(nn.Module):
clen = randrange(MIN_COND_LEN, MAX_COND_LEN)
gap = conditioning_input.shape[-1] - clen
cstart = randrange(0, gap)
conditioning_input = conditioning_input[:,:,cstart:cstart+clen]
code_emb = self.conditioning_encoder(conditioning_input, resolution)
conditioning_input = conditioning_input[:,
:, cstart:cstart+clen]
code_emb = self.conditioning_encoder(
conditioning_input, resolution)
# Mask out the conditioning input and x_prior inputs for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((x.shape[0], 1), device=x.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1), code_emb)
unconditioned_batches = torch.rand(
(x.shape[0], 1), device=x.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(
code_emb.shape[0], 1), code_emb)
with torch.autocast(x.device.type, enabled=self.enable_fp16):
time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
prior_time_emb = self.prior_time_embed(timestep_embedding(prior_timesteps, self.time_embed_dim))
time_emb = self.time_embed(
timestep_embedding(timesteps, self.time_embed_dim))
prior_time_emb = self.prior_time_embed(
timestep_embedding(prior_timesteps, self.time_embed_dim))
res_emb = self.resolution_embed(resolution)
blk_emb = torch.cat([time_emb, prior_time_emb, res_emb, code_emb], dim=1)
blk_emb = torch.cat(
[time_emb, prior_time_emb, res_emb, code_emb], dim=1)
h = torch.cat([x, x_prior], dim=1)
h = self.inp_block(h)
@ -250,13 +281,16 @@ class TransformerDiffusion(nn.Module):
return out
def get_grad_norm_parameter_groups(self):
attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers]))
attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers]))
attn1 = list(itertools.chain.from_iterable(
[lyr.block1.attn.parameters() for lyr in self.layers]))
attn2 = list(itertools.chain.from_iterable(
[lyr.block2.attn.parameters() for lyr in self.layers]))
ff1 = list(itertools.chain.from_iterable([lyr.block1.ff1.parameters() for lyr in self.layers] +
[lyr.block1.ff2.parameters() for lyr in self.layers]))
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff1.parameters() for lyr in self.layers] +
[lyr.block2.ff2.parameters() for lyr in self.layers]))
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers]))
blkout_layers = list(itertools.chain.from_iterable(
[lyr.out.parameters() for lyr in self.layers]))
groups = {
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])),
'blk1_attention_layers': attn1,
@ -276,7 +310,8 @@ class TransformerDiffusion(nn.Module):
return groups
def before_step(self, step):
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers]))
scaled_grad_parameters = list(itertools.chain.from_iterable(
[lyr.out.parameters() for lyr in self.layers]))
# Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
# directly fiddling with the gradients.
@ -291,14 +326,13 @@ def register_transformer_diffusion13(opt_net, opt):
def test_tfd():
from models.diffusion.respace import SpacedDiffusion
from models.diffusion.respace import space_timesteps
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
from models.diffusion.respace import SpacedDiffusion, space_timesteps
diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [4000]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse',
betas=get_named_beta_schedule('linear', 4000))
clip = torch.randn(2,256,10336)
cond = torch.randn(2,256,10336)
model_var_type='learned_range', loss_type='mse',
betas=get_named_beta_schedule('linear', 4000))
clip = torch.randn(2, 256, 10336)
cond = torch.randn(2, 256, 10336)
ts = torch.LongTensor([0, 0])
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
num_heads=512//64, input_vec_dim=256, num_layers=12, dropout=.1,
@ -316,5 +350,5 @@ def remove_conditioning(sd_path):
if __name__ == '__main__':
#remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_pre\\models\\12500_generator.pth')
# remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_pre\\models\\12500_generator.pth')
test_tfd()

@ -4,18 +4,21 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from models.arch_util import TimestepEmbedSequential
from models.audio.music.encoders import ResEncoder16x
from models.audio.music.transformer_diffusion13 import ConcatAttentionBlock
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from trainer.networks import register_model
from utils.util import checkpoint, print_network
from dlas.models.arch_util import TimestepEmbedSequential
from dlas.models.audio.music.encoders import ResEncoder16x
from dlas.models.audio.music.transformer_diffusion13 import \
ConcatAttentionBlock
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
timestep_embedding, zero_module)
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint, print_network
class TransformerDiffusion(nn.Module):
"""
A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way?
"""
def __init__(
self,
time_embed_dim=256,
@ -27,11 +30,13 @@ class TransformerDiffusion(nn.Module):
out_channels=512, # mean and variance
num_heads=4,
dropout=0,
use_corner_alignment=False, # This is an interpolation parameter only provided for backwards compatibility. ALL NEW TRAINS SHOULD SET THIS TO TRUE.
# This is an interpolation parameter only provided for backwards compatibility. ALL NEW TRAINS SHOULD SET THIS TO TRUE.
use_corner_alignment=False,
use_fp16=False,
new_code_expansion=False,
# Parameters for regularization.
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
# This implements a mechanism similar to what is used in classifier-free training.
unconditioned_percentage=.1,
# Parameters for re-training head
freeze_except_code_converters=False,
):
@ -55,7 +60,8 @@ class TransformerDiffusion(nn.Module):
)
self.input_converter = nn.Conv1d(input_vec_dim, model_channels, 1)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
self.unconditioned_embedding = nn.Parameter(
torch.randn(1, model_channels, 1))
self.intg = nn.Conv1d(model_channels*2, model_channels, 1)
self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim//4,
num_heads, dropout) for _ in range(num_layers)])
@ -63,7 +69,8 @@ class TransformerDiffusion(nn.Module):
self.out = nn.Sequential(
normalization(model_channels),
nn.SiLU(),
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
zero_module(conv_nd(1, model_channels,
out_channels, 3, padding=1)),
)
if freeze_except_code_converters:
@ -76,13 +83,16 @@ class TransformerDiffusion(nn.Module):
p.requires_grad = True
def get_grad_norm_parameter_groups(self):
attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers]))
attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers]))
attn1 = list(itertools.chain.from_iterable(
[lyr.block1.attn.parameters() for lyr in self.layers]))
attn2 = list(itertools.chain.from_iterable(
[lyr.block2.attn.parameters() for lyr in self.layers]))
ff1 = list(itertools.chain.from_iterable([lyr.block1.ff1.parameters() for lyr in self.layers] +
[lyr.block1.ff2.parameters() for lyr in self.layers]))
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff1.parameters() for lyr in self.layers] +
[lyr.block2.ff2.parameters() for lyr in self.layers]))
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers]))
blkout_layers = list(itertools.chain.from_iterable(
[lyr.out.parameters() for lyr in self.layers]))
groups = {
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])),
'blk1_attention_layers': attn1,
@ -101,7 +111,8 @@ class TransformerDiffusion(nn.Module):
def forward(self, x, timesteps, prior=None, conditioning_free=False):
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
code_emb = self.unconditioned_embedding.repeat(
x.shape[0], 1, x.shape[-1])
else:
code_emb = self.input_converter(prior)
@ -112,10 +123,12 @@ class TransformerDiffusion(nn.Module):
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1),
code_emb)
code_emb = F.interpolate(code_emb, size=x.shape[-1], mode='nearest')
code_emb = F.interpolate(
code_emb, size=x.shape[-1], mode='nearest')
with torch.autocast(x.device.type, enabled=self.enable_fp16):
blk_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
blk_emb = self.time_embed(
timestep_embedding(timesteps, self.time_embed_dim))
x = self.inp_block(x)
x = self.intg(torch.cat([x, code_emb], dim=1))
@ -141,7 +154,8 @@ class TransformerDiffusionWithCheaterLatent(nn.Module):
self.internal_step = 0
self.freeze_encoder_until = freeze_encoder_until
self.diff = TransformerDiffusion(**kwargs)
self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder)
self.encoder = ResEncoder16x(
256, 1024, 256, checkpointing_enabled=checkpoint_encoder)
def forward(self, x, timesteps, truth_mel, conditioning_free=False, cheater=None):
unused_parameters = []
@ -158,7 +172,8 @@ class TransformerDiffusionWithCheaterLatent(nn.Module):
for p in unused_parameters:
proj = proj + p.mean() * 0
diff = self.diff(x, timesteps, prior=proj, conditioning_free=conditioning_free)
diff = self.diff(x, timesteps, prior=proj,
conditioning_free=conditioning_free)
return diff
def get_debug_values(self, step, __):
@ -172,7 +187,8 @@ class TransformerDiffusionWithCheaterLatent(nn.Module):
def before_step(self, step):
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \
list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers]))
list(itertools.chain.from_iterable(
[lyr.prenorm.parameters() for lyr in self.diff.layers]))
# Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
# directly fiddling with the gradients.
@ -196,10 +212,10 @@ def register_transformer_diffusion_14_with_cheater_latent(opt_net, opt):
def test_tfd():
clip = torch.randn(2,256,400)
clip = torch.randn(2, 256, 400)
ts = torch.LongTensor([600, 600])
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
num_heads=3, input_vec_dim=256, num_layers=12, dropout=.1)
num_heads=3, input_vec_dim=256, num_layers=12, dropout=.1)
model(clip, ts, clip)
@ -209,14 +225,14 @@ def test_cheater_model():
# For music:
model = TransformerDiffusionWithCheaterLatent(in_channels=256, out_channels=512,
model_channels=1024, contraction_dim=512, num_heads=8,
input_vec_dim=256, num_layers=16,
dropout=.1, new_code_expansion=True,
)
#diff_weights = torch.load('extracted_diff.pth')
#model.diff.load_state_dict(diff_weights, strict=False)
#model.encoder.load_state_dict(torch.load('../experiments/music_cheater_encoder_256.pth', map_location=torch.device('cpu')), strict=True)
#torch.save(model.state_dict(), 'sample.pth')
model_channels=1024, contraction_dim=512, num_heads=8,
input_vec_dim=256, num_layers=16,
dropout=.1, new_code_expansion=True,
)
# diff_weights = torch.load('extracted_diff.pth')
# model.diff.load_state_dict(diff_weights, strict=False)
# model.encoder.load_state_dict(torch.load('../experiments/music_cheater_encoder_256.pth', map_location=torch.device('cpu')), strict=True)
# torch.save(model.state_dict(), 'sample.pth')
print_network(model)
o = model(clip, ts, clip)
@ -234,7 +250,8 @@ def extract_cheater_encoder(in_f, out_f):
if __name__ == '__main__':
#test_local_attention_mask()
extract_cheater_encoder('X:\\dlas\\experiments\\tfd14_and_cheater.pth', 'X:\\dlas\\experiments\\tfd14_cheater_encoder.pth')
#test_cheater_model()
#extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True)
# test_local_attention_mask()
extract_cheater_encoder('X:\\dlas\\experiments\\tfd14_and_cheater.pth',
'X:\\dlas\\experiments\\tfd14_cheater_encoder.pth')
# test_cheater_model()
# extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True)

@ -1,29 +1,23 @@
from abc import abstractmethod
import math
from abc import abstractmethod
import numpy as np
import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torchvision # For debugging, not actually used.
import torch_intermediary as ml
from models.audio.music.gpt_music import GptMusicLower
from models.audio.music.music_quantizer import MusicQuantizer
from models.diffusion.fp16_util import convert_module_to_f16, convert_module_to_f32
from models.diffusion.nn import (
conv_nd,
linear,
avg_pool_nd,
zero_module,
normalization,
timestep_embedding,
)
from models.lucidrains.x_transformers import Encoder
from trainer.networks import register_model
from utils.util import checkpoint, print_network, ceil_multiple
import dlas.torch_intermediary as ml
from dlas.models.audio.music.gpt_music import GptMusicLower
from dlas.models.audio.music.music_quantizer import MusicQuantizer
from dlas.models.diffusion.fp16_util import (convert_module_to_f16,
convert_module_to_f32)
from dlas.models.diffusion.nn import (avg_pool_nd, conv_nd, linear,
normalization, timestep_embedding,
zero_module)
from dlas.models.lucidrains.x_transformers import Encoder
from dlas.trainer.networks import register_model
from dlas.utils.util import ceil_multiple, checkpoint, print_network
class TimestepBlock(nn.Module):
@ -80,7 +74,8 @@ class Upsample(nn.Module):
if dims == 1:
ksize = 5
pad = 2
self.conv = conv_nd(dims, self.channels, self.out_channels, ksize, padding=pad)
self.conv = conv_nd(dims, self.channels,
self.out_channels, ksize, padding=pad)
def forward(self, x):
assert x.shape[1] == self.channels
@ -123,7 +118,7 @@ class Downsample(nn.Module):
elif dims == 2:
stride = 2
else:
stride = (1,2,2)
stride = (1, 2, 2)
if factor is not None:
stride = factor
if use_conv:
@ -180,7 +175,8 @@ class ResBlock(TimestepBlock):
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
conv_nd(dims, channels, self.out_channels,
kernel_size, padding=padding),
)
self.updown = up or down
@ -206,7 +202,8 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
conv_nd(dims, self.out_channels, self.out_channels,
kernel_size, padding=padding)
),
)
@ -217,7 +214,8 @@ class ResBlock(TimestepBlock):
dims, channels, self.out_channels, kernel_size, padding=padding
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 1)
def forward(self, x, emb):
"""
@ -346,13 +344,15 @@ class QKVAttentionLegacy(nn.Module):
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3,
length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
if rel_pos is not None:
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(
bs * self.n_heads, weight.shape[-2], weight.shape[-1])
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
if mask is not None:
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
@ -400,7 +400,8 @@ class QKVAttention(nn.Module):
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
weight = weight * mask
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
a = th.einsum("bts,bcs->bct", weight,
v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
@staticmethod
@ -460,7 +461,8 @@ class UNetMusicModel(nn.Module):
resblock_updown=False,
use_new_attention_order=False,
use_raw_y_as_embedding=False,
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
# This implements a mechanism similar to what is used in classifier-free training.
unconditioned_percentage=.1,
):
super().__init__()
@ -493,44 +495,48 @@ class UNetMusicModel(nn.Module):
if self.ar_prior:
self.ar_input = ml.Linear(input_vec_dim, model_channels)
self.ar_prior_intg = Encoder(
dim=model_channels,
depth=4,
heads=num_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
zero_init_branch_output=True,
ff_mult=1,
)
dim=model_channels,
depth=4,
heads=num_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
zero_init_branch_output=True,
ff_mult=1,
)
else:
self.input_converter = ml.Linear(input_vec_dim, model_channels)
self.code_converter = Encoder(
dim=model_channels,
depth=4,
heads=num_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
zero_init_branch_output=True,
ff_mult=1,
)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))
self.x_processor = conv_nd(dims, in_channels, model_channels, 3, padding=1)
dim=model_channels,
depth=4,
heads=num_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
zero_init_branch_output=True,
ff_mult=1,
)
self.unconditioned_embedding = nn.Parameter(
torch.randn(1, 1, model_channels))
self.x_processor = conv_nd(
dims, in_channels, model_channels, 3, padding=1)
if self.num_classes is not None:
# nn.Embedding
self.label_emb = ml.Embedding(num_classes, time_embed_dim)
self.use_raw_y_as_embedding = use_raw_y_as_embedding
assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive.
# These are mutually-exclusive.
assert not ((self.num_classes is not None) and use_raw_y_as_embedding)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, model_channels*2, model_channels, 1, padding=0)
conv_nd(dims, model_channels*2,
model_channels, 1, padding=0)
)
]
)
@ -657,7 +663,8 @@ class UNetMusicModel(nn.Module):
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
zero_module(conv_nd(dims, model_channels,
out_channels, 3, padding=1)),
)
def forward(self, x, timesteps, y, conditioning_free=False):
@ -666,27 +673,35 @@ class UNetMusicModel(nn.Module):
if cm != 0:
pc = (cm - x.shape[-1]) / x.shape[-1]
x = F.pad(x, (0, cm - x.shape[-1]))
y = F.pad(y.permute(0,2,1), (0, int(pc * y.shape[-1]))).permute(0,2,1)
y = F.pad(y.permute(0, 2, 1), (0, int(
pc * y.shape[-1]))).permute(0, 2, 1)
unused_params = []
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
emb = self.time_embed(timestep_embedding(
timesteps, self.model_channels))
if conditioning_free:
expanded_code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1).permute(0,2,1)
expanded_code_emb = self.unconditioned_embedding.repeat(
x.shape[0], x.shape[-1], 1).permute(0, 2, 1)
if self.ar_prior:
unused_params.extend(list(self.ar_input.parameters()) + list(self.ar_prior_intg.parameters()))
unused_params.extend(
list(self.ar_input.parameters()) + list(self.ar_prior_intg.parameters()))
else:
unused_params.extend(list(self.input_converter.parameters()) + list(self.code_converter.parameters()))
unused_params.extend(
list(self.input_converter.parameters()) + list(self.code_converter.parameters()))
else:
code_emb = self.ar_input(y) if self.ar_prior else self.input_converter(y)
code_emb = self.ar_input(
y) if self.ar_prior else self.input_converter(y)
if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
device=code_emb.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(y.shape[0], 1, 1),
code_emb)
code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb)
expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=x.shape[-1], mode='nearest')
code_emb = self.ar_prior_intg(
code_emb) if self.ar_prior else self.code_converter(code_emb)
expanded_code_emb = F.interpolate(code_emb.permute(
0, 2, 1), size=x.shape[-1], mode='nearest')
h = x.type(self.dtype)
expanded_code_emb = expanded_code_emb.type(self.dtype)
@ -720,7 +735,8 @@ class UNetMusicModelWithQuantizer(nn.Module):
self.internal_step = 0
self.freeze_quantizer_until = freeze_quantizer_until
self.diff = UNetMusicModel(**kwargs)
self.m2v = MusicQuantizer(inp_channels=256, inner_dim=[1024,1024,512], codevector_dim=1024, codebook_size=512, codebook_groups=2)
self.m2v = MusicQuantizer(inp_channels=256, inner_dim=[
1024, 1024, 512], codevector_dim=1024, codebook_size=512, codebook_groups=2)
self.m2v.quantizer.temperature = self.m2v.min_gumbel_temperature
del self.m2v.up
@ -728,15 +744,16 @@ class UNetMusicModelWithQuantizer(nn.Module):
self.internal_step = step
qstep = max(0, self.internal_step - self.freeze_quantizer_until)
self.m2v.quantizer.temperature = max(
self.m2v.max_gumbel_temperature * self.m2v.gumbel_temperature_decay**qstep,
self.m2v.min_gumbel_temperature,
)
self.m2v.max_gumbel_temperature * self.m2v.gumbel_temperature_decay**qstep,
self.m2v.min_gumbel_temperature,
)
def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False):
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
with torch.set_grad_enabled(quant_grad_enabled):
proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True)
proj = proj.permute(0,2,1)
proj, diversity_loss = self.m2v(
truth_mel, return_decoder_latent=True)
proj = proj.permute(0, 2, 1)
# Make sure this does not cause issues in DDP by explicitly using the parameters for nothing.
if not quant_grad_enabled:
@ -746,7 +763,8 @@ class UNetMusicModelWithQuantizer(nn.Module):
proj = proj + unused
diversity_loss = diversity_loss * 0
diff = self.diff(x, timesteps, proj, conditioning_free=conditioning_free)
diff = self.diff(x, timesteps, proj,
conditioning_free=conditioning_free)
if disable_diversity:
return diff
return diff, diversity_loss
@ -790,7 +808,8 @@ class UNetMusicModelARPrior(nn.Module):
with torch.no_grad():
prior = self.ar(truth_mel, conditioning_input, return_latent=True)
diff = self.diff(x, timesteps, prior, conditioning_free=conditioning_free)
diff = self.diff(x, timesteps, prior,
conditioning_free=conditioning_free)
return diff
@ -798,6 +817,7 @@ class UNetMusicModelARPrior(nn.Module):
def register_unet_diffusion_music_codes(opt_net, opt):
return UNetMusicModelWithQuantizer(**opt_net['args'])
@register_model
def register_unet_diffusion_music_ar_prior(opt_net, opt):
return UNetMusicModelARPrior(**opt_net['args'])
@ -808,20 +828,21 @@ if __name__ == '__main__':
cond = torch.randn(2, 256, 300)
ts = torch.LongTensor([600, 600])
model = UNetMusicModelARPrior(in_channels=256, out_channels=512, model_channels=640, num_res_blocks=3, input_vec_dim=512,
attention_resolutions=(2,4), channel_mult=(1,2,3), dims=1,
attention_resolutions=(2, 4), channel_mult=(1, 2, 3), dims=1,
use_scale_shift_norm=True, dropout=.1, num_heads=8, unconditioned_percentage=.4, freeze_unet=True)
print_network(model)
model.get_grad_norm_parameter_groups()
ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth')
ar_weights = torch.load(
'D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth')
model.ar.load_state_dict(ar_weights, strict=True)
diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_unet_music\\models\\55500_generator_ema.pth')
diff_weights = torch.load(
'X:\\dlas\\experiments\\train_music_diffusion_unet_music\\models\\55500_generator_ema.pth')
pruned_diff_weights = {}
for k,v in diff_weights.items():
for k, v in diff_weights.items():
if k.startswith('diff.'):
pruned_diff_weights[k.replace('diff.', '')] = v
model.diff.load_state_dict(pruned_diff_weights, strict=False)
torch.save(model.state_dict(), 'sample.pth')
model(clip, ts, cond, cond)

@ -6,14 +6,19 @@ import torch.nn.functional as F
from torch import autocast
from x_transformers import Encoder
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
Downsample, Upsample, TimestepBlock
from models.audio.tts.mini_encoder import AudioMiniEncoder
from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
from trainer.networks import register_model
from utils.util import checkpoint
from dlas.models.audio.tts.mini_encoder import AudioMiniEncoder
from dlas.models.audio.tts.unet_diffusion_tts7 import \
CheckpointedXTransformerEncoder
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
timestep_embedding, zero_module)
from dlas.models.diffusion.unet_diffusion import (AttentionBlock, Downsample,
TimestepBlock,
TimestepEmbedSequential,
Upsample)
from dlas.scripts.audio.gen.use_diffuse_tts import ceil_multiple
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint
def is_sequence(t):
return t.dtype == torch.long
@ -44,7 +49,8 @@ class ResBlock(TimestepBlock):
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding),
conv_nd(dims, channels, self.out_channels,
eff_kernel, padding=eff_padding),
)
self.emb_layers = nn.Sequential(
@ -59,14 +65,16 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
conv_nd(dims, self.out_channels, self.out_channels,
kernel_size, padding=padding)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
self.skip_connection = conv_nd(
dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
def forward(self, x, emb):
"""
@ -95,6 +103,7 @@ class ResBlock(TimestepBlock):
h = self.out_layers(h)
return self.skip_connection(x) + h
class DiffusionWaveformGen(nn.Module):
"""
The full UNet model with attention and timestep embedding.
@ -138,12 +147,12 @@ class DiffusionWaveformGen(nn.Module):
out_channels=2, # mean and variance
dropout=0,
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
channel_mult=(1, 1.5, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
token_conditioning_resolutions=(1,16,),
attention_resolutions=(512,1024,2048),
token_conditioning_resolutions=(1, 16,),
attention_resolutions=(512, 1024, 2048),
conv_resample=True,
dims=1,
use_fp16=False,
@ -154,10 +163,12 @@ class DiffusionWaveformGen(nn.Module):
scale_factor=2,
time_embed_dim_multiplier=4,
freeze_main_net=False,
efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
# Uses kernels with width of 1 in several places rather than 3.
efficient_convs=True,
use_scale_shift_norm=True,
# Parameters for regularization.
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
# This implements a mechanism similar to what is used in classifier-free training.
unconditioned_percentage=.1,
# Parameters for super-sampling.
super_sampling=False,
super_sampling_max_noising_factor=.1,
@ -168,7 +179,8 @@ class DiffusionWaveformGen(nn.Module):
num_heads_upsample = num_heads
if super_sampling:
in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input.
# In super-sampling mode, the LR input is concatenated directly onto the input.
in_channels *= 2
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
@ -218,22 +230,31 @@ class DiffusionWaveformGen(nn.Module):
rotary_pos_emb=True,
)
))
self.latent_converter = nn.Conv1d(in_latent_channels, conditioning_dim, 1)
self.aligned_latent_padding_embedding = nn.Parameter(torch.randn(1,in_latent_channels,1))
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
self.latent_converter = nn.Conv1d(
in_latent_channels, conditioning_dim, 1)
self.aligned_latent_padding_embedding = nn.Parameter(
torch.randn(1, in_latent_channels, 1))
self.unconditioned_embedding = nn.Parameter(
torch.randn(1, conditioning_dim, 1))
self.conditioning_timestep_integrator = TimestepEmbedSequential(
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim,
dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
AttentionBlock(conditioning_dim, num_heads=num_heads,
num_head_channels=num_head_channels),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim,
dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
AttentionBlock(conditioning_dim, num_heads=num_heads,
num_head_channels=num_head_channels),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim,
dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
)
self.conditioning_expansion = conditioning_expansion
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
conv_nd(dims, in_channels, model_channels,
kernel_size, padding=padding)
)
]
)
@ -344,7 +365,8 @@ class DiffusionWaveformGen(nn.Module):
if level and i == num_blocks:
out_ch = ch
layers.append(
Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor)
Upsample(ch, conv_resample, dims=dims,
out_channels=out_ch, factor=scale_factor)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
@ -353,7 +375,8 @@ class DiffusionWaveformGen(nn.Module):
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
zero_module(conv_nd(dims, model_channels, out_channels,
kernel_size, padding=padding)),
)
if self.freeze_main_net:
@ -385,13 +408,14 @@ class DiffusionWaveformGen(nn.Module):
cm = ceil_multiple(x.shape[-1], self.alignment_size)
if cm != 0:
pc = (cm-x.shape[-1])/x.shape[-1]
x = F.pad(x, (0,cm-x.shape[-1]))
x = F.pad(x, (0, cm-x.shape[-1]))
# Also fix aligned_latent, which is aligned to x.
if self.is_latent(aligned_conditioning):
aligned_conditioning = torch.cat([aligned_conditioning,
self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1)
else:
aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1])))
aligned_conditioning = F.pad(
aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1])))
return x, aligned_conditioning
def forward(self, x, timesteps, aligned_conditioning, conditioning_free=False):
@ -416,11 +440,13 @@ class DiffusionWaveformGen(nn.Module):
with autocast(x.device.type, enabled=self.enable_fp16):
hs = []
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
time_emb = self.time_embed(
timestep_embedding(timesteps, self.model_channels))
# Note: this block does not need to repeated on inference, since it is not timestep-dependent.
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
code_emb = self.unconditioned_embedding.repeat(
x.shape[0], 1, 1)
else:
if self.is_latent(aligned_conditioning):
code_emb = self.latent_converter(aligned_conditioning)
@ -428,15 +454,18 @@ class DiffusionWaveformGen(nn.Module):
code_emb = self.mel_converter(aligned_conditioning)
# Everything after this comment is timestep dependent.
code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
code_emb = torch.repeat_interleave(
code_emb, self.conditioning_expansion, dim=-1)
code_emb = self.conditioning_timestep_integrator(
code_emb, time_emb)
first = True
time_emb = time_emb.float()
h = x
for k, module in enumerate(self.input_blocks):
if isinstance(module, nn.Conv1d):
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
h_tok = F.interpolate(module(code_emb), size=(
h.shape[-1]), mode='nearest')
h = h + h_tok
else:
with autocast(x.device.type, enabled=self.enable_fp16 and not first):
@ -455,7 +484,8 @@ class DiffusionWaveformGen(nn.Module):
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
extraneous_addition = 0
params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters())
params = [self.aligned_latent_padding_embedding,
self.unconditioned_embedding] + list(self.latent_converter.parameters())
for p in params:
extraneous_addition = extraneous_addition + p.mean()
out = out + extraneous_addition * 0
@ -470,13 +500,13 @@ def register_unet_diffusion_waveform_gen(opt_net, opt):
if __name__ == '__main__':
clip = torch.randn(2, 1, 32868)
aligned_latent = torch.randn(2,388,1024)
aligned_sequence = torch.randn(2,120,220)
aligned_latent = torch.randn(2, 388, 1024)
aligned_sequence = torch.randn(2, 120, 220)
ts = torch.LongTensor([600, 600])
model = DiffusionWaveformGen(128,
channel_mult=[1,1.5,2, 3, 4, 6, 8],
channel_mult=[1, 1.5, 2, 3, 4, 6, 8],
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
token_conditioning_resolutions=[1,4,16,64],
token_conditioning_resolutions=[1, 4, 16, 64],
attention_resolutions=[],
num_heads=8,
kernel_size=3,
@ -488,4 +518,3 @@ if __name__ == '__main__':
o = model(clip, ts, aligned_latent)
# Test with sequence aligned conditioning
o = model(clip, ts, aligned_sequence)

@ -4,12 +4,14 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepEmbedSequential, \
Downsample, Upsample, TimestepBlock
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
from trainer.networks import register_model
from utils.util import checkpoint, print_network
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
timestep_embedding, zero_module)
from dlas.models.diffusion.unet_diffusion import (Downsample, TimestepBlock,
TimestepEmbedSequential,
Upsample)
from dlas.scripts.audio.gen.use_diffuse_tts import ceil_multiple
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint, print_network
def is_sequence(t):
@ -368,4 +370,3 @@ if __name__ == '__main__':
# Test with sequence aligned conditioning
o = model(clip, ts, aligned_sequence)
print_network(model)

@ -3,12 +3,14 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepEmbedSequential, \
Downsample, Upsample, TimestepBlock
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
from trainer.networks import register_model
from utils.util import checkpoint
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
timestep_embedding, zero_module)
from dlas.models.diffusion.unet_diffusion import (Downsample, TimestepBlock,
TimestepEmbedSequential,
Upsample)
from dlas.scripts.audio.gen.use_diffuse_tts import ceil_multiple
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint
def is_sequence(t):
@ -40,7 +42,8 @@ class ResBlock(TimestepBlock):
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding),
conv_nd(dims, channels, self.out_channels,
eff_kernel, padding=eff_padding),
)
self.emb_layers = nn.Sequential(
@ -55,14 +58,16 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
conv_nd(dims, self.out_channels, self.out_channels,
kernel_size, padding=padding)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
self.skip_connection = conv_nd(
dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
def forward(self, x, emb):
"""
@ -91,6 +96,7 @@ class ResBlock(TimestepBlock):
h = self.out_layers(h)
return self.skip_connection(x) + h
class DiffusionWaveformGen(nn.Module):
"""
The full UNet model with residual blocks and timestep embedding.
@ -122,11 +128,11 @@ class DiffusionWaveformGen(nn.Module):
out_channels=2, # mean and variance
dropout=0,
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
channel_mult=(1, 1.5, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
token_conditioning_resolutions=(1,16,),
token_conditioning_resolutions=(1, 16,),
conv_resample=True,
dims=1,
use_fp16=False,
@ -134,10 +140,12 @@ class DiffusionWaveformGen(nn.Module):
scale_factor=2,
time_embed_dim_multiplier=4,
freeze_main_net=False,
efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
# Uses kernels with width of 1 in several places rather than 3.
efficient_convs=True,
use_scale_shift_norm=True,
# Parameters for regularization.
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
# This implements a mechanism similar to what is used in classifier-free training.
unconditioned_percentage=.1,
# Parameters for super-sampling.
super_sampling=False,
super_sampling_max_noising_factor=.1,
@ -145,7 +153,8 @@ class DiffusionWaveformGen(nn.Module):
super().__init__()
if super_sampling:
in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input.
# In super-sampling mode, the LR input is concatenated directly onto the input.
in_channels *= 2
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
@ -175,19 +184,25 @@ class DiffusionWaveformGen(nn.Module):
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
# transformer network.
self.mel_converter = nn.Conv1d(in_mel_channels, conditioning_dim, 3, padding=1)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
self.mel_converter = nn.Conv1d(
in_mel_channels, conditioning_dim, 3, padding=1)
self.unconditioned_embedding = nn.Parameter(
torch.randn(1, conditioning_dim, 1))
self.conditioning_timestep_integrator = TimestepEmbedSequential(
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim,
dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim,
dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim,
dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
)
self.conditioning_expansion = conditioning_expansion
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
conv_nd(dims, in_channels, model_channels,
kernel_size, padding=padding)
)
]
)
@ -268,7 +283,8 @@ class DiffusionWaveformGen(nn.Module):
if level and i == num_blocks:
out_ch = ch
layers.append(
Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor)
Upsample(ch, conv_resample, dims=dims,
out_channels=out_ch, factor=scale_factor)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
@ -277,7 +293,8 @@ class DiffusionWaveformGen(nn.Module):
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
zero_module(conv_nd(dims, model_channels, out_channels,
kernel_size, padding=padding)),
)
if self.freeze_main_net:
@ -306,8 +323,9 @@ class DiffusionWaveformGen(nn.Module):
cm = ceil_multiple(x.shape[-1], self.alignment_size)
if cm != 0:
pc = (cm-x.shape[-1])/x.shape[-1]
x = F.pad(x, (0,cm-x.shape[-1]))
aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1])))
x = F.pad(x, (0, cm-x.shape[-1]))
aligned_conditioning = F.pad(
aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1])))
return x, aligned_conditioning
def forward(self, x, timesteps, aligned_conditioning, conditioning_free=False):
@ -327,24 +345,29 @@ class DiffusionWaveformGen(nn.Module):
with autocast(x.device.type, enabled=self.enable_fp16):
hs = []
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
time_emb = self.time_embed(
timestep_embedding(timesteps, self.model_channels))
# Note: this block does not need to repeated on inference, since it is not timestep-dependent.
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
code_emb = self.unconditioned_embedding.repeat(
x.shape[0], 1, 1)
else:
code_emb = self.mel_converter(aligned_conditioning)
# Everything after this comment is timestep dependent.
code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
code_emb = torch.repeat_interleave(
code_emb, self.conditioning_expansion, dim=-1)
code_emb = self.conditioning_timestep_integrator(
code_emb, time_emb)
first = True
time_emb = time_emb.float()
h = x
for k, module in enumerate(self.input_blocks):
if isinstance(module, nn.Conv1d):
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
h_tok = F.interpolate(module(code_emb), size=(
h.shape[-1]), mode='nearest')
h = h + h_tok
else:
with autocast(x.device.type, enabled=self.enable_fp16 and not first):
@ -378,12 +401,12 @@ def register_unet_diffusion_waveform_gen2(opt_net, opt):
if __name__ == '__main__':
clip = torch.randn(2, 1, 32868)
aligned_sequence = torch.randn(2,120,220)
aligned_sequence = torch.randn(2, 120, 220)
ts = torch.LongTensor([600, 600])
model = DiffusionWaveformGen(128,
channel_mult=[1,1.5,2, 3, 4, 6, 8],
channel_mult=[1, 1.5, 2, 3, 4, 6, 8],
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
token_conditioning_resolutions=[1,4,16,64],
token_conditioning_resolutions=[1, 4, 16, 64],
kernel_size=3,
scale_factor=2,
time_embed_dim_multiplier=4,
@ -391,4 +414,3 @@ if __name__ == '__main__':
efficient_convs=False)
# Test with sequence aligned conditioning
o = model(clip, ts, aligned_sequence)

@ -1,12 +1,13 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2PreTrainedModel, GPT2Config
from transformers import GPT2Config, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from models.arch_util import AttentionBlock
from models.lucidrains.x_transformers import TransformerWrapper, Decoder, Encoder
from trainer.networks import register_model
from dlas.models.arch_util import AttentionBlock
from dlas.models.lucidrains.x_transformers import (Decoder, Encoder,
TransformerWrapper)
from dlas.trainer.networks import register_model
class InferenceModel(GPT2PreTrainedModel):
@ -14,6 +15,7 @@ class InferenceModel(GPT2PreTrainedModel):
Implementation of GPT2PreTrainedModel from transformers, which allows us to use their generation library with
this transformer.
"""
def __init__(self, model):
super().__init__(GPT2Config())
self.transformer = model
@ -83,7 +85,8 @@ class InferenceModel(GPT2PreTrainedModel):
):
assert self.context is not None
assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model.
# Training not supported by this inference model.
assert labels is None
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
out = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True, past_key_values=past_key_values,
@ -115,7 +118,8 @@ class InferenceModel(GPT2PreTrainedModel):
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
tuple(past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past)
for layer_past in past
)
@ -124,6 +128,7 @@ class ResBlock(nn.Module):
"""
Basic residual convolutional block that uses GroupNorm.
"""
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
@ -148,11 +153,13 @@ class ConditioningEncoder(nn.Module):
super().__init__()
attn = []
self.init = nn.Sequential(nn.Conv1d(spec_dim, embedding_dim//4, kernel_size=5, padding=2),
nn.Conv1d(embedding_dim//4, embedding_dim//2, kernel_size=3, padding=1, stride=2),
nn.Conv1d(embedding_dim//4, embedding_dim //
2, kernel_size=3, padding=1, stride=2),
ResBlock(embedding_dim//2),
nn.Conv1d(embedding_dim//2, embedding_dim, kernel_size=3, padding=1, stride=2))
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
attn.append(AttentionBlock(embedding_dim,
num_attn_heads, do_checkpoint=do_checkpointing))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
@ -165,50 +172,53 @@ class ConditioningEncoder(nn.Module):
class AutoregressiveCodegen(nn.Module):
def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1):
super().__init__()
assert depth >= 8 # This is the minimum bound to support the context interleaving that happens later.
# This is the minimum bound to support the context interleaving that happens later.
assert depth >= 8
self.START_TOKEN=8192
self.STOP_TOKEN=8193
self.START_TOKEN = 8192
self.STOP_TOKEN = 8193
self.START_TEXT_TOKEN = 255
self.STOP_TEXT_TOKEN = 0
self.max_text_token_id = num_text_tokens
self.max_mel_token_id = num_mel_tokens
self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False)
self.mel_embedding = ConditioningEncoder(
80, model_dim, do_checkpointing=False)
self.encoder = TransformerWrapper(
num_tokens=num_text_tokens,
use_pos_emb=False,
max_seq_len=-1,
attn_layers = Encoder(
depth=depth,
heads=model_dim//64,
dim=model_dim,
attn_dropout=dropout,
ff_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
ff_mult=1,
rotary_pos_emb=True,
attn_rel_pos_bias=True,
))
self.encoder.norm = nn.Identity() # This layer and the next are unused.
num_tokens=num_text_tokens,
use_pos_emb=False,
max_seq_len=-1,
attn_layers=Encoder(
depth=depth,
heads=model_dim//64,
dim=model_dim,
attn_dropout=dropout,
ff_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
ff_mult=1,
rotary_pos_emb=True,
attn_rel_pos_bias=True,
))
# This layer and the next are unused.
self.encoder.norm = nn.Identity()
self.encoder.to_logits = nn.Identity()
self.decoder = TransformerWrapper(
num_tokens=num_mel_tokens,
use_pos_emb=False,
max_seq_len=-1,
attn_layers=Decoder(
depth=depth,
heads=model_dim//64,
dim=model_dim,
attn_dropout=dropout,
ff_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
ff_mult=1,
rotary_pos_emb=True,
cross_attend=True,
attn_rel_pos_bias=True,
))
num_tokens=num_mel_tokens,
use_pos_emb=False,
max_seq_len=-1,
attn_layers=Decoder(
depth=depth,
heads=model_dim//64,
dim=model_dim,
attn_dropout=dropout,
ff_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
ff_mult=1,
rotary_pos_emb=True,
cross_attend=True,
attn_rel_pos_bias=True,
))
def get_grad_norm_parameter_groups(self):
return {
@ -218,8 +228,10 @@ class AutoregressiveCodegen(nn.Module):
}
def forward(self, text_codes, conditioning_signal, mel_codes, wav_lengths, return_loss=True):
assert text_codes.max() < self.max_text_token_id and text_codes.min() >= 0, f'Invalid text code encountered: {text_codes.max()}, {text_codes.min()}'
assert mel_codes.max() < self.max_mel_token_id and mel_codes.min() >= 0, f'Invalid mel code encountered: {mel_codes.max()}, {mel_codes.min()}'
assert text_codes.max() < self.max_text_token_id and text_codes.min(
) >= 0, f'Invalid text code encountered: {text_codes.max()}, {text_codes.min()}'
assert mel_codes.max() < self.max_mel_token_id and mel_codes.min(
) >= 0, f'Invalid mel code encountered: {mel_codes.max()}, {mel_codes.min()}'
# Format mel_codes with a stop token on the end.
mel_lengths = wav_lengths // 1024 + 1
@ -235,8 +247,8 @@ class AutoregressiveCodegen(nn.Module):
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
# Since all positional embeddings are relative, it is (probably) important to "fix" the text with some permanent embeddings.
text_codes = F.pad(text_codes, (1,0), value=self.START_TEXT_TOKEN)
text_codes = F.pad(text_codes, (0,1), value=self.STOP_TEXT_TOKEN)
text_codes = F.pad(text_codes, (1, 0), value=self.START_TEXT_TOKEN)
text_codes = F.pad(text_codes, (0, 1), value=self.STOP_TEXT_TOKEN)
_, enc_text = self.encoder(text_codes, return_hiddens=True)
# Interleave cond_emb into the first few contexts.
full_context = enc_text
@ -245,11 +257,11 @@ class AutoregressiveCodegen(nn.Module):
full_context[6] = cond_emb
# Execute the decoder
dec_inputs = F.pad(mel_codes, (1,0), value=self.START_TOKEN)[:, :-1]
dec_inputs = F.pad(mel_codes, (1, 0), value=self.START_TOKEN)[:, :-1]
dec = self.decoder(dec_inputs, full_context=full_context)
if not return_loss:
return dec
loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
loss_mel = F.cross_entropy(dec.permute(0, 2, 1), mel_codes)
return loss_mel
def generate(self, conditioning_signal, text_codes, max_tokens=256, **hf_generate_kwargs):
@ -261,8 +273,8 @@ class AutoregressiveCodegen(nn.Module):
for i in range(conditioning_signal.shape[1]):
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
text_codes = F.pad(text_codes, (1,0), value=self.START_TEXT_TOKEN)
text_codes = F.pad(text_codes, (0,1), value=self.STOP_TEXT_TOKEN)
text_codes = F.pad(text_codes, (1, 0), value=self.START_TEXT_TOKEN)
text_codes = F.pad(text_codes, (0, 1), value=self.STOP_TEXT_TOKEN)
_, enc_text = self.encoder(text_codes, return_hiddens=True)
# Interleave cond_emb into the first few contexts.
full_context = enc_text
@ -273,7 +285,7 @@ class AutoregressiveCodegen(nn.Module):
gen = inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
max_length=max_tokens, output_attentions=False, return_dict_in_generate=True, use_cache=True,
**hf_generate_kwargs)
**hf_generate_kwargs)
return gen.sequences
@ -285,8 +297,8 @@ def register_autoregressive_codegen(opt_net, opt):
if __name__ == '__main__':
codegen = AutoregressiveCodegen(256, 10)
torch.save(codegen.state_dict(), 'sample.pth')
#codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200)))
codegen(torch.randint(0,256, (2,200)),
torch.randn(2,80,120),
torch.randint(0,8192, (2,350)),
torch.tensor([192,350]))
# codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200)))
codegen(torch.randint(0, 256, (2, 200)),
torch.randn(2, 80, 120),
torch.randint(0, 8192, (2, 350)),
torch.tensor([192, 350]))

@ -1,12 +1,13 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2PreTrainedModel, GPT2Config
from transformers import GPT2Config, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from models.arch_util import AttentionBlock
from models.lucidrains.x_transformers import TransformerWrapper, Encoder, Decoder
from trainer.networks import register_model
from dlas.models.arch_util import AttentionBlock
from dlas.models.lucidrains.x_transformers import (Decoder, Encoder,
TransformerWrapper)
from dlas.trainer.networks import register_model
class InferenceModel(GPT2PreTrainedModel):
@ -14,6 +15,7 @@ class InferenceModel(GPT2PreTrainedModel):
Implementation of GPT2PreTrainedModel from transformers, which allows us to use their generation library with
this transformer.
"""
def __init__(self, model):
super().__init__(GPT2Config())
self.transformer = model
@ -83,10 +85,12 @@ class InferenceModel(GPT2PreTrainedModel):
):
assert self.context is not None
assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model.
# Training not supported by this inference model.
assert labels is None
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.transformer.decoder(input_ids, context=self.context, return_embeddings=True)
hidden_states = self.transformer.decoder(
input_ids, context=self.context, return_embeddings=True)
logits = self.transformer.decoder.to_logits(hidden_states)
if not return_dict:
@ -109,7 +113,8 @@ class InferenceModel(GPT2PreTrainedModel):
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
tuple(past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past)
for layer_past in past
)
@ -118,6 +123,7 @@ class ResBlock(nn.Module):
"""
Basic residual convolutional block that uses GroupNorm.
"""
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
@ -142,11 +148,13 @@ class ConditioningEncoder(nn.Module):
super().__init__()
attn = []
self.init = nn.Sequential(nn.Conv1d(spec_dim, embedding_dim//4, kernel_size=5, padding=2),
nn.Conv1d(embedding_dim//4, embedding_dim//2, kernel_size=3, padding=1, stride=2),
nn.Conv1d(embedding_dim//4, embedding_dim //
2, kernel_size=3, padding=1, stride=2),
ResBlock(embedding_dim//2),
nn.Conv1d(embedding_dim//2, embedding_dim, kernel_size=3, padding=1, stride=2))
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
attn.append(AttentionBlock(embedding_dim,
num_attn_heads, do_checkpoint=do_checkpointing))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
@ -160,45 +168,46 @@ class AutoregressiveCodegen(nn.Module):
def __init__(self, model_dim, encoder_depth, decoder_depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1, ff_mult=1):
super().__init__()
self.START_TOKEN=8192
self.STOP_TOKEN=8193
self.START_TOKEN = 8192
self.STOP_TOKEN = 8193
self.max_text_token_id = num_text_tokens
self.max_mel_token_id = num_mel_tokens
self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False)
self.mel_embedding = ConditioningEncoder(
80, model_dim, do_checkpointing=False)
self.encoder = TransformerWrapper(
num_tokens=num_text_tokens,
use_pos_emb=False,
max_seq_len=-1,
attn_layers = Encoder(
depth=encoder_depth,
heads=model_dim//64,
dim=model_dim,
attn_dropout=dropout,
ff_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
ff_mult=ff_mult,
rotary_pos_emb=True,
attn_rel_pos_bias=True,
))
num_tokens=num_text_tokens,
use_pos_emb=False,
max_seq_len=-1,
attn_layers=Encoder(
depth=encoder_depth,
heads=model_dim//64,
dim=model_dim,
attn_dropout=dropout,
ff_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
ff_mult=ff_mult,
rotary_pos_emb=True,
attn_rel_pos_bias=True,
))
self.encoder.to_logits = nn.Identity() # This is unused.
self.decoder = TransformerWrapper(
num_tokens=num_mel_tokens,
use_pos_emb=False,
max_seq_len=-1,
attn_layers=Decoder(
depth=decoder_depth,
heads=model_dim//64,
dim=model_dim,
attn_dropout=dropout,
ff_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
ff_mult=ff_mult,
rotary_pos_emb=True,
cross_attend=True,
attn_rel_pos_bias=True,
))
num_tokens=num_mel_tokens,
use_pos_emb=False,
max_seq_len=-1,
attn_layers=Decoder(
depth=decoder_depth,
heads=model_dim//64,
dim=model_dim,
attn_dropout=dropout,
ff_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
ff_mult=ff_mult,
rotary_pos_emb=True,
cross_attend=True,
attn_rel_pos_bias=True,
))
def get_grad_norm_parameter_groups(self):
return {
@ -208,8 +217,10 @@ class AutoregressiveCodegen(nn.Module):
}
def forward(self, text_codes, conditioning_signal, mel_codes, wav_lengths, return_loss=True):
assert text_codes.max() < self.max_text_token_id and text_codes.min() >= 0, f'Invalid text code encountered: {text_codes.max()}, {text_codes.min()}'
assert mel_codes.max() < self.max_mel_token_id and mel_codes.min() >= 0, f'Invalid mel code encountered: {mel_codes.max()}, {mel_codes.min()}'
assert text_codes.max() < self.max_text_token_id and text_codes.min(
) >= 0, f'Invalid text code encountered: {text_codes.max()}, {text_codes.min()}'
assert mel_codes.max() < self.max_mel_token_id and mel_codes.min(
) >= 0, f'Invalid mel code encountered: {mel_codes.max()}, {mel_codes.min()}'
# Format mel_codes with a stop token on the end.
mel_lengths = wav_lengths // 1024 + 1
@ -228,11 +239,11 @@ class AutoregressiveCodegen(nn.Module):
context = torch.cat([cond_emb, enc_text], dim=1)
# Execute the decoder
dec_inputs = F.pad(mel_codes, (1,0), value=self.START_TOKEN)[:, :-1]
dec_inputs = F.pad(mel_codes, (1, 0), value=self.START_TOKEN)[:, :-1]
dec = self.decoder(dec_inputs, context=context)
if not return_loss:
return dec
loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
loss_mel = F.cross_entropy(dec.permute(0, 2, 1), mel_codes)
return loss_mel
def generate(self, conditioning_signal, text_codes, max_tokens=1024, **hf_generate_kwargs):
@ -263,8 +274,9 @@ def register_autoregressive_codegen2(opt_net, opt):
if __name__ == '__main__':
codegen = AutoregressiveCodegen(512, 20)
torch.save(codegen.state_dict(), 'sample.pth')
codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200)))
codegen(torch.randint(0,256, (2,200)),
torch.randn(2,80,120),
torch.randint(0,8192, (2,350)),
torch.tensor([192,350]))
codegen.generate(torch.randn((1, 80, 120)),
torch.randint(0, 256, (1, 200)))
codegen(torch.randint(0, 256, (2, 200)),
torch.randn(2, 80, 120),
torch.randint(0, 8192, (2, 350)),
torch.tensor([192, 350]))

@ -3,12 +3,12 @@ from random import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_intermediary as ml
from models.audio.tts.unet_diffusion_tts7 import CheckpointedLayer
from models.lucidrains.x_transformers import Encoder
from trainer.networks import register_model
from utils.util import opt_get
import dlas.torch_intermediary as ml
from dlas.models.audio.tts.unet_diffusion_tts7 import CheckpointedLayer
from dlas.models.lucidrains.x_transformers import Encoder
from dlas.trainer.networks import register_model
from dlas.utils.util import opt_get
class CheckpointedXTransformerEncoder(nn.Module):
@ -16,6 +16,7 @@ class CheckpointedXTransformerEncoder(nn.Module):
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
to channels-last that XTransformer expects.
"""
def __init__(self, **xtransformer_kwargs):
super().__init__()
self.transformer = XTransformer(**xtransformer_kwargs)
@ -23,7 +24,8 @@ class CheckpointedXTransformerEncoder(nn.Module):
for xform in [self.transformer.encoder, self.transformer.decoder.net]:
for i in range(len(xform.attn_layers.layers)):
n, b, r = xform.attn_layers.layers[i]
xform.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
xform.attn_layers.layers[i] = nn.ModuleList(
[n, CheckpointedLayer(b), r])
def forward(self, *args, **kwargs):
return self.transformer(*args, **kwargs)
@ -45,20 +47,21 @@ class CtcCodeGenerator(nn.Module):
self.recursive_embedding = ml.Embedding(pred_codes, model_dim)
self.mask_embedding = nn.Parameter(torch.randn(model_dim))
self.encoder = Encoder(
dim=model_dim,
depth=layers,
heads=model_dim//64,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
)
dim=model_dim,
depth=layers,
heads=model_dim//64,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
)
self.pred_head = ml.Linear(model_dim, pred_codes)
self.confidence_head = ml.Linear(model_dim, 1)
def inference(self, codes, pads, repeats):
position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device))
position_h = self.position_embedding(
torch.arange(0, codes.shape[-1], device=codes.device))
codes_h = self.codes_embedding(codes)
labels = pads + repeats * self.max_pad
@ -83,13 +86,15 @@ class CtcCodeGenerator(nn.Module):
print(f"Got unexpectedly long pads. Max: {pads.max()}, {pads}")
pads = torch.clip(pads, 0, self.max_pad)
if repeats.max() > self.max_repeat:
print(f"Got unexpectedly long repeats. Max: {repeats.max()}, {repeats}")
print(
f"Got unexpectedly long repeats. Max: {repeats.max()}, {repeats}")
repeats = torch.clip(repeats, 0, self.max_repeat)
assert codes.max() < self.ctc_codes, codes.max()
labels = pads + repeats * self.max_pad
position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device))
position_h = self.position_embedding(
torch.arange(0, codes.shape[-1], device=codes.device))
codes_h = self.codes_embedding(codes)
recursive_h = self.recursive_embedding(labels)
@ -101,12 +106,14 @@ class CtcCodeGenerator(nn.Module):
h = self.encoder(position_h + codes_h + recursive_h)
pred_logits = self.pred_head(h)
loss = F.cross_entropy(pred_logits.permute(0,2,1), labels, reduce=False)
loss = F.cross_entropy(pred_logits.permute(
0, 2, 1), labels, reduce=False)
confidences = self.confidence_head(h).squeeze(-1)
confidences = F.softmax(confidences * mask, dim=-1)
confidence_loss = loss * confidences
loss = loss / loss.shape[-1] # This balances the confidence_loss and loss.
# This balances the confidence_loss and loss.
loss = loss / loss.shape[-1]
return loss.mean(), confidence_loss.mean()
@ -118,8 +125,8 @@ def register_ctc_code_generator(opt_net, opt):
if __name__ == '__main__':
model = CtcCodeGenerator()
inps = torch.randint(0,36, (4, 300))
pads = torch.randint(0,100, (4,300))
repeats = torch.randint(0,20, (4,300))
inps = torch.randint(0, 36, (4, 300))
pads = torch.randint(0, 100, (4, 300))
repeats = torch.randint(0, 20, (4, 300))
loss = model(inps, pads, repeats)
print(loss.shape)
print(loss.shape)

@ -1,16 +1,22 @@
import functools
import math
import random
from functools import partial
import torch
import torch.nn as nn
import torch_intermediary as ml
from x_transformers.x_transformers import (DEFAULT_DIM_HEAD,
AlibiPositionalBias, Attention,
AttentionLayers, FeedForward,
FixedPositionalEmbedding, GRUGating,
LayerIntermediates,
LearnedAlibiPositionalBias,
RelativePositionBias, Residual,
RMSNorm, RotaryEmbedding, Scale,
ScaleNorm, ShiftTokens, cast_tuple,
default, equals, exists,
groupby_prefix_and_trim, not_equals)
from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \
DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, \
exists, Attention, FeedForward, Scale, ShiftTokens, GRUGating, Residual, cast_tuple, equals, LayerIntermediates, \
AttentionLayers, not_equals
import dlas.torch_intermediary as ml
class TimeIntegrationBlock(nn.Module):
@ -36,40 +42,41 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
"""
Modification of x-transformers.AttentionLayers that performs timestep embeddings and layerdrop.
"""
def __init__(
self,
dim,
timestep_dim,
depth,
heads = 8,
causal = False,
cross_attend = False,
only_cross = False,
use_scalenorm = False,
use_rmsnorm = False,
use_rezero = False,
alibi_pos_bias = False,
alibi_num_heads = None,
alibi_learned = False,
rel_pos_bias = False,
rel_pos_num_buckets = 32,
rel_pos_max_distance = 128,
position_infused_attn = False,
rotary_pos_emb = False,
rotary_emb_dim = None,
custom_layers = None,
sandwich_coef = None,
par_ratio = None,
residual_attn = False,
cross_residual_attn = False,
macaron = False,
gate_residual = False,
scale_residual = False,
shift_tokens = 0,
use_qk_norm_attn = False,
qk_norm_attn_seq_len = None,
zero_init_branch_output = False,
layerdrop_percent = .1,
heads=8,
causal=False,
cross_attend=False,
only_cross=False,
use_scalenorm=False,
use_rmsnorm=False,
use_rezero=False,
alibi_pos_bias=False,
alibi_num_heads=None,
alibi_learned=False,
rel_pos_bias=False,
rel_pos_num_buckets=32,
rel_pos_max_distance=128,
position_infused_attn=False,
rotary_pos_emb=False,
rotary_emb_dim=None,
custom_layers=None,
sandwich_coef=None,
par_ratio=None,
residual_attn=False,
cross_residual_attn=False,
macaron=False,
gate_residual=False,
scale_residual=False,
shift_tokens=0,
use_qk_norm_attn=False,
qk_norm_attn_seq_len=None,
zero_init_branch_output=False,
layerdrop_percent=.1,
**kwargs
):
super().__init__(dim, depth)
@ -84,21 +91,26 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
self.layerdrop_percent = layerdrop_percent
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
self.pia_pos_emb = FixedPositionalEmbedding(
dim) if position_infused_attn else None
rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
self.rotary_pos_emb = RotaryEmbedding(
rotary_emb_dim) if rotary_pos_emb else None
assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
assert not (
alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
if rel_pos_bias:
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
self.rel_pos = RelativePositionBias(
scale=dim_head ** 0.5, causal=causal, heads=heads, num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance)
elif alibi_pos_bias:
alibi_num_heads = default(alibi_num_heads, heads)
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
self.rel_pos = alibi_pos_klass(heads = alibi_num_heads, bidirectional = not causal)
self.rel_pos = alibi_pos_klass(
heads=alibi_num_heads, bidirectional=not causal)
else:
self.rel_pos = None
@ -125,8 +137,10 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
# qk normalization
if use_qk_norm_attn:
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None
attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len **
2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None
attn_kwargs = {**attn_kwargs, 'qk_norm': True,
'scale_init_value': attn_scale_init_value}
# zero init
if zero_init_branch_output:
@ -140,16 +154,20 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
par_depth = depth * len(default_block)
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
default_block = tuple(filter(not_equals('f'), default_block))
par_attn = par_depth // par_ratio
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
par_attn = par_depth // par_ratio
# 2 / 3 attention layer cutoff suggested by PAR paper
depth_cut = par_depth * 2 // 3
par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
par_block = default_block + ('f',) * (par_width - len(default_block))
assert len(
default_block) <= par_width, 'default block is too large for par_ratio'
par_block = default_block + \
('f',) * (par_width - len(default_block))
par_head = par_block * par_attn
layer_types = par_head + ('f',) * (par_depth - len(par_head))
elif exists(sandwich_coef):
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
layer_types = ('a',) * sandwich_coef + default_block * \
(depth - sandwich_coef) + ('f',) * sandwich_coef
else:
layer_types = default_block * depth
@ -163,9 +181,10 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
# iterate and construct layers
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
if layer_type == 'a':
layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
layer = Attention(dim, heads=heads,
causal=causal, **attn_kwargs)
elif layer_type == 'c':
layer = Attention(dim, heads = heads, **attn_kwargs)
layer = Attention(dim, heads=heads, **attn_kwargs)
elif layer_type == 'f':
layer = FeedForward(dim, **ff_kwargs)
layer = layer if not macaron else Scale(0.5, layer)
@ -175,19 +194,22 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
if layer_shift_tokens > 0:
shift_range_upper = layer_shift_tokens + 1
shift_range_lower = -layer_shift_tokens if not causal else 0
layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
layer = ShiftTokens(
range(shift_range_lower, shift_range_upper), layer)
if exists(branch_fn):
layer = branch_fn(layer)
residual_fn = GRUGating if gate_residual else Residual
residual = residual_fn(dim, scale_residual = scale_residual)
residual = residual_fn(dim, scale_residual=scale_residual)
layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c')
pre_branch_norm = TimeIntegrationBlock(timestep_dim, dim, norm_fn())
pre_branch_norm = TimeIntegrationBlock(
timestep_dim, dim, norm_fn())
post_branch_norm = norm_fn() if layer_uses_qk_norm else None
post_main_norm = None # Always do prenorm for timestep integration.
# Always do prenorm for timestep integration.
post_main_norm = None
norms = nn.ModuleList([
pre_branch_norm,
@ -204,15 +226,16 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
def forward(
self,
x,
time_emb = None,
context = None,
mask = None,
context_mask = None,
attn_mask = None,
mems = None,
return_hiddens = False
time_emb=None,
context=None,
mask=None,
context_mask=None,
attn_mask=None,
mems=None,
return_hiddens=False
):
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
assert not (self.cross_attend ^ exists(
context)), 'context must be passed in if cross_attend is set to True'
assert time_emb is not None, 'must specify a timestep embedding.'
hiddens = []
@ -224,8 +247,10 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
rotary_pos_emb = None
if exists(self.rotary_pos_emb):
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
max_rotary_emb_length = max(
list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
rotary_pos_emb = self.rotary_pos_emb(
max_rotary_emb_length, x.device)
unused_params = []
to_drop = 0
@ -253,9 +278,11 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
x = pre_branch_norm(x, time_emb)
if layer_type == 'a':
out, inter = block(x, mask = mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
out, inter = block(x, mask=mask, attn_mask=attn_mask, sinusoidal_emb=self.pia_pos_emb,
rel_pos=self.rel_pos, rotary_pos_emb=rotary_pos_emb, prev_attn=prev_attn, mem=layer_mem)
elif layer_type == 'c':
out, inter = block(x, context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn)
out, inter = block(
x, context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
elif layer_type == 'f':
out = block(x)
@ -283,10 +310,10 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
if return_hiddens:
intermediates = LayerIntermediates(
hiddens = hiddens,
attn_intermediates = intermediates
hiddens=hiddens,
attn_intermediates=intermediates
)
return x, intermediates
return x
return x

@ -1,17 +1,15 @@
import functools
import math
from math import sqrt
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import einsum
from vector_quantize_pytorch import VectorQuantize
from models.vqvae.vqvae import Quantize
from trainer.networks import register_model
from utils.util import opt_get
from dlas.models.vqvae.vqvae import Quantize
from dlas.trainer.networks import register_model
from dlas.utils.util import opt_get
def default(val, d):
@ -32,9 +30,9 @@ class ResBlock(nn.Module):
def __init__(self, chan, conv, activation):
super().__init__()
self.net = nn.Sequential(
conv(chan, chan, 3, padding = 1),
conv(chan, chan, 3, padding=1),
activation(),
conv(chan, chan, 3, padding = 1),
conv(chan, chan, 3, padding=1),
activation(),
conv(chan, chan, 1)
)
@ -52,7 +50,8 @@ class UpsampledConv(nn.Module):
self.conv = conv(*args, **kwargs)
def forward(self, x):
up = nn.functional.interpolate(x, scale_factor=self.stride, mode='nearest')
up = nn.functional.interpolate(
x, scale_factor=self.stride, mode='nearest')
return self.conv(up)
@ -60,23 +59,23 @@ class DiscreteVAE(nn.Module):
def __init__(
self,
positional_dims=2,
num_tokens = 512,
codebook_dim = 512,
num_layers = 3,
num_resnet_blocks = 0,
hidden_dim = 64,
channels = 3,
stride = 2,
kernel_size = 4,
use_transposed_convs = True,
encoder_norm = False,
activation = 'relu',
smooth_l1_loss = False,
straight_through = False,
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
record_codes = False,
use_lr_quantizer = False,
lr_quantizer_args = {},
num_tokens=512,
codebook_dim=512,
num_layers=3,
num_resnet_blocks=0,
hidden_dim=64,
channels=3,
stride=2,
kernel_size=4,
use_transposed_convs=True,
encoder_norm=False,
activation='relu',
smooth_l1_loss=False,
straight_through=False,
normalization=None, # ((0.5,) * 3, (0.5,) * 3),
record_codes=False,
use_lr_quantizer=False,
lr_quantizer_args={},
):
super().__init__()
has_resblocks = num_resnet_blocks > 0
@ -86,7 +85,8 @@ class DiscreteVAE(nn.Module):
self.straight_through = straight_through
self.positional_dims = positional_dims
assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
# This VAE only supports 1d and 2d inputs for now.
assert positional_dims > 0 and positional_dims < 3
if positional_dims == 2:
conv = nn.Conv2d
conv_transpose = nn.ConvTranspose2d
@ -103,7 +103,6 @@ class DiscreteVAE(nn.Module):
else:
assert NotImplementedError()
enc_layers = []
dec_layers = []
@ -116,18 +115,22 @@ class DiscreteVAE(nn.Module):
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
dec_chans = [dec_init_chan, *dec_chans]
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
enc_chans_io, dec_chans_io = map(lambda t: list(
zip(t[:-1], t[1:])), (enc_chans, dec_chans))
pad = (kernel_size - 1) // 2
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), act()))
enc_layers.append(nn.Sequential(
conv(enc_in, enc_out, kernel_size, stride=stride, padding=pad), act()))
if encoder_norm:
enc_layers.append(nn.GroupNorm(8, enc_out))
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), act()))
dec_layers.append(nn.Sequential(conv_transpose(
dec_in, dec_out, kernel_size, stride=stride, padding=pad), act()))
dec_out_chans = dec_chans[-1]
innermost_dim = dec_chans[0]
else:
enc_layers.append(nn.Sequential(conv(channels, hidden_dim, 1), act()))
enc_layers.append(nn.Sequential(
conv(channels, hidden_dim, 1), act()))
dec_out_chans = hidden_dim
innermost_dim = hidden_dim
@ -138,7 +141,6 @@ class DiscreteVAE(nn.Module):
if num_resnet_blocks > 0:
dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1))
enc_layers.append(conv(innermost_dim, codebook_dim, 1))
dec_layers.append(conv(dec_out_chans, channels, 1))
@ -148,9 +150,11 @@ class DiscreteVAE(nn.Module):
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
if use_lr_quantizer:
self.codebook = VectorQuantize(dim=codebook_dim, codebook_size=num_tokens, **lr_quantizer_args)
self.codebook = VectorQuantize(
dim=codebook_dim, codebook_size=num_tokens, **lr_quantizer_args)
else:
self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
self.codebook = Quantize(
codebook_dim, num_tokens, new_return_order=True)
# take care of normalization within class
self.normalization = normalization
@ -165,7 +169,8 @@ class DiscreteVAE(nn.Module):
if not self.normalization is not None:
return images
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
means, stds = map(lambda t: torch.as_tensor(
t).to(images), self.normalization)
arrange = 'c -> () c () ()' if self.positional_dims == 2 else 'c -> () c ()'
means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
images = images.clone()
@ -183,7 +188,8 @@ class DiscreteVAE(nn.Module):
@eval_decorator
def get_codebook_indices(self, images):
img = self.norm(images)
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
logits = self.encoder(img).permute(
(0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
sampled, codes, _ = self.codebook(logits)
self.log_codes(codes)
return codes
@ -214,7 +220,8 @@ class DiscreteVAE(nn.Module):
def infer(self, img):
img = self.norm(img)
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
logits = self.encoder(img).permute(
(0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
sampled, codes, commitment_loss = self.codebook(logits)
return self.decode(codes)
@ -226,9 +233,11 @@ class DiscreteVAE(nn.Module):
img
):
img = self.norm(img)
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
logits = self.encoder(img).permute(
(0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
sampled, codes, commitment_loss = self.codebook(logits)
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
sampled = sampled.permute(
(0, 3, 1, 2) if len(img.shape) == 4 else (0, 2, 1))
if self.training:
out = sampled
@ -249,7 +258,8 @@ class DiscreteVAE(nn.Module):
if self.record_codes and self.internal_step % 10 == 0:
codes = codes.flatten()
l = codes.shape[0]
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
i = self.code_ind if (
self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
self.codes[i:i+l] = codes.cpu()
self.code_ind = self.code_ind + l
if self.code_ind >= self.codes.shape[0]:
@ -264,14 +274,14 @@ def register_lucidrains_dvae(opt_net, opt):
if __name__ == '__main__':
#v = DiscreteVAE()
#o=v(torch.randn(1,3,256,256))
#print(o.shape)
# v = DiscreteVAE()
# o=v(torch.randn(1,3,256,256))
# print(o.shape)
v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048,
hidden_dim=512, num_resnet_blocks=3, kernel_size=3, num_layers=1, use_transposed_convs=False,
use_lr_quantizer=True)
#v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth'))
#v.eval()
r,l,o=v(torch.randn(1,80,256))
v.decode(torch.randint(0,8192,(1,256)))
# v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth'))
# v.eval()
r, l, o = v(torch.randn(1, 80, 256))
v.decode(torch.randint(0, 8192, (1, 256)))
print(o.shape, l.shape)

@ -1,14 +1,14 @@
import torch
import torch.nn as nn
import torch_intermediary as ml
from models.diffusion.nn import normalization, conv_nd, zero_module
from models.diffusion.unet_diffusion import Downsample, AttentionBlock, QKVAttention, QKVAttentionLegacy, Upsample
import dlas.torch_intermediary as ml
from dlas.models.diffusion.nn import conv_nd, normalization, zero_module
from dlas.models.diffusion.unet_diffusion import (AttentionBlock, Downsample,
QKVAttention,
QKVAttentionLegacy, Upsample)
# Combined resnet & full-attention encoder for converting an audio clip into an embedding.
from trainer.networks import register_model
from utils.util import checkpoint, opt_get, sequential_checkpoint
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint, opt_get, sequential_checkpoint
class ResBlock(nn.Module):
@ -37,7 +37,8 @@ class ResBlock(nn.Module):
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
conv_nd(dims, channels, self.out_channels,
kernel_size, padding=padding),
)
self.updown = up or down
@ -56,7 +57,8 @@ class ResBlock(nn.Module):
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
conv_nd(dims, self.out_channels, self.out_channels,
kernel_size, padding=padding)
),
)
@ -67,7 +69,8 @@ class ResBlock(nn.Module):
dims, channels, self.out_channels, kernel_size, padding=padding
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 1)
def forward(self, x):
if self.do_checkpoint:
@ -111,8 +114,10 @@ class AudioMiniEncoder(nn.Module):
self.layers = depth
for l in range(depth):
for r in range(resnet_blocks):
res.append(ResBlock(ch, dropout, dims=1, do_checkpoint=False, kernel_size=kernel_size))
res.append(Downsample(ch, use_conv=True, dims=1, out_channels=ch*2, factor=downsample_factor))
res.append(ResBlock(ch, dropout, dims=1,
do_checkpoint=False, kernel_size=kernel_size))
res.append(Downsample(ch, use_conv=True, dims=1,
out_channels=ch*2, factor=downsample_factor))
ch *= 2
self.res = nn.Sequential(*res)
self.final = nn.Sequential(
@ -122,7 +127,8 @@ class AudioMiniEncoder(nn.Module):
)
attn = []
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False))
attn.append(AttentionBlock(embedding_dim,
num_attn_heads, do_checkpoint=False))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
@ -150,10 +156,12 @@ class AudioMiniEncoderWithClassifierHead(nn.Module):
return logits
else:
if self.distribute_zero_label:
oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes)
oh_labels = nn.functional.one_hot(
labels, num_classes=self.num_classes)
zeros_indices = (labels == 0).unsqueeze(-1)
# Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise.
zero_extra_mass = torch.full_like(oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1))
zero_extra_mass = torch.full_like(
oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1))
zero_extra_mass[:, 0] = -.2
zero_extra_mass = zero_extra_mass * zeros_indices
oh_labels = oh_labels + zero_extra_mass
@ -167,6 +175,7 @@ class QueryProvidedAttentionBlock(nn.Module):
"""
An attention block that provides a separate signal for the query vs the keys/parameters.
"""
def __init__(
self,
channels,
@ -200,12 +209,13 @@ class QueryProvidedAttentionBlock(nn.Module):
return checkpoint(self._forward, qx, kvx, mask)
def _forward(self, qx, kvx, mask=None):
q = self.q(self.qnorm(qx)).unsqueeze(1).repeat(1, kvx.shape[1], 1).permute(0,2,1)
kv = self.kv(self.norm(kvx.permute(0,2,1)))
q = self.q(self.qnorm(qx)).unsqueeze(1).repeat(
1, kvx.shape[1], 1).permute(0, 2, 1)
kv = self.kv(self.norm(kvx.permute(0, 2, 1)))
qkv = torch.cat([q, kv], dim=1)
h = self.attention(qkv, mask)
h = self.proj_out(h)
return kvx + h.permute(0,2,1)
return kvx + h.permute(0, 2, 1)
# Next up: combine multiple embeddings given a conditioning signal into a single embedding.
@ -213,7 +223,8 @@ class EmbeddingCombiner(nn.Module):
def __init__(self, embedding_dim, attn_blocks=3, num_attn_heads=2, cond_provided=True):
super().__init__()
block = QueryProvidedAttentionBlock if cond_provided else AttentionBlock
self.attn = nn.ModuleList([block(embedding_dim, num_attn_heads) for _ in range(attn_blocks)])
self.attn = nn.ModuleList(
[block(embedding_dim, num_attn_heads) for _ in range(attn_blocks)])
self.cond_provided = cond_provided
# x_s: (b,n,d); b=batch_sz, n=number of embeddings, d=embedding_dim

@ -3,10 +3,10 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_intermediary as ml
from trainer.networks import register_model
from utils.util import opt_get
import dlas.torch_intermediary as ml
from dlas.trainer.networks import register_model
from dlas.utils.util import opt_get
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
@ -61,4 +61,4 @@ def register_random_latent_converter(opt_net, opt):
if __name__ == '__main__':
model = RandomLatentConverter(512)
model(torch.randn(5,512))
model(torch.randn(5, 512))

@ -1,6 +1,6 @@
from models.audio.tts.tacotron2.taco_utils import *
from models.audio.tts.tacotron2.text import *
from models.audio.tts.tacotron2.tacotron2 import *
from models.audio.tts.tacotron2.stft import *
from models.audio.tts.tacotron2.layers import *
from models.audio.tts.tacotron2.loss import *
from dlas.models.audio.tts.tacotron2.layers import *
from dlas.models.audio.tts.tacotron2.loss import *
from dlas.models.audio.tts.tacotron2.stft import *
from dlas.models.audio.tts.tacotron2.taco_utils import *
from dlas.models.audio.tts.tacotron2.tacotron2 import *
from dlas.models.audio.tts.tacotron2.text import *

@ -1,7 +1,7 @@
import torch
import librosa.util as librosa_util
import numpy as np
import torch
from scipy.signal import get_window
import librosa.util as librosa_util
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
@ -52,7 +52,8 @@ def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
# Fill the envelope
for i in range(n_frames):
sample = i * hop_length
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
x[sample:min(n, sample + n_fft)
] += win_sq[:max(0, min(n_fft, n - sample))]
return x
@ -90,4 +91,4 @@ def dynamic_range_decompression(x, C=1):
------
C: compression factor used to compress
"""
return torch.exp(x) / C
return torch.exp(x) / C

@ -1,5 +1,5 @@
#import tensorflow as tf
from models.audio.tts.tacotron2.text import symbols
# import tensorflow as tf
from dlas.models.audio.tts.tacotron2.text import symbols
def create_hparams(hparams_string=None, verbose=False):
@ -33,7 +33,8 @@ def create_hparams(hparams_string=None, verbose=False):
# Audio Parameters #
################################
max_wav_value=32768.0,
input_sample_rate=22050, # When different from sampling_rate, dataset automatically interpolates to sampling_rate
# When different from sampling_rate, dataset automatically interpolates to sampling_rate
input_sample_rate=22050,
sampling_rate=22050,
filter_length=1024,
hop_length=256, # This means a MEL is 1/256th the equivalent audio.
@ -86,4 +87,4 @@ def create_hparams(hparams_string=None, verbose=False):
mask_padding=True # set model's padded outputs to padded values
)
return hparams
return hparams

@ -1,9 +1,10 @@
import torch
from librosa.filters import mel as librosa_mel_fn
from models.audio.tts.tacotron2.audio_processing import dynamic_range_compression
from models.audio.tts.tacotron2.audio_processing import dynamic_range_decompression
from models.audio.tts.tacotron2.stft import STFT
import torch_intermediary as ml
import dlas.torch_intermediary as ml
from dlas.models.audio.tts.tacotron2.audio_processing import (
dynamic_range_compression, dynamic_range_decompression)
from dlas.models.audio.tts.tacotron2.stft import STFT
class LinearNorm(torch.nn.Module):
@ -24,7 +25,7 @@ class ConvNorm(torch.nn.Module):
padding=None, dilation=1, bias=True, w_init_gain='linear'):
super(ConvNorm, self).__init__()
if padding is None:
assert(kernel_size % 2 == 1)
assert (kernel_size % 2 == 1)
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(in_channels, out_channels,
@ -71,8 +72,8 @@ class TacotronSTFT(torch.nn.Module):
-------
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
"""
assert(torch.min(y.data) >= -10)
assert(torch.max(y.data) <= 10)
assert (torch.min(y.data) >= -10)
assert (torch.max(y.data) <= 10)
y = torch.clip(y, min=-1, max=1)
magnitudes, phases = self.stft_fn.transform(y)

@ -1,6 +1,6 @@
from torch import nn
from trainer.losses import ConfigurableLoss
from dlas.trainer.losses import ConfigurableLoss
class Tacotron2Loss(ConfigurableLoss):
@ -20,7 +20,8 @@ class Tacotron2Loss(ConfigurableLoss):
gate_target.requires_grad = False
gate_target = gate_target.view(-1, 1)
mel_out, mel_out_postnet, gate_out = state[self.mel_output_key], state[self.mel_output_postnet_key], state[self.gate_output_key]
mel_out, mel_out_postnet, gate_out = state[self.mel_output_key], state[
self.mel_output_postnet_key], state[self.gate_output_key]
gate_out = gate_out.view(-1, 1)
mel_loss = nn.MSELoss()(mel_out, mel_target) + \
nn.MSELoss()(mel_out_postnet, mel_target)
@ -61,4 +62,4 @@ class Tacotron2LossRaw(nn.Module):
return {
'mel_loss': self.last_mel_loss,
'gate_loss': self.last_gate_loss
}
}

@ -30,17 +30,19 @@ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
import torch
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from scipy.signal import get_window
from librosa.util import pad_center, tiny
from models.audio.tts.tacotron2.audio_processing import window_sumsquare
from scipy.signal import get_window
from torch.autograd import Variable
from dlas.models.audio.tts.tacotron2.audio_processing import window_sumsquare
class STFT(torch.nn.Module):
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
def __init__(self, filter_length=800, hop_length=200, win_length=800,
window='hann'):
super(STFT, self).__init__()
@ -61,7 +63,7 @@ class STFT(torch.nn.Module):
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
if window is not None:
assert(filter_length >= win_length)
assert (filter_length >= win_length)
# get window and zero center pad it to filter_length
fft_window = get_window(window, win_length, fftbins=True)
fft_window = pad_center(fft_window, filter_length)
@ -125,17 +127,19 @@ class STFT(torch.nn.Module):
window_sum = torch.autograd.Variable(
torch.from_numpy(window_sum), requires_grad=False)
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
inverse_transform[:, :,
approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
# scale by hop ratio
inverse_transform *= float(self.filter_length) / self.hop_length
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
inverse_transform = inverse_transform[:,
:, :-int(self.filter_length/2):]
return inverse_transform
def forward(self, input_data):
self.magnitude, self.phase = self.transform(input_data)
reconstruction = self.inverse(self.magnitude, self.phase)
return reconstruction
return reconstruction

@ -8,7 +8,8 @@ from scipy.io.wavfile import read
def get_mask_from_lengths(lengths, max_len=None):
if max_len is None:
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).to(lengths.device)
ids = torch.arange(0, max_len, out=torch.LongTensor(
max_len)).to(lengths.device)
mask = (ids < lengths.unsqueeze(1)).bool()
return mask
@ -22,24 +23,29 @@ def load_wav_to_torch(full_path):
elif data.dtype == np.float16 or data.dtype == np.float32:
norm_fix = 1.
else:
raise NotImplemented(f"Provided data dtype not supported: {data.dtype}")
raise NotImplemented(
f"Provided data dtype not supported: {data.dtype}")
return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate)
def load_filepaths_and_text_type(filename, type, split="|"):
with open(filename, encoding='utf-8') as f:
filepaths_and_text = [list(line.strip().split(split)) + [type] for line in f]
filepaths_and_text = [
list(line.strip().split(split)) + [type] for line in f]
base = os.path.dirname(filename)
for j in range(len(filepaths_and_text)):
filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])
filepaths_and_text[j][0] = os.path.join(
base, filepaths_and_text[j][0])
return filepaths_and_text
def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding='utf-8') as f:
filepaths_and_text = [line.strip().split(split) for line in f]
base = os.path.dirname(filename)
for j in range(len(filepaths_and_text)):
filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])
filepaths_and_text[j][0] = os.path.join(
base, filepaths_and_text[j][0])
return filepaths_and_text
@ -48,4 +54,4 @@ def to_gpu(x):
if torch.cuda.is_available():
x = x.cuda(non_blocking=True)
return torch.autograd.Variable(x)
return torch.autograd.Variable(x)

@ -1,14 +1,16 @@
from math import sqrt
import torch
from munch import munchify
from torch.autograd import Variable
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
from models.audio.tts.tacotron2.layers import ConvNorm, LinearNorm
from models.audio.tts.tacotron2.hparams import create_hparams
from trainer.networks import register_model
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
import torch_intermediary as ml
import dlas.torch_intermediary as ml
from dlas.models.audio.tts.tacotron2.hparams import create_hparams
from dlas.models.audio.tts.tacotron2.layers import ConvNorm, LinearNorm
from dlas.models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
from dlas.trainer.networks import register_model
class LocationLayer(nn.Module):
@ -59,7 +61,8 @@ class Attention(nn.Module):
"""
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat)
processed_attention_weights = self.location_layer(
attention_weights_cat)
energies = self.v(torch.tanh(
processed_query + processed_attention_weights + processed_memory))
@ -128,7 +131,8 @@ class Postnet(nn.Module):
ConvNorm(hparams.postnet_embedding_dim,
hparams.postnet_embedding_dim,
kernel_size=hparams.postnet_kernel_size, stride=1,
padding=int((hparams.postnet_kernel_size - 1) / 2),
padding=int(
(hparams.postnet_kernel_size - 1) / 2),
dilation=1, w_init_gain='tanh'),
nn.BatchNorm1d(hparams.postnet_embedding_dim))
)
@ -140,11 +144,12 @@ class Postnet(nn.Module):
padding=int((hparams.postnet_kernel_size - 1) / 2),
dilation=1, w_init_gain='linear'),
nn.BatchNorm1d(hparams.n_mel_channels))
)
)
def forward(self, x):
for i in range(len(self.convolutions) - 1):
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
x = F.dropout(torch.tanh(
self.convolutions[i](x)), 0.5, self.training)
x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
return x
@ -155,6 +160,7 @@ class Encoder(nn.Module):
- Three 1-d convolution banks
- Bidirectional LSTM
"""
def __init__(self, hparams):
super(Encoder, self).__init__()
@ -408,7 +414,8 @@ class Decoder(nn.Module):
mel_outputs, gate_outputs, alignments = [], [], []
while len(mel_outputs) < decoder_inputs.size(0) - 1:
decoder_input = decoder_inputs[len(mel_outputs)]
mel_output, gate_output, attention_weights = self.decode(decoder_input)
mel_output, gate_output, attention_weights = self.decode(
decoder_input)
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output.squeeze(1)]
alignments += [attention_weights]
@ -529,9 +536,9 @@ def register_nv_tacotron2(opt_net, opt):
if __name__ == '__main__':
tron = register_nv_tacotron2({}, {})
inputs = torch.randint(high=24, size=(1,12)), \
torch.tensor([12]), \
torch.randn((1,80,749)), \
torch.tensor([749])
inputs = torch.randint(high=24, size=(1, 12)), \
torch.tensor([12]), \
torch.randn((1, 80, 749)), \
torch.tensor([749])
out = tron(*inputs)
print(out)
print(out)

@ -3,9 +3,8 @@ import re
import torch
from models.audio.tts.tacotron2.text import cleaners
from models.audio.tts.tacotron2.text.symbols import symbols
from dlas.models.audio.tts.tacotron2.text import cleaners
from dlas.models.audio.tts.tacotron2.text.symbols import symbols
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
@ -16,72 +15,73 @@ _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
def text_to_sequence(text, cleaner_names=['english_cleaners']):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Returns:
List of integers corresponding to the symbols in the text
'''
sequence = []
Returns:
List of integers corresponding to the symbols in the text
'''
sequence = []
# Check for curly braces and treat their contents as ARPAbet:
while len(text):
m = _curly_re.match(text)
if not m:
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
break
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
# Check for curly braces and treat their contents as ARPAbet:
while len(text):
m = _curly_re.match(text)
if not m:
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
break
sequence += _symbols_to_sequence(
_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
return sequence
return sequence
def sequence_to_text(sequence):
'''Converts a sequence of IDs back to a string'''
result = ''
for symbol_id in sequence:
if isinstance(symbol_id, torch.Tensor):
symbol_id = symbol_id.item()
if symbol_id in _id_to_symbol:
s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == '@':
s = '{%s}' % s[1:]
result += s
return result.replace('}{', ' ')
'''Converts a sequence of IDs back to a string'''
result = ''
for symbol_id in sequence:
if isinstance(symbol_id, torch.Tensor):
symbol_id = symbol_id.item()
if symbol_id in _id_to_symbol:
s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == '@':
s = '{%s}' % s[1:]
result += s
return result.replace('}{', ' ')
def tacotron_symbols():
return list(_symbol_to_id.keys())
return list(_symbol_to_id.keys())
def tacotron_symbol_mapping():
return _symbol_to_id.copy()
return _symbol_to_id.copy()
def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception('Unknown cleaner: %s' % name)
text = cleaner(text)
return text
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception('Unknown cleaner: %s' % name)
text = cleaner(text)
return text
def _symbols_to_sequence(symbols):
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
def _arpabet_to_sequence(text):
return _symbols_to_sequence(['@' + s for s in text.split()])
return _symbols_to_sequence(['@' + s for s in text.split()])
def _should_keep_symbol(s):
return s in _symbol_to_id and s != '_' and s != '~'
return s in _symbol_to_id and s != '_' and s != '~'

@ -12,80 +12,79 @@ hyperparameter. Some cleaners are English-specific. You'll typically want to use
the symbols in symbols.py to match your data).
'''
# Regular expression matching whitespace:
import re
from unidecode import unidecode
from .numbers import normalize_numbers
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def expand_numbers(text):
return normalize_numbers(text)
return normalize_numbers(text)
def lowercase(text):
return text.lower()
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
return unidecode(text)
def basic_cleaners(text):
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
text = lowercase(text)
text = collapse_whitespace(text)
return text
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
'''Pipeline for non-English text that transliterates to ASCII.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
'''Pipeline for non-English text that transliterates to ASCII.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def english_cleaners(text):
'''Pipeline for English text, including number and abbreviation expansion.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
text = text.replace('"', '')
return text
'''Pipeline for English text, including number and abbreviation expansion.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
text = text.replace('"', '')
return text

@ -2,64 +2,62 @@
import re
valid_symbols = [
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
]
_valid_symbol_set = set(valid_symbols)
class CMUDict:
'''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
def __init__(self, file_or_path, keep_ambiguous=True):
if isinstance(file_or_path, str):
with open(file_or_path, encoding='latin-1') as f:
entries = _parse_cmudict(f)
else:
entries = _parse_cmudict(file_or_path)
if not keep_ambiguous:
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
self._entries = entries
def __len__(self):
return len(self._entries)
'''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
def __init__(self, file_or_path, keep_ambiguous=True):
if isinstance(file_or_path, str):
with open(file_or_path, encoding='latin-1') as f:
entries = _parse_cmudict(f)
else:
entries = _parse_cmudict(file_or_path)
if not keep_ambiguous:
entries = {word: pron for word,
pron in entries.items() if len(pron) == 1}
self._entries = entries
def lookup(self, word):
'''Returns list of ARPAbet pronunciations of the given word.'''
return self._entries.get(word.upper())
def __len__(self):
return len(self._entries)
def lookup(self, word):
'''Returns list of ARPAbet pronunciations of the given word.'''
return self._entries.get(word.upper())
_alt_re = re.compile(r'\([0-9]+\)')
def _parse_cmudict(file):
cmudict = {}
for line in file:
if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
parts = line.split(' ')
word = re.sub(_alt_re, '', parts[0])
pronunciation = _get_pronunciation(parts[1])
if pronunciation:
if word in cmudict:
cmudict[word].append(pronunciation)
else:
cmudict[word] = [pronunciation]
return cmudict
cmudict = {}
for line in file:
if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
parts = line.split(' ')
word = re.sub(_alt_re, '', parts[0])
pronunciation = _get_pronunciation(parts[1])
if pronunciation:
if word in cmudict:
cmudict[word].append(pronunciation)
else:
cmudict[word] = [pronunciation]
return cmudict
def _get_pronunciation(s):
parts = s.strip().split(' ')
for part in parts:
if part not in _valid_symbol_set:
return None
return ' '.join(parts)
parts = s.strip().split(' ')
for part in parts:
if part not in _valid_symbol_set:
return None
return ' '.join(parts)

@ -1,8 +1,8 @@
""" from https://github.com/keithito/tacotron """
import inflect
import re
import inflect
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
@ -14,58 +14,58 @@ _number_re = re.compile(r'[0-9]+')
def _remove_commas(m):
return m.group(1).replace(',', '')
return m.group(1).replace(',', '')
def _expand_decimal_point(m):
return m.group(1).replace('.', ' point ')
return m.group(1).replace('.', ' point ')
def _expand_dollars(m):
match = m.group(1)
parts = match.split('.')
if len(parts) > 2:
return match + ' dollars' # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit)
elif cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit)
else:
return 'zero dollars'
match = m.group(1)
parts = match.split('.')
if len(parts) > 2:
return match + ' dollars' # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit)
elif cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit)
else:
return 'zero dollars'
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return 'two thousand'
elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred'
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return 'two thousand'
elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred'
else:
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
else:
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
else:
return _inflect.number_to_words(num, andword='')
return _inflect.number_to_words(num, andword='')
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text

@ -4,8 +4,9 @@
Defines the set of symbols used in text input to the model.
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
from models.audio.tts.tacotron2.text import cmudict
from dlas.models.audio.tts.tacotron2.text import cmudict
_pad = '_'
_punctuation = '!\'(),.:;? '
_special = '-'
@ -15,4 +16,5 @@ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
_arpabet = ['@' + s for s in cmudict.valid_symbols]
# Export all symbols:
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet
symbols = [_pad] + list(_special) + list(_punctuation) + \
list(_letters) + _arpabet

@ -1,20 +1,20 @@
from math import sqrt
import torch
from munch import munchify
from torch.autograd import Variable
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
from models.arch_util import ConvGnSilu
from models.diffusion.unet_diffusion import UNetModel, AttentionPool2d
from models.audio.tts.tacotron2.layers import LinearNorm
from models.audio.tts.tacotron2.hparams import create_hparams
from models.audio.tts.tacotron2.tacotron2 import Attention, Encoder
from trainer.networks import register_model
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
from utils.util import checkpoint
import torch_intermediary as ml
import dlas.torch_intermediary as ml
from dlas.models.arch_util import ConvGnSilu
from dlas.models.audio.tts.tacotron2.hparams import create_hparams
from dlas.models.audio.tts.tacotron2.layers import LinearNorm
from dlas.models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
from dlas.models.audio.tts.tacotron2.tacotron2 import Attention, Encoder
from dlas.models.diffusion.unet_diffusion import AttentionPool2d, UNetModel
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint
class WavDecoder(nn.Module):
@ -24,26 +24,33 @@ class WavDecoder(nn.Module):
self.K = int(sample_rate * (K_ms/1000))
self.clarifier = UNetModel(image_size=self.K,
in_channels=1,
model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model.
# This is a requirement to enable to load the embedding produced by the decoder into the unet model.
model_channels=dec_channels // 4,
out_channels=2, # 2 channels: eps_pred and variance_pred
num_res_blocks=2,
attention_resolutions=(8,),
dims=1,
dropout=.1,
channel_mult=(1,2,4,8),
channel_mult=(1, 2, 4, 8),
use_raw_y_as_embedding=True)
assert self.K % 64 == 0 # Otherwise the UNetModel breaks.
self.pre_rnn = nn.Sequential(ConvGnSilu(1,32,kernel_size=5,convnd=nn.Conv1d),
ConvGnSilu(32,64,kernel_size=5,stride=4,convnd=nn.Conv1d),
ConvGnSilu(64,128,kernel_size=5,stride=4,convnd=nn.Conv1d),
ConvGnSilu(128,256,kernel_size=5,stride=4,convnd=nn.Conv1d),
ConvGnSilu(256,dec_channels,kernel_size=1,convnd=nn.Conv1d),
AttentionPool2d(self.K//64,dec_channels,dec_channels//4))
self.pre_rnn = nn.Sequential(ConvGnSilu(1, 32, kernel_size=5, convnd=nn.Conv1d),
ConvGnSilu(32, 64, kernel_size=5,
stride=4, convnd=nn.Conv1d),
ConvGnSilu(64, 128, kernel_size=5,
stride=4, convnd=nn.Conv1d),
ConvGnSilu(128, 256, kernel_size=5,
stride=4, convnd=nn.Conv1d),
ConvGnSilu(256, dec_channels,
kernel_size=1, convnd=nn.Conv1d),
AttentionPool2d(self.K//64, dec_channels, dec_channels//4))
self.attention_rnn = nn.LSTMCell(dec_channels*2, dec_channels)
self.attention_layer = Attention(dec_channels, dec_channels, dec_channels)
self.attention_layer = Attention(
dec_channels, dec_channels, dec_channels)
self.decoder_rnn = nn.LSTMCell(dec_channels*2, dec_channels, 1)
self.linear_projection = LinearNorm(dec_channels*2, self.dec_channels)
self.gate_layer = LinearNorm(self.dec_channels*2, 1, bias=True, w_init_gain='sigmoid')
self.gate_layer = LinearNorm(
self.dec_channels*2, 1, bias=True, w_init_gain='sigmoid')
self.dropout_probability = dropout_probability
def chunk_wav(self, wav):
@ -51,16 +58,18 @@ class WavDecoder(nn.Module):
# Pad the last chunk as needed.
padding_needed = self.K - wavs[-1].shape[-1]
if padding_needed > 0:
wavs[-1] = F.pad(wavs[-1], (0,padding_needed))
wavs[-1] = F.pad(wavs[-1], (0, padding_needed))
wavs = torch.stack(wavs, dim=1) # wavs.shape = (b,s,K) where s=decoder sequence length
# wavs.shape = (b,s,K) where s=decoder sequence length
wavs = torch.stack(wavs, dim=1)
return wavs, padding_needed
def prepare_decoder_inputs(self, inp):
# inp.shape = (b,s,K) chunked waveform.
b,s,K = inp.shape
first_frame = torch.zeros(b,1,K).to(inp.device)
x = torch.cat([first_frame, inp[:,:-1]], dim=1) # It is now aligned for teacher forcing.
b, s, K = inp.shape
first_frame = torch.zeros(b, 1, K).to(inp.device)
# It is now aligned for teacher forcing.
x = torch.cat([first_frame, inp[:, :-1]], dim=1)
return x
def initialize_decoder_states(self, memory, mask):
@ -75,18 +84,25 @@ class WavDecoder(nn.Module):
B = memory.size(0)
MAX_TIME = memory.size(1)
self.attention_hidden = Variable(memory.data.new(B, self.dec_channels).zero_())
self.attention_cell = Variable(memory.data.new(B, self.dec_channels).zero_())
self.attention_hidden = Variable(
memory.data.new(B, self.dec_channels).zero_())
self.attention_cell = Variable(
memory.data.new(B, self.dec_channels).zero_())
self.decoder_hidden = Variable(memory.data.new(B, self.dec_channels).zero_())
self.decoder_cell = Variable(memory.data.new(B, self.dec_channels).zero_())
self.decoder_hidden = Variable(
memory.data.new(B, self.dec_channels).zero_())
self.decoder_cell = Variable(
memory.data.new(B, self.dec_channels).zero_())
self.attention_weights = Variable(memory.data.new(B, MAX_TIME).zero_())
self.attention_weights_cum = Variable(memory.data.new(B, MAX_TIME).zero_())
self.attention_context = Variable(memory.data.new(B, self.dec_channels).zero_())
self.attention_weights_cum = Variable(
memory.data.new(B, MAX_TIME).zero_())
self.attention_context = Variable(
memory.data.new(B, self.dec_channels).zero_())
self.memory = memory
self.processed_memory = checkpoint(self.attention_layer.memory_layer, memory)
self.processed_memory = checkpoint(
self.attention_layer.memory_layer, memory)
self.mask = mask
def teardown_states(self):
@ -113,20 +129,28 @@ class WavDecoder(nn.Module):
attention_weights:
"""
cell_input = torch.cat((decoder_input, self.attention_context), -1)
self.attention_hidden, self.attention_cell = self.attention_rnn(cell_input, (self.attention_hidden, self.attention_cell))
self.attention_hidden = F.dropout(self.attention_hidden, self.dropout_probability, self.training)
self.attention_hidden, self.attention_cell = self.attention_rnn(
cell_input, (self.attention_hidden, self.attention_cell))
self.attention_hidden = F.dropout(
self.attention_hidden, self.dropout_probability, self.training)
attention_weights_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1)
attention_weights_cat = torch.cat((self.attention_weights.unsqueeze(
1), self.attention_weights_cum.unsqueeze(1)), dim=1)
self.attention_context, self.attention_weights = checkpoint(self.attention_layer, self.attention_hidden, self.memory,
self.processed_memory, attention_weights_cat, self.mask)
self.attention_weights_cum += self.attention_weights
decoder_input = torch.cat((self.attention_hidden, self.attention_context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(decoder_input, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(self.decoder_hidden, self.dropout_probability, self.training)
decoder_hidden_attention_context = torch.cat((self.decoder_hidden, self.attention_context), dim=1)
decoder_output = checkpoint(self.linear_projection, decoder_hidden_attention_context)
decoder_input = torch.cat(
(self.attention_hidden, self.attention_context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
decoder_input, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(
self.decoder_hidden, self.dropout_probability, self.training)
decoder_hidden_attention_context = torch.cat(
(self.decoder_hidden, self.attention_context), dim=1)
decoder_output = checkpoint(
self.linear_projection, decoder_hidden_attention_context)
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
return decoder_output, gate_prediction, self.attention_weights
@ -137,11 +161,11 @@ class WavDecoder(nn.Module):
# (T_out, B) -> (B, T_out)
gate_outputs = torch.stack(gate_outputs, dim=1).repeat(1, self.K)
b,s,_,K = diffusion_eps.shape
b, s, _, K = diffusion_eps.shape
# (B, S, 2, K) -> (B, 2, S*K)
diffusion_eps = diffusion_eps.permute(0,2,1,3).reshape(b, 2, s*K)
diffusion_eps = diffusion_eps.permute(0, 2, 1, 3).reshape(b, 2, s*K)
return diffusion_eps[:,:,:-padding_added], gate_outputs[:,:-padding_added], alignments[:,:-padding_added]
return diffusion_eps[:, :, :-padding_added], gate_outputs[:, :-padding_added], alignments[:, :-padding_added]
def forward(self, wav_noised, wav_real, timesteps, text_enc, memory_lengths):
'''
@ -155,14 +179,17 @@ class WavDecoder(nn.Module):
wav_noised, padding_added = self.chunk_wav(wav_noised)
wav_real, _ = self.chunk_wav(wav_real)
wav_real = self.prepare_decoder_inputs(wav_real)
b,s,K = wav_real.shape
wav_real = checkpoint(self.pre_rnn, wav_real.reshape(b*s,1,K)).reshape(b,s,self.dec_channels)
b, s, K = wav_real.shape
wav_real = checkpoint(self.pre_rnn, wav_real.reshape(
b*s, 1, K)).reshape(b, s, self.dec_channels)
self.initialize_decoder_states(text_enc, mask=~get_mask_from_lengths(memory_lengths))
self.initialize_decoder_states(
text_enc, mask=~get_mask_from_lengths(memory_lengths))
decoder_contexts, gate_outputs, alignments = [], [], []
while len(decoder_contexts) < wav_real.size(1):
decoder_input = wav_real[:, len(decoder_contexts)]
dec_context, gate_output, attention_weights = self.produce_context(decoder_input)
dec_context, gate_output, attention_weights = self.produce_context(
decoder_input)
decoder_contexts += [dec_context.squeeze(1)]
gate_outputs += [gate_output.squeeze(1)]
alignments += [attention_weights]
@ -170,12 +197,14 @@ class WavDecoder(nn.Module):
# diffusion_inputs and wavs needs to have the sequence and batch dimensions combined, and needs a channel dimension
diffusion_emb = torch.stack(decoder_contexts, dim=1)
b,s,c = diffusion_emb.shape
diffusion_emb = diffusion_emb.reshape(b*s,c)
wav_noised = wav_noised.reshape(b*s,1,self.K)
diffusion_eps = self.clarifier(wav_noised, timesteps.repeat(s), diffusion_emb).reshape(b,s,2,self.K)
b, s, c = diffusion_emb.shape
diffusion_emb = diffusion_emb.reshape(b*s, c)
wav_noised = wav_noised.reshape(b*s, 1, self.K)
diffusion_eps = self.clarifier(wav_noised, timesteps.repeat(
s), diffusion_emb).reshape(b, s, 2, self.K)
# Recombine diffusion outputs across the sequence into a single prediction.
diffusion_eps, gate_outputs, alignments = self.recombine(diffusion_eps, gate_outputs, alignments, padding_added)
diffusion_eps, gate_outputs, alignments = self.recombine(
diffusion_eps, gate_outputs, alignments, padding_added)
return diffusion_eps, gate_outputs, alignments
@ -199,11 +228,11 @@ class WaveTacotron2(nn.Module):
if self.mask_padding and output_lengths is not None:
mask_fill = outputs[0].shape[-1]
mask = ~get_mask_from_lengths(output_lengths, mask_fill)
mask = mask.unsqueeze(1).repeat(1,2,1)
mask = mask.unsqueeze(1).repeat(1, 2, 1)
outputs[0].data.masked_fill_(mask, 0.0)
outputs[0] = outputs[0].unsqueeze(1) # Re-add channel dimension.
outputs[1].data.masked_fill_(mask[:,0], 1e3) # gate energies
outputs[1].data.masked_fill_(mask[:, 0], 1e3) # gate energies
return outputs
@ -214,7 +243,8 @@ class WaveTacotron2(nn.Module):
text_lengths, output_lengths = text_lengths.data, output_lengths.data
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
encoder_outputs = checkpoint(self.encoder, embedded_inputs, text_lengths)
encoder_outputs = checkpoint(
self.encoder, embedded_inputs, text_lengths)
eps_pred, gate_outputs, alignments = self.decoder(
wavs_diffused, wavs_corrected, timesteps, encoder_outputs, memory_lengths=text_lengths)
@ -234,7 +264,7 @@ if __name__ == '__main__':
out = tron(wavs_diffused=torch.randn(2, 1, 22000),
wavs_corrected=torch.randn(2, 1, 22000),
timesteps=torch.LongTensor([555, 543]),
text_inputs=torch.randint(high=24, size=(2,12)),
text_inputs=torch.randint(high=24, size=(2, 12)),
text_lengths=torch.tensor([12, 12]),
output_lengths=torch.tensor([21995]))
print([o.shape for o in out])
print([o.shape for o in out])

@ -23,11 +23,13 @@ Returns:
import functools
import random
from time import time
import torch
import torch.nn as nn
import torch_intermediary as ml
from tqdm import tqdm
import dlas.torch_intermediary as ml
def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
@ -61,13 +63,13 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
"""
from transformers import GPT2Config, GPT2Model
gpt_config = GPT2Config(vocab_size=256, # Unused.
n_positions=max_mel_seq_len+max_text_seq_len,
n_ctx=max_mel_seq_len+max_text_seq_len,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing)
n_positions=max_mel_seq_len+max_text_seq_len,
n_ctx=max_mel_seq_len+max_text_seq_len,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing)
gpt = GPT2Model(gpt_config)
# Override the built in positional embeddings
del gpt.wpe
@ -75,8 +77,10 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
# Built-in token embeddings are unused.
del gpt.wte
mel_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim)
text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim)
mel_pos_emb = LearnedPositionEmbeddings(
max_mel_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim)
text_pos_emb = LearnedPositionEmbeddings(
max_text_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim)
return gpt, mel_pos_emb, text_pos_emb, None, None
@ -85,7 +89,8 @@ def build_lr_performer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_l
lucidrains Performer implementation, https://github.com/lucidrains/performer-pytorch
"""
from models.lucidrains.performer.performer_pytorch import Performer
model = Performer(dim=model_dim, depth=layers, heads=heads, dim_head=model_dim, causal=True)
model = Performer(dim=model_dim, depth=layers,
heads=heads, dim_head=model_dim, causal=True)
return model
@ -104,14 +109,14 @@ def build_lr_xformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len
def test_all_performance(**kwargs):
transformer_builders = [#build_hf_gpt_transformer,
build_lr_performer,]
# build_lr_reformer,
# build_lr_xformer]
transformer_builders = [ # build_hf_gpt_transformer,
build_lr_performer,]
# build_lr_reformer,
# build_lr_xformer]
for builder in transformer_builders:
model = builder(**kwargs)
start = time()
args = torch.randint(0, 8192, (16,450))
args = torch.randint(0, 8192, (16, 450))
for k in tqdm(range(10)):
model(args)
stop = time()
@ -119,4 +124,5 @@ def test_all_performance(**kwargs):
if __name__ == '__main__':
test_all_performance(layers=12, model_dim=512, heads=8, num_tokens=8192, max_seq_len=1000, checkpointing=False)
test_all_performance(layers=12, model_dim=512, heads=8,
num_tokens=8192, max_seq_len=1000, checkpointing=False)

@ -2,17 +2,23 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
from trainer.networks import register_model
from utils.util import checkpoint
import torch_intermediary as ml
import dlas.torch_intermediary as ml
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
timestep_embedding, zero_module)
from dlas.models.diffusion.unet_diffusion import (TimestepBlock,
TimestepEmbedSequential)
from dlas.models.lucidrains.x_transformers import (Attention, Encoder,
FeedForward,
RMSScaleShiftNorm,
RotaryEmbedding)
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint
def is_latent(t):
return t.dtype == torch.float
def is_sequence(t):
return t.dtype == torch.long
@ -21,7 +27,8 @@ class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim):
super().__init__()
# nn.Embedding
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
self.m = nn.ModuleList(
[ml.Embedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
@ -41,13 +48,16 @@ class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock):
class AttentionBlock(TimestepBlock):
def __init__(self, dim, heads, dropout):
super().__init__()
self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout, zero_init_output=False)
self.ff = FeedForward(dim, mult=1, dropout=dropout, zero_init_output=True)
self.attn = Attention(dim, heads=heads, causal=False,
dropout=dropout, zero_init_output=False)
self.ff = FeedForward(
dim, mult=1, dropout=dropout, zero_init_output=True)
self.rms_scale_norm = RMSScaleShiftNorm(dim)
def forward(self, x, timestep_emb, rotary_emb):
h = self.rms_scale_norm(x, norm_scale_shift_inp=timestep_emb)
h, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb)
h, _, _, _ = checkpoint(self.attn, h, None, None,
None, None, None, rotary_emb)
h = checkpoint(self.ff, h)
return h + x
@ -56,6 +66,7 @@ class TransformerDiffusionTTS(nn.Module):
"""
A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way?
"""
def __init__(
self,
model_channels=512,
@ -71,7 +82,8 @@ class TransformerDiffusionTTS(nn.Module):
dropout=0,
use_fp16=False,
# Parameters for regularization.
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
# This implements a mechanism similar to what is used in classifier-free training.
unconditioned_percentage=.1,
):
super().__init__()
@ -91,17 +103,17 @@ class TransformerDiffusionTTS(nn.Module):
linear(model_channels, model_channels),
)
self.conditioning_embedder = nn.Sequential(nn.Conv1d(in_channels, model_channels // 2, 3, padding=1, stride=2),
nn.Conv1d(model_channels//2, model_channels,3,padding=1,stride=2))
nn.Conv1d(model_channels//2, model_channels, 3, padding=1, stride=2))
self.conditioning_encoder = Encoder(
dim=model_channels,
depth=4,
heads=heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
)
dim=model_channels,
depth=4,
heads=heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
)
self.clvp_encoder = ml.Linear(clvp_in_dim, model_channels)
# nn.Embedding
self.type_embedding = ml.Embedding(types, model_channels)
@ -114,43 +126,48 @@ class TransformerDiffusionTTS(nn.Module):
# nn.Embedding
self.embeddings = ml.Embedding(token_count, model_channels)
else:
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
self.embeddings = MultiGroupEmbedding(
token_count, in_groups, model_channels)
self.latent_conditioner = nn.Sequential(
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
Encoder(
dim=model_channels,
depth=2,
heads=heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
)
dim=model_channels,
depth=2,
heads=heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
)
)
self.latent_fade = nn.Parameter(torch.zeros(1,1,model_channels))
self.latent_fade = nn.Parameter(torch.zeros(1, 1, model_channels))
self.code_converter = Encoder(
dim=model_channels,
depth=3,
heads=heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
dim=model_channels,
depth=3,
heads=heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
)
self.unconditioned_embedding = nn.Parameter(
torch.randn(1, 1, model_channels))
self.mel_head = nn.Conv1d(
model_channels, in_channels, kernel_size=3, padding=1)
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
self.intg = ml.Linear(model_channels*2, model_channels)
self.layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers)])
self.layers = TimestepRotaryEmbedSequential(
*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers)])
self.out = nn.Sequential(
normalization(model_channels),
nn.SiLU(),
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
zero_module(conv_nd(1, model_channels,
out_channels, 3, padding=1)),
)
self.debug_codes = {}
@ -165,7 +182,8 @@ class TransformerDiffusionTTS(nn.Module):
return groups
def timestep_independent(self, codes, conditioning_input, expected_seq_len, prenet_latent=None, return_code_pred=False):
cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1)
cond_emb = self.conditioning_embedder(
conditioning_input).permute(0, 2, 1)
cond_emb = self.conditioning_encoder(cond_emb)[:, 0]
code_emb = self.embeddings(codes)
@ -173,7 +191,8 @@ class TransformerDiffusionTTS(nn.Module):
latent_conditioning = self.latent_conditioner(prenet_latent)
code_emb = code_emb + latent_conditioning * self.latent_fade
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
unconditioned_batches = torch.zeros(
(code_emb.shape[0], 1, 1), device=code_emb.device)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
@ -182,57 +201,65 @@ class TransformerDiffusionTTS(nn.Module):
code_emb)
code_emb = self.code_converter(code_emb)
expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1)
expanded_code_emb = F.interpolate(code_emb.permute(
0, 2, 1), size=expected_seq_len, mode='nearest').permute(0, 2, 1)
if not return_code_pred:
return expanded_code_emb, cond_emb
else:
# Perform the mel_head computation on the pre-exanded code embeddings, then interpolate it separately.
mel_pred = self.mel_head(code_emb.permute(0,2,1))
mel_pred = F.interpolate(mel_pred, size=expected_seq_len, mode='nearest')
mel_pred = self.mel_head(code_emb.permute(0, 2, 1))
mel_pred = F.interpolate(
mel_pred, size=expected_seq_len, mode='nearest')
# Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches.
# This is because we don't want that gradient being used to train parameters through the codes_embedder as
# it unbalances contributions to that network from the MSE loss.
mel_pred = mel_pred * unconditioned_batches.logical_not()
return expanded_code_emb, cond_emb, mel_pred
def forward(self, x, timesteps, codes=None, conditioning_input=None, clvp_input=None, type=None, prenet_latent=None, precomputed_code_embeddings=None,
precomputed_cond_embeddings=None, conditioning_free=False, return_code_pred=False):
if precomputed_code_embeddings is not None:
assert precomputed_cond_embeddings is not None, "Must specify both precomputed embeddings if one is specified"
assert codes is None and conditioning_input is None and prenet_latent is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here."
assert not (return_code_pred and precomputed_code_embeddings is not None), "I cannot compute a code_pred output for you."
assert not (
return_code_pred and precomputed_code_embeddings is not None), "I cannot compute a code_pred output for you."
assert type is not None, "Type is required."
unused_params = []
if not return_code_pred:
unused_params.extend(list(self.mel_head.parameters()))
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
code_emb = self.unconditioned_embedding.repeat(
x.shape[0], 1, x.shape[-1])
unused_params.extend(
list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
unused_params.extend(list(self.latent_conditioner.parameters()))
else:
if precomputed_code_embeddings is not None:
code_emb = precomputed_code_embeddings
cond_emb = precomputed_cond_embeddings
else:
code_emb, cond_emb, mel_pred = self.timestep_independent(codes, conditioning_input, x.shape[-1], prenet_latent, True)
code_emb, cond_emb, mel_pred = self.timestep_independent(
codes, conditioning_input, x.shape[-1], prenet_latent, True)
if prenet_latent is None:
unused_params.extend(list(self.latent_conditioner.parameters()) + [self.latent_fade])
unused_params.extend(
list(self.latent_conditioner.parameters()) + [self.latent_fade])
unused_params.append(self.unconditioned_embedding)
clvp_emb = torch.zeros_like(cond_emb) if clvp_input is None else self.clvp_encoder(clvp_input)
clvp_emb = torch.zeros_like(
cond_emb) if clvp_input is None else self.clvp_encoder(clvp_input)
type_emb = self.type_embedding(type)
if clvp_input is None:
unused_params.extend(self.clvp_encoder.parameters())
blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb + clvp_emb + type_emb
x = self.inp_block(x).permute(0,2,1)
blk_emb = self.time_embed(timestep_embedding(
timesteps, self.model_channels)) + cond_emb + clvp_emb + type_emb
x = self.inp_block(x).permute(0, 2, 1)
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
x = self.intg(torch.cat([x, code_emb], dim=-1))
x = self.layers(x, blk_emb, rotary_pos_emb)
x = x.float().permute(0,2,1)
x = x.float().permute(0, 2, 1)
out = self.out(x)
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
@ -253,13 +280,14 @@ def register_transformer_diffusion_tts(opt_net, opt):
if __name__ == '__main__':
clip = torch.randn(2, 256, 400)
aligned_latent = torch.randn(2,100,512)
aligned_sequence = torch.randint(0,8,(2,100,8))
aligned_latent = torch.randn(2, 100, 512)
aligned_sequence = torch.randint(0, 8, (2, 100, 8))
cond = torch.randn(2, 256, 400)
ts = torch.LongTensor([600, 600])
clvp = torch.randn(2,768)
type = torch.LongTensor([0,1])
model = TransformerDiffusionTTS(512, unconditioned_percentage=.5, in_groups=8)
o = model(clip, ts, aligned_sequence, cond, clvp_input=clvp, type=type, return_code_pred=True)
#o = model(clip, ts, aligned_sequence, cond, aligned_latent)
clvp = torch.randn(2, 768)
type = torch.LongTensor([0, 1])
model = TransformerDiffusionTTS(
512, unconditioned_percentage=.5, in_groups=8)
o = model(clip, ts, aligned_sequence, cond,
clvp_input=clvp, type=type, return_code_pred=True)
# o = model(clip, ts, aligned_sequence, cond, aligned_latent)

@ -1,18 +1,24 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock
from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding
from trainer.networks import register_model
from utils.util import checkpoint, print_network
import dlas.torch_intermediary as ml
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
timestep_embedding, zero_module)
from dlas.models.diffusion.unet_diffusion import (TimestepBlock,
TimestepEmbedSequential)
from dlas.models.lucidrains.x_transformers import (Attention, Encoder,
FeedForward,
RMSScaleShiftNorm,
RotaryEmbedding)
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint, print_network
def is_latent(t):
return t.dtype == torch.float
def is_sequence(t):
return t.dtype == torch.long
@ -21,7 +27,8 @@ class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim):
super().__init__()
# nn.Embedding
self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)])
self.m = nn.ModuleList(
[ml.Embedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
@ -44,12 +51,14 @@ class DietAttentionBlock(TimestepBlock):
self.rms_scale_norm = RMSScaleShiftNorm(in_dim)
self.proj = ml.Linear(in_dim, dim)
self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout)
self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True)
self.ff = FeedForward(dim, in_dim, mult=1,
dropout=dropout, zero_init_output=True)
def forward(self, x, timestep_emb, rotary_emb):
h = self.rms_scale_norm(x, norm_scale_shift_inp=timestep_emb)
h = self.proj(h)
h, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb)
h, _, _, _ = checkpoint(self.attn, h, None, None,
None, None, None, rotary_emb)
h = checkpoint(self.ff, h)
return h + x
@ -58,6 +67,7 @@ class TransformerDiffusionTTS(nn.Module):
"""
A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way?
"""
def __init__(
self,
prenet_channels=256,
@ -75,7 +85,8 @@ class TransformerDiffusionTTS(nn.Module):
dropout=0,
use_fp16=False,
# Parameters for regularization.
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
# This implements a mechanism similar to what is used in classifier-free training.
unconditioned_percentage=.1,
):
super().__init__()
@ -96,17 +107,17 @@ class TransformerDiffusionTTS(nn.Module):
)
prenet_heads = prenet_channels//64
self.conditioning_embedder = nn.Sequential(nn.Conv1d(in_channels, prenet_channels // 2, 3, padding=1, stride=2),
nn.Conv1d(prenet_channels//2, prenet_channels,3,padding=1,stride=2))
nn.Conv1d(prenet_channels//2, prenet_channels, 3, padding=1, stride=2))
self.conditioning_encoder = Encoder(
dim=prenet_channels,
depth=4,
heads=prenet_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
)
dim=prenet_channels,
depth=4,
heads=prenet_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
)
self.clvp_encoder = ml.Linear(clvp_in_dim, prenet_channels)
# nn.Embedding
self.type_embedding = ml.Embedding(types, prenet_channels)
@ -119,45 +130,48 @@ class TransformerDiffusionTTS(nn.Module):
# nn.Embedding
self.embeddings = ml.Embedding(token_count, prenet_channels)
else:
self.embeddings = MultiGroupEmbedding(token_count, in_groups, prenet_channels)
self.embeddings = MultiGroupEmbedding(
token_count, in_groups, prenet_channels)
self.latent_conditioner = nn.Sequential(
nn.Conv1d(in_latent_channels, prenet_channels, 3, padding=1),
Encoder(
dim=prenet_channels,
depth=2,
heads=prenet_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
)
dim=prenet_channels,
depth=2,
heads=prenet_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
)
)
self.latent_fade = nn.Parameter(torch.zeros(1,1,prenet_channels))
self.latent_fade = nn.Parameter(torch.zeros(1, 1, prenet_channels))
self.code_converter = Encoder(
dim=prenet_channels,
depth=3,
heads=prenet_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels))
dim=prenet_channels,
depth=3,
heads=prenet_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
)
self.unconditioned_embedding = nn.Parameter(
torch.randn(1, 1, prenet_channels))
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
self.cond_intg = ml.Linear(prenet_channels*4, model_channels)
self.intg = ml.Linear(prenet_channels*2, model_channels)
self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)])
self.layers = TimestepRotaryEmbedSequential(
*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)])
self.out = nn.Sequential(
normalization(model_channels),
nn.SiLU(),
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
zero_module(conv_nd(1, model_channels,
out_channels, 3, padding=1)),
)
self.debug_codes = {}
@ -172,7 +186,8 @@ class TransformerDiffusionTTS(nn.Module):
return groups
def timestep_independent(self, codes, conditioning_input, expected_seq_len, prenet_latent=None):
cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1)
cond_emb = self.conditioning_embedder(
conditioning_input).permute(0, 2, 1)
cond_emb = self.conditioning_encoder(cond_emb)[:, 0]
code_emb = self.embeddings(codes)
@ -188,11 +203,11 @@ class TransformerDiffusionTTS(nn.Module):
code_emb)
code_emb = self.code_converter(code_emb)
expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1)
expanded_code_emb = F.interpolate(code_emb.permute(
0, 2, 1), size=expected_seq_len, mode='nearest').permute(0, 2, 1)
return expanded_code_emb, cond_emb
def forward(self, x, timesteps, codes=None, conditioning_input=None, clvp_input=None, type=None, prenet_latent=None, precomputed_code_embeddings=None,
precomputed_cond_embeddings=None, conditioning_free=False):
if precomputed_code_embeddings is not None:
@ -202,32 +217,38 @@ class TransformerDiffusionTTS(nn.Module):
unused_params = []
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
code_emb = self.unconditioned_embedding.repeat(
x.shape[0], 1, x.shape[-1])
unused_params.extend(
list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
unused_params.extend(list(self.latent_conditioner.parameters()))
else:
if precomputed_code_embeddings is not None:
code_emb = precomputed_code_embeddings
cond_emb = precomputed_cond_embeddings
else:
code_emb, cond_emb = self.timestep_independent(codes, conditioning_input, x.shape[-1], prenet_latent)
code_emb, cond_emb = self.timestep_independent(
codes, conditioning_input, x.shape[-1], prenet_latent)
if prenet_latent is None:
unused_params.extend(list(self.latent_conditioner.parameters()) + [self.latent_fade])
unused_params.extend(
list(self.latent_conditioner.parameters()) + [self.latent_fade])
unused_params.append(self.unconditioned_embedding)
clvp_emb = torch.zeros_like(cond_emb) if clvp_input is None else self.clvp_encoder(clvp_input)
clvp_emb = torch.zeros_like(
cond_emb) if clvp_input is None else self.clvp_encoder(clvp_input)
type_emb = self.type_embedding(type)
if clvp_input is None:
unused_params.extend(self.clvp_encoder.parameters())
blk_emb = torch.cat([self.time_embed(timestep_embedding(timesteps, self.prenet_channels)), cond_emb, clvp_emb, type_emb], dim=-1)
blk_emb = torch.cat([self.time_embed(timestep_embedding(
timesteps, self.prenet_channels)), cond_emb, clvp_emb, type_emb], dim=-1)
blk_emb = self.cond_intg(blk_emb)
x = self.inp_block(x).permute(0,2,1)
x = self.inp_block(x).permute(0, 2, 1)
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
x = self.intg(torch.cat([x, code_emb], dim=-1))
x = self.layers(x, blk_emb, rotary_pos_emb)
x = x.float().permute(0,2,1)
x = x.float().permute(0, 2, 1)
out = self.out(x)
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
@ -246,15 +267,15 @@ def register_transformer_diffusion_tts2(opt_net, opt):
if __name__ == '__main__':
clip = torch.randn(2, 256, 400)
aligned_latent = torch.randn(2,100,512)
aligned_sequence = torch.randint(0,8,(2,100,8))
aligned_latent = torch.randn(2, 100, 512)
aligned_sequence = torch.randint(0, 8, (2, 100, 8))
cond = torch.randn(2, 256, 400)
ts = torch.LongTensor([600, 600])
clvp = torch.randn(2,768)
type = torch.LongTensor([0,1])
model = TransformerDiffusionTTS(model_channels=3072, num_layers=16, unconditioned_percentage=.5, in_groups=8, prenet_channels=1024, block_channels=1024)
clvp = torch.randn(2, 768)
type = torch.LongTensor([0, 1])
model = TransformerDiffusionTTS(model_channels=3072, num_layers=16,
unconditioned_percentage=.5, in_groups=8, prenet_channels=1024, block_channels=1024)
print_network(model)
o = model(clip, ts, aligned_sequence, cond, clvp_input=clvp, type=type)
torch.save(model.state_dict(), 'test.pth')
#o = model(clip, ts, aligned_sequence, cond, aligned_latent)
# o = model(clip, ts, aligned_sequence, cond, aligned_latent)

@ -5,16 +5,19 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
import torch_intermediary as ml
from x_transformers import ContinuousTransformerWrapper, Encoder
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
Downsample, Upsample, TimestepBlock
from models.audio.tts.mini_encoder import AudioMiniEncoder
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
from trainer.networks import register_model
from utils.util import checkpoint
from x_transformers import Encoder, ContinuousTransformerWrapper
import dlas.torch_intermediary as ml
from dlas.models.audio.tts.mini_encoder import AudioMiniEncoder
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
timestep_embedding, zero_module)
from dlas.models.diffusion.unet_diffusion import (AttentionBlock, Downsample,
TimestepBlock,
TimestepEmbedSequential,
Upsample)
from dlas.scripts.audio.gen.use_diffuse_tts import ceil_multiple
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint
def clustered_mask(probability, shape, dev, lateral_expansion_radius_max=3, inverted=False):
@ -567,4 +570,3 @@ if __name__ == '__main__':
o.sum().backward()
model.before_step(0)
torch.save(model.state_dict(), 'test_out.pth')

@ -5,21 +5,26 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
from x_transformers import Encoder
import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
Downsample, Upsample, TimestepBlock
from models.audio.tts.mini_encoder import AudioMiniEncoder
from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
from trainer.networks import register_model
from utils.util import checkpoint
import dlas.torch_intermediary as ml
from dlas.models.audio.tts.mini_encoder import AudioMiniEncoder
from dlas.models.audio.tts.unet_diffusion_tts7 import \
CheckpointedXTransformerEncoder
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
timestep_embedding, zero_module)
from dlas.models.diffusion.unet_diffusion import (AttentionBlock, Downsample,
TimestepBlock,
TimestepEmbedSequential,
Upsample)
from dlas.scripts.audio.gen.use_diffuse_tts import ceil_multiple
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint
def is_latent(t):
return t.dtype == torch.float
def is_sequence(t):
return t.dtype == torch.long
@ -49,7 +54,8 @@ class ResBlock(TimestepBlock):
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding),
conv_nd(dims, channels, self.out_channels,
eff_kernel, padding=eff_padding),
)
self.emb_layers = nn.Sequential(
@ -64,14 +70,16 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
conv_nd(dims, self.out_channels, self.out_channels,
kernel_size, padding=padding)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
self.skip_connection = conv_nd(
dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
def forward(self, x, emb):
"""
@ -100,6 +108,7 @@ class ResBlock(TimestepBlock):
h = self.out_layers(h)
return self.skip_connection(x) + h
class DiffusionTts(nn.Module):
"""
The full UNet model with attention and timestep embedding.
@ -143,12 +152,12 @@ class DiffusionTts(nn.Module):
out_channels=2, # mean and variance
dropout=0,
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
channel_mult=(1, 1.5, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
token_conditioning_resolutions=(1,16,),
attention_resolutions=(512,1024,2048),
token_conditioning_resolutions=(1, 16,),
attention_resolutions=(512, 1024, 2048),
conv_resample=True,
dims=1,
use_fp16=False,
@ -159,10 +168,12 @@ class DiffusionTts(nn.Module):
scale_factor=2,
time_embed_dim_multiplier=4,
freeze_main_net=False,
efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
# Uses kernels with width of 1 in several places rather than 3.
efficient_convs=True,
use_scale_shift_norm=True,
# Parameters for regularization.
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
# This implements a mechanism similar to what is used in classifier-free training.
unconditioned_percentage=.1,
# Parameters for super-sampling.
super_sampling=False,
super_sampling_max_noising_factor=.1,
@ -173,7 +184,8 @@ class DiffusionTts(nn.Module):
num_heads_upsample = num_heads
if super_sampling:
in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input.
# In super-sampling mode, the LR input is concatenated directly onto the input.
in_channels *= 2
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
@ -224,10 +236,12 @@ class DiffusionTts(nn.Module):
rotary_pos_emb=True,
)
))
self.latent_converter = nn.Conv1d(in_latent_channels, conditioning_dim, 1)
self.aligned_latent_padding_embedding = nn.Parameter(torch.randn(1,in_latent_channels,1))
self.latent_converter = nn.Conv1d(
in_latent_channels, conditioning_dim, 1)
self.aligned_latent_padding_embedding = nn.Parameter(
torch.randn(1, in_latent_channels, 1))
if in_channels > 60: # It's a spectrogram.
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,conditioning_dim,3,padding=1,stride=2),
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels, conditioning_dim, 3, padding=1, stride=2),
CheckpointedXTransformerEncoder(
needs_permute=True,
max_seq_len=-1,
@ -242,25 +256,33 @@ class DiffusionTts(nn.Module):
ff_glu=True,
rotary_pos_emb=True,
)
))
))
else:
self.contextual_embedder = AudioMiniEncoder(1, conditioning_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=3, num_attn_heads=8, dropout=dropout, downsample_factor=4, kernel_size=5)
self.conditioning_conv = nn.Conv1d(conditioning_dim*2, conditioning_dim, 1)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
self.conditioning_conv = nn.Conv1d(
conditioning_dim*2, conditioning_dim, 1)
self.unconditioned_embedding = nn.Parameter(
torch.randn(1, conditioning_dim, 1))
self.conditioning_timestep_integrator = TimestepEmbedSequential(
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim,
dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
AttentionBlock(conditioning_dim, num_heads=num_heads,
num_head_channels=num_head_channels),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim,
dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
AttentionBlock(conditioning_dim, num_heads=num_heads,
num_head_channels=num_head_channels),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim,
dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
)
self.conditioning_expansion = conditioning_expansion
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
conv_nd(dims, in_channels, model_channels,
kernel_size, padding=padding)
)
]
)
@ -371,7 +393,8 @@ class DiffusionTts(nn.Module):
if level and i == num_blocks:
out_ch = ch
layers.append(
Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor)
Upsample(ch, conv_resample, dims=dims,
out_channels=out_ch, factor=scale_factor)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
@ -380,7 +403,8 @@ class DiffusionTts(nn.Module):
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
zero_module(conv_nd(dims, model_channels, out_channels,
kernel_size, padding=padding)),
)
if self.freeze_main_net:
@ -410,13 +434,14 @@ class DiffusionTts(nn.Module):
cm = ceil_multiple(x.shape[-1], self.alignment_size)
if cm != 0:
pc = (cm-x.shape[-1])/x.shape[-1]
x = F.pad(x, (0,cm-x.shape[-1]))
x = F.pad(x, (0, cm-x.shape[-1]))
# Also fix aligned_latent, which is aligned to x.
if is_latent(aligned_conditioning):
aligned_conditioning = torch.cat([aligned_conditioning,
self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1)
else:
aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1])))
aligned_conditioning = F.pad(
aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1])))
return x, aligned_conditioning
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False):
@ -435,9 +460,12 @@ class DiffusionTts(nn.Module):
if self.super_sampling_enabled:
assert lr_input is not None
if self.training and self.super_sampling_max_noising_factor > 0:
noising_factor = random.uniform(0,self.super_sampling_max_noising_factor)
lr_input = torch.randn_like(lr_input) * noising_factor + lr_input
lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest')
noising_factor = random.uniform(
0, self.super_sampling_max_noising_factor)
lr_input = torch.randn_like(
lr_input) * noising_factor + lr_input
lr_input = F.interpolate(
lr_input, size=(x.shape[-1],), mode='nearest')
x = torch.cat([x, lr_input], dim=1)
# Shuffle aligned_latent to BxCxS format
@ -451,11 +479,13 @@ class DiffusionTts(nn.Module):
with autocast(x.device.type, enabled=self.enable_fp16):
hs = []
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
time_emb = self.time_embed(
timestep_embedding(timesteps, self.model_channels))
# Note: this block does not need to repeated on inference, since it is not timestep-dependent.
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
code_emb = self.unconditioned_embedding.repeat(
x.shape[0], 1, 1)
else:
cond_emb = self.contextual_embedder(conditioning_input)
if len(cond_emb.shape) == 3: # Just take the first element.
@ -464,8 +494,10 @@ class DiffusionTts(nn.Module):
code_emb = self.latent_converter(aligned_conditioning)
else:
code_emb = self.code_converter(aligned_conditioning)
cond_emb = cond_emb.unsqueeze(-1).repeat(1, 1, code_emb.shape[-1])
code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb], dim=1))
cond_emb = cond_emb.unsqueeze(-1).repeat(1,
1, code_emb.shape[-1])
code_emb = self.conditioning_conv(
torch.cat([cond_emb, code_emb], dim=1))
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
@ -474,15 +506,18 @@ class DiffusionTts(nn.Module):
code_emb)
# Everything after this comment is timestep dependent.
code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
code_emb = torch.repeat_interleave(
code_emb, self.conditioning_expansion, dim=-1)
code_emb = self.conditioning_timestep_integrator(
code_emb, time_emb)
first = True
time_emb = time_emb.float()
h = x
for k, module in enumerate(self.input_blocks):
if isinstance(module, nn.Conv1d):
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
h_tok = F.interpolate(module(code_emb), size=(
h.shape[-1]), mode='nearest')
h = h + h_tok
else:
with autocast(x.device.type, enabled=self.enable_fp16 and not first):
@ -501,7 +536,8 @@ class DiffusionTts(nn.Module):
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
extraneous_addition = 0
params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters())
params = [self.aligned_latent_padding_embedding,
self.unconditioned_embedding] + list(self.latent_converter.parameters())
for p in params:
extraneous_addition = extraneous_addition + p.mean()
out = out + extraneous_addition * 0
@ -516,14 +552,14 @@ def register_diffusion_tts9(opt_net, opt):
if __name__ == '__main__':
clip = torch.randn(2, 1, 32868)
aligned_latent = torch.randn(2,388,1024)
aligned_sequence = torch.randint(0,8192,(2,388))
aligned_latent = torch.randn(2, 388, 1024)
aligned_sequence = torch.randint(0, 8192, (2, 388))
cond = torch.randn(2, 1, 44000)
ts = torch.LongTensor([600, 600])
model = DiffusionTts(128,
channel_mult=[1,1.5,2, 3, 4, 6, 8],
channel_mult=[1, 1.5, 2, 3, 4, 6, 8],
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
token_conditioning_resolutions=[1,4,16,64],
token_conditioning_resolutions=[1, 4, 16, 64],
attention_resolutions=[],
num_heads=8,
kernel_size=3,
@ -535,4 +571,3 @@ if __name__ == '__main__':
o = model(clip, ts, aligned_latent, cond)
# Test with sequence aligned conditioning
o = model(clip, ts, aligned_sequence, cond)

@ -5,18 +5,22 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
import torch_intermediary as ml
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock, QKVAttentionLegacy
from models.lucidrains.x_transformers import RelativePositionBias
from trainer.networks import register_model
from utils.util import checkpoint
import dlas.torch_intermediary as ml
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
timestep_embedding, zero_module)
from dlas.models.diffusion.unet_diffusion import (QKVAttentionLegacy,
TimestepBlock,
TimestepEmbedSequential)
from dlas.models.lucidrains.x_transformers import RelativePositionBias
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint
def is_latent(t):
return t.dtype == torch.float
def is_sequence(t):
return t.dtype == torch.long
@ -54,7 +58,8 @@ class AttentionBlock(nn.Module):
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
if relative_pos_embeddings:
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
self.relative_pos_embeddings = RelativePositionBias(scale=(
channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
else:
self.relative_pos_embeddings = None
@ -92,7 +97,8 @@ class ResBlock(TimestepBlock):
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding),
conv_nd(dims, channels, self.out_channels,
eff_kernel, padding=eff_padding),
)
self.emb_layers = nn.Sequential(
@ -107,14 +113,16 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
conv_nd(dims, self.out_channels, self.out_channels,
kernel_size, padding=padding)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
self.skip_connection = conv_nd(
dims, channels, self.out_channels, eff_kernel, padding=eff_padding)
def forward(self, x, emb):
"""
@ -147,8 +155,10 @@ class ResBlock(TimestepBlock):
class DiffusionLayer(TimestepBlock):
def __init__(self, model_channels, dropout, num_heads):
super().__init__()
self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True)
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
self.resblk = ResBlock(model_channels, model_channels, dropout,
model_channels, dims=1, use_scale_shift_norm=True)
self.attn = AttentionBlock(
model_channels, num_heads, relative_pos_embeddings=True)
def forward(self, x, time_emb):
y = self.resblk(x, time_emb)
@ -170,7 +180,8 @@ class DiffusionTtsFlat(nn.Module):
freeze_everything_except_autoregressive_inputs=False,
# Parameters for regularization.
layer_drop=.1,
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
# This implements a mechanism similar to what is used in classifier-free training.
unconditioned_percentage=.1,
):
super().__init__()
@ -198,33 +209,48 @@ class DiffusionTtsFlat(nn.Module):
# nn.Embedding
self.code_embedding = ml.Embedding(in_tokens, model_channels)
self.code_converter = nn.Sequential(
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads,
relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads,
relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads,
relative_pos_embeddings=True),
)
self.code_norm = normalization(model_channels)
self.latent_conditioner = nn.Sequential(
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads,
relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads,
relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads,
relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads,
relative_pos_embeddings=True),
)
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2),
nn.Conv1d(
model_channels, model_channels*2, 3, padding=1, stride=2),
AttentionBlock(
model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(
model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(
model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(
model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False))
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
self.unconditioned_embedding = nn.Parameter(
torch.randn(1, model_channels, 1))
self.conditioning_timestep_integrator = TimestepEmbedSequential(
DiffusionLayer(model_channels, dropout, num_heads),
DiffusionLayer(model_channels, dropout, num_heads),
DiffusionLayer(model_channels, dropout, num_heads),
)
self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1)
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
self.integrating_conv = nn.Conv1d(
model_channels*2, model_channels, kernel_size=1)
self.mel_head = nn.Conv1d(
model_channels, in_channels, kernel_size=3, padding=1)
self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
[ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])
@ -232,7 +258,8 @@ class DiffusionTtsFlat(nn.Module):
self.out = nn.Sequential(
normalization(model_channels),
nn.SiLU(),
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
zero_module(conv_nd(1, model_channels,
out_channels, 3, padding=1)),
)
if freeze_everything_except_autoregressive_inputs:
@ -263,25 +290,30 @@ class DiffusionTtsFlat(nn.Module):
conditioning_input.shape) == 3 else conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
conds.append(self.contextual_embedder(
speech_conditioning_input[:, j]))
conds = torch.cat(conds, dim=-1)
cond_emb = conds.mean(dim=-1)
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
if is_latent(aligned_conditioning):
code_emb = self.latent_conditioner(aligned_conditioning)
else:
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
code_emb = self.code_embedding(
aligned_conditioning).permute(0, 2, 1)
code_emb = self.code_converter(code_emb)
code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
code_emb = self.code_norm(
code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
unconditioned_batches = torch.zeros(
(code_emb.shape[0], 1, 1), device=code_emb.device)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
device=code_emb.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
code_emb)
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
expanded_code_emb = F.interpolate(
code_emb, size=expected_seq_len, mode='nearest')
if not return_code_pred:
return expanded_code_emb
@ -291,7 +323,6 @@ class DiffusionTtsFlat(nn.Module):
mel_pred = mel_pred * unconditioned_batches.logical_not()
return expanded_code_emb, mel_pred
def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
"""
Apply the model to an input batch.
@ -304,27 +335,36 @@ class DiffusionTtsFlat(nn.Module):
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
:return: an [N x C x ...] Tensor of outputs.
"""
assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_input is not None)
assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
assert precomputed_aligned_embeddings is not None or (
aligned_conditioning is not None and conditioning_input is not None)
# These two are mutually exclusive.
assert not (
return_code_pred and precomputed_aligned_embeddings is not None)
unused_params = list(self.mel_head.parameters())
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
code_emb = self.unconditioned_embedding.repeat(
x.shape[0], 1, x.shape[-1])
unused_params.extend(
list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
unused_params.extend(list(self.latent_conditioner.parameters()))
else:
if precomputed_aligned_embeddings is not None:
code_emb = precomputed_aligned_embeddings
else:
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, x.shape[-1], True)
code_emb, mel_pred = self.timestep_independent(
aligned_conditioning, conditioning_input, x.shape[-1], True)
if is_latent(aligned_conditioning):
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
unused_params.extend(
list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
else:
unused_params.extend(list(self.latent_conditioner.parameters()))
unused_params.extend(
list(self.latent_conditioner.parameters()))
unused_params.append(self.unconditioned_embedding)
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
time_emb = self.time_embed(
timestep_embedding(timesteps, self.model_channels))
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
x = self.inp_block(x)
x = torch.cat([x, code_emb], dim=1)
@ -356,10 +396,12 @@ class DiffusionTtsFlat(nn.Module):
conditioning_input.shape) == 3 else conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
conds.append(self.contextual_embedder(
speech_conditioning_input[:, j]))
conds = torch.cat(conds, dim=-1)
return conds.mean(dim=-1)
@register_model
def register_diffusion_tts_flat(opt_net, opt):
return DiffusionTtsFlat(**opt_net['kwargs'])
@ -367,17 +409,17 @@ def register_diffusion_tts_flat(opt_net, opt):
if __name__ == '__main__':
clip = torch.randn(2, 100, 400)
aligned_latent = torch.randn(2,388,512)
aligned_sequence = torch.randint(0,8192,(2,100))
aligned_latent = torch.randn(2, 388, 512)
aligned_sequence = torch.randint(0, 8192, (2, 100))
cond = torch.randn(2, 100, 400)
ts = torch.LongTensor([600, 600])
model = DiffusionTtsFlat(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=True, num_heads=16,
layer_drop=0, unconditioned_percentage=0)
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=True, num_heads=16,
layer_drop=0, unconditioned_percentage=0)
# Test with latent aligned conditioning
#o = model(clip, ts, aligned_latent, cond)
# o = model(clip, ts, aligned_latent, cond)
# Test with sequence aligned conditioning
#o = model(clip, ts, aligned_sequence, cond)
# o = model(clip, ts, aligned_sequence, cond)
with torch.no_grad():
proj = torch.randn(2, 100, 1024).cuda()
@ -389,4 +431,3 @@ if __name__ == '__main__':
for k in range(1000):
model(clip, ts, precomputed_aligned_embeddings=ti)
print(f"Elapsed: {time()-start}")

@ -1,11 +1,16 @@
from models.diffusion.fp16_util import convert_module_to_f32, convert_module_to_f16
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, ResBlock, TimestepEmbedSequential, \
Downsample, Upsample
import torch
import torch.nn as nn
from trainer.networks import register_model
from dlas.models.diffusion.fp16_util import (convert_module_to_f16,
convert_module_to_f32)
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
timestep_embedding, zero_module)
from dlas.models.diffusion.unet_diffusion import (AttentionBlock,
AttentionPool2d, Downsample,
ResBlock,
TimestepEmbedSequential,
Upsample)
from dlas.trainer.networks import register_model
class DiffusionVocoder(nn.Module):
@ -46,11 +51,12 @@ class DiffusionVocoder(nn.Module):
in_channels=1,
out_channels=2, # mean and variance
spectrogram_channels=80,
spectrogram_conditioning_level=3, # Level at which spectrogram conditioning is applied to the waveform.
# Level at which spectrogram conditioning is applied to the waveform.
spectrogram_conditioning_level=3,
dropout=0,
# 106496 -> 26624 -> 6656 -> 16664 -> 416 -> 104 -> 26 for ~5secs@22050Hz
channel_mult=(1, 2, 4, 8, 16, 32, 64),
attention_resolutions=(16,32,64),
attention_resolutions=(16, 32, 64),
conv_resample=True,
dims=1,
num_classes=None,
@ -95,7 +101,8 @@ class DiffusionVocoder(nn.Module):
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
conv_nd(dims, in_channels, model_channels,
kernel_size, padding=padding)
)
]
)
@ -104,7 +111,8 @@ class DiffusionVocoder(nn.Module):
ch = model_channels
ds = 1
spec_chs = channel_mult[spectrogram_conditioning_level] * model_channels
spec_chs = channel_mult[spectrogram_conditioning_level] * \
model_channels
self.spectrogram_conditioner = nn.Sequential(
conv_nd(dims, self.spectrogram_channels, spec_chs, 1),
normalization(spec_chs),
@ -119,7 +127,8 @@ class DiffusionVocoder(nn.Module):
for level, mult in enumerate(channel_mult):
if level == spectrogram_conditioning_level+1:
ch *= 2 # At this level, the spectrogram is concatenated onto the input.
# At this level, the spectrogram is concatenated onto the input.
ch *= 2
for _ in range(num_res_blocks):
layers = [
@ -248,7 +257,8 @@ class DiffusionVocoder(nn.Module):
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
zero_module(conv_nd(dims, model_channels, out_channels,
kernel_size, padding=padding)),
)
def convert_to_fp16(self):
@ -278,14 +288,16 @@ class DiffusionVocoder(nn.Module):
"""
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
emb = self.time_embed(timestep_embedding(
timesteps, self.model_channels))
conditioning = self.spectrogram_conditioner(spectrogram)
h = x.type(self.dtype)
for k, module in enumerate(self.input_blocks):
h = module(h, emb)
if k == self.input_block_injection_point:
cond = nn.functional.interpolate(conditioning, size=h.shape[-self.dims:], mode='nearest')
cond = nn.functional.interpolate(
conditioning, size=h.shape[-self.dims:], mode='nearest')
h = torch.cat([h, cond], dim=1)
h = self.convergence_conv(h)
hs.append(h)

@ -1,11 +1,14 @@
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionBlock, ResBlock, TimestepEmbedSequential, \
Downsample, Upsample
import torch
import torch.nn as nn
from models.audio.tts.mini_encoder import AudioMiniEncoder
from trainer.networks import register_model
from dlas.models.audio.tts.mini_encoder import AudioMiniEncoder
from dlas.models.diffusion.nn import (conv_nd, linear, normalization,
timestep_embedding, zero_module)
from dlas.models.diffusion.unet_diffusion import (AttentionBlock, Downsample,
ResBlock,
TimestepEmbedSequential,
Upsample)
from dlas.trainer.networks import register_model
class DiscreteSpectrogramConditioningBlock(nn.Module):
@ -23,6 +26,7 @@ class DiscreteSpectrogramConditioningBlock(nn.Module):
:param x: bxcxS waveform latent
:param codes: bxN discrete codes, N <= S
"""
def forward(self, x, dvae_in):
b, c, S = x.shape
_, q, N = dvae_in.shape
@ -70,12 +74,12 @@ class DiffusionVocoderWithRef(nn.Module):
discrete_codes=512,
dropout=0,
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
channel_mult=(1, 1.5, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
spectrogram_conditioning_resolutions=(512,),
attention_resolutions=(512,1024,2048),
attention_resolutions=(512, 1024, 2048),
conv_resample=True,
dims=1,
use_fp16=False,
@ -122,10 +126,11 @@ class DiffusionVocoderWithRef(nn.Module):
self.conditioning_enabled = conditioning_inputs_provided
if conditioning_inputs_provided:
self.contextual_embedder = AudioMiniEncoder(in_channels, time_embed_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
seqlyr = TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
conv_nd(dims, in_channels, model_channels,
kernel_size, padding=padding)
)
seqlyr.level = 0
self.input_blocks = nn.ModuleList([seqlyr])
@ -137,7 +142,8 @@ class DiffusionVocoderWithRef(nn.Module):
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
if ds in spectrogram_conditioning_resolutions:
spec_cond_block = DiscreteSpectrogramConditioningBlock(discrete_codes, ch, 2 ** level)
spec_cond_block = DiscreteSpectrogramConditioningBlock(
discrete_codes, ch, 2 ** level)
self.input_blocks.append(spec_cond_block)
spectrogram_blocks.append(spec_cond_block)
ch *= 2
@ -172,21 +178,21 @@ class DiffusionVocoderWithRef(nn.Module):
if level != len(channel_mult) - 1:
out_ch = ch
upblk = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
kernel_size=kernel_size,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor
)
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
kernel_size=kernel_size,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor
)
)
upblk.level = 2 ** level
self.input_blocks.append(upblk)
ch = out_ch
@ -270,7 +276,8 @@ class DiffusionVocoderWithRef(nn.Module):
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
zero_module(conv_nd(dims, model_channels, out_channels,
kernel_size, padding=padding)),
)
if freeze_layers_below is not None:
@ -297,7 +304,8 @@ class DiffusionVocoderWithRef(nn.Module):
del p.DO_NOT_TRAIN
p.requires_grad = True
unfrozen_params += 1
print(f"freeze_layers_below specified. Training a total of {unfrozen_params} parameters.")
print(
f"freeze_layers_below specified. Training a total of {unfrozen_params} parameters.")
def forward(self, x, timesteps, spectrogram, conditioning_input=None):
"""
@ -313,7 +321,8 @@ class DiffusionVocoderWithRef(nn.Module):
assert conditioning_input is not None
hs = []
emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
emb1 = self.time_embed(timestep_embedding(
timesteps, self.model_channels))
if self.conditioning_enabled:
emb2 = self.contextual_embedder(conditioning_input)
emb = emb1 + emb2
@ -363,19 +372,20 @@ if __name__ == '__main__':
move_all_layers_down(path, 'diffuse_new_lyr.pth', layers_to_be_added=2)
clip = torch.randn(2, 1, 40960)
spec = torch.randn(2,80,160)
spec = torch.randn(2, 80, 160)
cond = torch.randn(2, 1, 40960)
ts = torch.LongTensor([555, 556])
model = DiffusionVocoderWithRef(model_channels=128, channel_mult=[1,1,1.5,2, 3, 4, 6, 8, 8, 8, 8 ],
num_res_blocks=[1,2, 2, 2, 2, 2, 2, 2, 2, 1, 1 ], spectrogram_conditioning_resolutions=[2,512],
dropout=.05, attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2,
model = DiffusionVocoderWithRef(model_channels=128, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8],
num_res_blocks=[1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1], spectrogram_conditioning_resolutions=[2, 512],
dropout=.05, attention_resolutions=[512, 1024], num_heads=4, kernel_size=3, scale_factor=2,
conditioning_inputs_provided=True, conditioning_input_dim=80, time_embed_dim_multiplier=4,
discrete_codes=80, freeze_layers_below=1)
loading_errors = model.load_state_dict(torch.load('diffuse_new_lyr.pth'), strict=False)
loading_errors = model.load_state_dict(
torch.load('diffuse_new_lyr.pth'), strict=False)
new_params = loading_errors.missing_keys
new_params_trained = []
existing_params_trained = []
for n,p in model.named_parameters():
for n, p in model.named_parameters():
if not hasattr(p, 'DO_NOT_TRAIN'):
if n in new_params:
new_params_trained.append(n)

@ -1,24 +1,26 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
from transformers.utils.model_parallel_utils import (assert_device_map,
get_device_map)
from models.arch_util import AttentionBlock
from models.audio.tts.transformer_builders import build_hf_gpt_transformer
from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb
from trainer.networks import register_model
from utils.util import opt_get
import dlas.torch_intermediary as ml
from dlas.models.arch_util import AttentionBlock
from dlas.models.audio.tts.transformer_builders import build_hf_gpt_transformer
from dlas.models.lucidrains.x_transformers import (RotaryEmbedding,
apply_rotary_pos_emb)
from dlas.trainer.networks import register_model
from dlas.utils.util import opt_get
import torch_intermediary as ml
class ResBlock(nn.Module):
"""
Basic residual convolutional block that uses GroupNorm.
"""
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
@ -48,7 +50,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
get_device_map(len(self.transformer.h),
range(torch.cuda.device_count()))
if device_map is None
else device_map
)
@ -121,7 +124,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
):
assert self.cached_mel_emb is not None
assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model.
# Training not supported by this inference model.
assert labels is None
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Create embedding
@ -131,13 +135,15 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
text_emb = self.embeddings(text_inputs)
text_emb = text_emb + self.text_pos_embedding(text_emb)
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0)
mel_emb = self.cached_mel_emb.repeat_interleave(
text_emb.shape[0]//self.cached_mel_emb.shape[0], 0)
else:
mel_emb = self.cached_mel_emb
emb = torch.cat([mel_emb, text_emb], dim=1)
else:
emb = self.embeddings(input_ids)
emb = emb + self.text_pos_embedding.get_fixed_embedding(attention_mask.shape[1]-mel_len, attention_mask.device)
emb = emb + self.text_pos_embedding.get_fixed_embedding(
attention_mask.shape[1]-mel_len, attention_mask.device)
transformer_outputs = self.transformer(
inputs_embeds=emb,
@ -182,7 +188,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
tuple(past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past)
for layer_past in past
)
@ -199,7 +206,8 @@ class ConditioningEncoder(nn.Module):
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
attn.append(AttentionBlock(embedding_dim,
num_attn_heads, do_checkpoint=do_checkpointing))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
self.do_checkpointing = do_checkpointing
@ -219,23 +227,27 @@ class MelEncoder(nn.Module):
super().__init__()
self.channels = channels
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1),
nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
nn.Sequential(
*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//4, channels//2,
kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels//16, channels//2),
nn.ReLU(),
nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
nn.Sequential(
*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//2, channels,
kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels//8, channels),
nn.ReLU(),
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
nn.Sequential(
*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
)
self.reduction = 4
def forward(self, x):
for e in self.encoder:
x = e(x)
return x.permute(0,2,1)
return x.permute(0, 2, 1)
class UnifiedVoice(nn.Module):
@ -276,25 +288,32 @@ class UnifiedVoice(nn.Module):
self.layers = layers
self.heads = heads
self.max_conditioning_inputs = max_conditioning_inputs
self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens+2+self.max_conditioning_inputs
self.max_mel_tokens = -1 if max_mel_tokens == - \
1 else max_mel_tokens+2+self.max_conditioning_inputs
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens+2
self.model_dim = model_dim
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.conditioning_encoder = ConditioningEncoder(
80, model_dim, num_attn_heads=heads)
self.average_conditioning_embeddings = average_conditioning_embeddings
self.tortoise_compat = tortoise_compat # credit to https://github.com/152334H/DL-Art-School/commit/ae80992817059acf6eef38a680efa5124cee570b
# credit to https://github.com/152334H/DL-Art-School/commit/ae80992817059acf6eef38a680efa5124cee570b
self.tortoise_compat = tortoise_compat
# nn.Embedding
self.text_embedding = ml.Embedding(self.number_text_tokens, model_dim)
if use_mel_codes_as_input:
# nn.Embedding
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
else:
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
self.mel_embedding = MelEncoder(
model_dim, resblocks_per_reduction=1)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
build_hf_gpt_transformer(
layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
if train_solo_embeddings:
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
self.mel_solo_embedding = nn.Parameter(
torch.randn(1, 1, model_dim) * .02, requires_grad=True)
self.text_solo_embedding = nn.Parameter(
torch.randn(1, 1, model_dim) * .02, requires_grad=True)
else:
self.mel_solo_embedding = 0
self.text_solo_embedding = 0
@ -303,7 +322,6 @@ class UnifiedVoice(nn.Module):
self.text_head = ml.Linear(model_dim, self.number_text_tokens)
self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
# Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding]
if use_mel_codes_as_input:
@ -328,8 +346,8 @@ class UnifiedVoice(nn.Module):
}
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1,0), value=start_token)
tar = F.pad(input, (0,1), value=stop_token)
inp = F.pad(input, (1, 0), value=start_token)
tar = F.pad(input, (0, 1), value=stop_token)
return inp, tar
def set_mel_padding(self, mel_input_tokens, wav_lengths):
@ -341,22 +359,26 @@ class UnifiedVoice(nn.Module):
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
mel_lengths = wav_lengths // self.mel_length_compression
for b in range(len(mel_lengths)):
actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
# Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
actual_end = mel_lengths[b] + 1
if actual_end < mel_input_tokens.shape[-1]:
mel_input_tokens[b, actual_end:] = self.stop_mel_token
return mel_input_tokens
def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False):
if second_inputs is not None:
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
emb = torch.cat([speech_conditioning_inputs,
first_inputs, second_inputs], dim=1)
else:
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True,
output_attentions=get_attns)
if get_attns:
return gpt_out.attentions
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
# The first logit is tied to the speech_conditioning_input
enc = gpt_out.last_hidden_state[:, 1:]
enc = self.final_norm(enc)
if return_latent:
@ -364,11 +386,11 @@ class UnifiedVoice(nn.Module):
first_logits = enc[:, :first_inputs.shape[1]]
first_logits = first_head(first_logits)
first_logits = first_logits.permute(0,2,1)
first_logits = first_logits.permute(0, 2, 1)
if second_inputs is not None:
second_logits = enc[:, -second_inputs.shape[1]:]
second_logits = second_head(second_logits)
second_logits = second_logits.permute(0,2,1)
second_logits = second_logits.permute(0, 2, 1)
return first_logits, second_logits
else:
return first_logits
@ -394,24 +416,31 @@ class UnifiedVoice(nn.Module):
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max()
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
text_inputs = F.pad(
text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token)
max_mel_len = wav_lengths.max() // self.mel_length_compression
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
mel_codes = F.pad(mel_codes[:, :max_mel_len],
(0, 1), value=self.stop_mel_token)
if raw_mels is not None:
raw_mels = raw_mels[:, :, :max_mel_len*4]
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds.append(self.conditioning_encoder(
speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
if self.average_conditioning_embeddings:
conds = conds.mean(dim=1).unsqueeze(1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(
text_inputs) + self.text_pos_embedding(text_inputs)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(
mel_codes, self.start_mel_token, self.stop_mel_token)
if raw_mels is not None:
mel_inp = F.pad(raw_mels, (0, 8))
else:
@ -421,13 +450,17 @@ class UnifiedVoice(nn.Module):
sub = -2 if self.tortoise_compat else -1
if text_first:
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
text_logits, mel_logits = self.get_logits(
conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
if return_latent:
return mel_logits[:, :sub] # Despite the name, these are not logits.
# Despite the name, these are not logits.
return mel_logits[:, :sub]
else:
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
mel_logits, text_logits = self.get_logits(
conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
if return_latent:
return text_logits[:, :sub] # Despite the name, these are not logits
# Despite the name, these are not logits
return text_logits[:, :sub]
if return_attentions:
return mel_logits
@ -443,18 +476,23 @@ class UnifiedVoice(nn.Module):
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max()
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
text_inputs = F.pad(
text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token)
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds.append(self.conditioning_encoder(
speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
if self.average_conditioning_embeddings:
conds = conds.mean(dim=1).unsqueeze(1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(
text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
text_logits = self.get_logits(conds, text_emb, self.text_head)
loss_text = F.cross_entropy(text_logits, text_targets.long())
return loss_text.mean()
@ -468,26 +506,31 @@ class UnifiedVoice(nn.Module):
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_mel_len = wav_lengths.max() // self.mel_length_compression
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
mel_codes = F.pad(mel_codes[:, :max_mel_len],
(0, 1), value=self.stop_mel_token)
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
if raw_mels is not None:
raw_mels = raw_mels[:, :, :max_mel_len*4]
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds.append(self.conditioning_encoder(
speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
if self.average_conditioning_embeddings:
conds = conds.mean(dim=1).unsqueeze(1)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(
mel_codes, self.start_mel_token, self.stop_mel_token)
if raw_mels is not None:
mel_inp = F.pad(raw_mels, (0, 4))
else:
mel_inp = mel_codes
mel_emb = self.mel_embedding(mel_inp)
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding
mel_emb = mel_emb + \
self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding
mel_logits = self.get_logits(conds, mel_emb, self.mel_head)
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_mel.mean()
@ -507,17 +550,22 @@ class UnifiedVoice(nn.Module):
n_head=self.heads,
gradient_checkpointing=False,
use_cache=True)
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
self.inference_model = GPT2InferenceModel(
gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
self.gpt.wte = self.mel_embedding
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(
text_inputs) + self.text_pos_embedding(text_inputs)
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds.append(self.conditioning_encoder(
speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
if self.average_conditioning_embeddings:
conds = conds.mean(dim=1).unsqueeze(1)
@ -525,8 +573,9 @@ class UnifiedVoice(nn.Module):
emb = torch.cat([conds, text_emb], dim=1)
self.inference_model.store_mel_emb(emb)
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs[:,-1] = self.start_mel_token
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],),
fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs[:, -1] = self.start_mel_token
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
max_length=seq_length, output_attentions=return_attentions, return_dict_in_generate=True, **hf_generate_kwargs)
@ -535,15 +584,16 @@ class UnifiedVoice(nn.Module):
else:
return gen.sequences[:, fake_inputs.shape[1]:]
# Turns the (utterly insane) output of HF.generate() into a far more sane output:
# [tensors(B,H,S,S)]. Outer=layers, B=batch,H=head,S=sequence
def make_hf_generate_attentions_sane(self, attentions):
layers = [[] for _ in range(len(attentions[0]))]
full_attention_size = attentions[-1][0].shape[-1]
for i, gen in enumerate(attentions):
for j, lyr in enumerate(gen):
layers[j].append(F.pad(lyr, (0, full_attention_size - lyr.shape[-1])))
layers[j].append(
F.pad(lyr, (0, full_attention_size - lyr.shape[-1])))
catted = []
for lyr in layers:
catted.append(torch.cat(lyr, dim=2))
@ -562,18 +612,21 @@ class UnifiedVoice(nn.Module):
for l, layer in enumerate(attentions):
dec_context = layer[:, :, num_context:, :]
# Mask out everything that isn't text (including the start token, which gets a LOT of attention)
dec_context[:,:,:,:text_padding+1] = 0
dec_context[:,:,:,num_context:] = 0
dec_context[:, :, :, :text_padding+1] = 0
dec_context[:, :, :, num_context:] = 0
for h in range(dec_context.shape[1]):
dec_context_indices = torch.argmax(dec_context[0,h], dim=-1)
dec_context_indices = torch.argmax(dec_context[0, h], dim=-1)
print(f'layer_{l};head_{h}: ' + str(dec_context_indices))
for t, att_tok in enumerate(attentions):
combined_attention_weights = torch.zeros((codes.shape[0], num_text), device=codes.device)
combined_attention_weights = torch.zeros(
(codes.shape[0], num_text), device=codes.device)
for lyr in att_tok:
token_to_text_attentions = lyr[:, :, -1, text_padding:(text_padding + num_text)].sum(dim=1)
token_to_text_attentions = lyr[:, :, -1,
text_padding:(text_padding + num_text)].sum(dim=1)
combined_attention_weights = combined_attention_weights + token_to_text_attentions
break
most_attended_text_token = combined_attention_weights.argmax(dim=-1)
most_attended_text_token = combined_attention_weights.argmax(
dim=-1)
results[:, t] = most_attended_text_token
eos_token_mask = (codes != self.stop_mel_token)
return results * eos_token_mask
@ -585,10 +638,11 @@ def register_unified_voice2(opt_net, opt):
if __name__ == '__main__':
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True)
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True,
max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True)
l = gpt(torch.randn(2, 3, 80, 800),
torch.randint(high=256, size=(2,120)),
torch.randint(high=256, size=(2, 120)),
torch.tensor([32, 120]),
torch.randint(high=8192, size=(2,250)),
torch.tensor([250*256,195*256]))
#gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
torch.randint(high=8192, size=(2, 250)),
torch.tensor([250*256, 195*256]))
# gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))

@ -1,24 +1,26 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_intermediary as ml
from transformers import GPT2Config, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
from transformers.utils.model_parallel_utils import (assert_device_map,
get_device_map)
from models.arch_util import AttentionBlock
from models.audio.tts.transformer_builders import build_hf_gpt_transformer
from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb
from trainer.networks import register_model
from utils.util import opt_get
import dlas.torch_intermediary as ml
from dlas.models.arch_util import AttentionBlock
from dlas.models.audio.tts.transformer_builders import build_hf_gpt_transformer
from dlas.models.lucidrains.x_transformers import (RotaryEmbedding,
apply_rotary_pos_emb)
from dlas.trainer.networks import register_model
from dlas.utils.util import opt_get
class ResBlock(nn.Module):
"""
Basic residual convolutional block that uses GroupNorm.
"""
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
@ -48,7 +50,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
get_device_map(len(self.transformer.h),
range(torch.cuda.device_count()))
if device_map is None
else device_map
)
@ -120,7 +123,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
):
assert self.cached_prior_emb is not None
assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model.
# Training not supported by this inference model.
assert labels is None
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Create embedding
@ -128,15 +132,18 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
if input_ids.shape[1] != 1:
posterior_inputs = input_ids[:, prior_len:]
posterior_emb = self.embeddings(posterior_inputs)
posterior_emb = posterior_emb + self.posterior_pos_embedding(posterior_emb)
posterior_emb = posterior_emb + \
self.posterior_pos_embedding(posterior_emb)
if self.cached_prior_emb.shape[0] != posterior_emb.shape[0]:
prior_emb = self.cached_prior_emb.repeat_interleave(posterior_emb.shape[0] // self.cached_prior_emb.shape[0], 0)
prior_emb = self.cached_prior_emb.repeat_interleave(
posterior_emb.shape[0] // self.cached_prior_emb.shape[0], 0)
else:
prior_emb = self.cached_prior_emb
emb = torch.cat([prior_emb, posterior_emb], dim=1)
else:
emb = self.embeddings(input_ids)
emb = emb + self.posterior_pos_embedding.get_fixed_embedding(attention_mask.shape[1] - prior_len, attention_mask.device)
emb = emb + self.posterior_pos_embedding.get_fixed_embedding(
attention_mask.shape[1] - prior_len, attention_mask.device)
transformer_outputs = self.transformer(
inputs_embeds=emb,
@ -181,7 +188,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
tuple(past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past)
for layer_past in past
)
@ -198,7 +206,8 @@ class ConditioningEncoder(nn.Module):
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
attn.append(AttentionBlock(embedding_dim,
num_attn_heads, do_checkpoint=do_checkpointing))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
self.do_checkpointing = do_checkpointing
@ -218,23 +227,27 @@ class MelEncoder(nn.Module):
super().__init__()
self.channels = channels
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1),
nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
nn.Sequential(
*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//4, channels//2,
kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels//16, channels//2),
nn.ReLU(),
nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
nn.Sequential(
*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//2, channels,
kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels//8, channels),
nn.ReLU(),
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
nn.Sequential(
*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
)
self.reduction = 4
def forward(self, x):
for e in self.encoder:
x = e(x)
return x.permute(0,2,1)
return x.permute(0, 2, 1)
class UnifiedVoice(nn.Module):
@ -260,7 +273,8 @@ class UnifiedVoice(nn.Module):
super().__init__()
self.number_text_tokens = number_text_tokens
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
self.start_text_token = number_text_tokens * \
types if start_text_token is None else start_text_token
self.stop_text_token = 0
self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token
@ -268,17 +282,21 @@ class UnifiedVoice(nn.Module):
self.layers = layers
self.heads = heads
self.max_conditioning_inputs = max_conditioning_inputs
self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens+2+self.max_conditioning_inputs
self.max_mel_tokens = -1 if max_mel_tokens == - \
1 else max_mel_tokens+2+self.max_conditioning_inputs
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens+2
self.model_dim = model_dim
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.conditioning_encoder = ConditioningEncoder(
80, model_dim, num_attn_heads=heads)
# nn.Embedding
self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim)
self.text_embedding = ml.Embedding(
self.number_text_tokens*types+1, model_dim)
# nn.Embedding
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
build_hf_gpt_transformer(
layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1)
@ -306,8 +324,8 @@ class UnifiedVoice(nn.Module):
}
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1,0), value=start_token)
tar = F.pad(input, (0,1), value=stop_token)
inp = F.pad(input, (1, 0), value=start_token)
tar = F.pad(input, (0, 1), value=stop_token)
return inp, tar
def set_mel_padding(self, mel_input_tokens, wav_lengths):
@ -319,45 +337,48 @@ class UnifiedVoice(nn.Module):
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
mel_lengths = wav_lengths // self.mel_length_compression
for b in range(len(mel_lengths)):
actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
# Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
actual_end = mel_lengths[b] + 1
if actual_end < mel_input_tokens.shape[-1]:
mel_input_tokens[b, actual_end:] = self.stop_mel_token
return mel_input_tokens
def get_logits(self, speech_conditioning_inputs, text_inputs, text_head, mel_inputs, mel_head, aligned_head, return_latent=False):
if mel_inputs is not None:
emb = torch.cat([speech_conditioning_inputs, text_inputs, mel_inputs], dim=1)
emb = torch.cat([speech_conditioning_inputs,
text_inputs, mel_inputs], dim=1)
else:
emb = torch.cat([speech_conditioning_inputs, text_inputs], dim=1)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True)
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
# The first logit is tied to the speech_conditioning_input
enc = gpt_out.last_hidden_state[:, 1:]
enc = self.final_norm(enc)
if return_latent:
return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1] + text_inputs.shape[1]], enc[:, -mel_inputs.shape[1]:]
text_logits = enc[:, :text_inputs.shape[1]]
text_logits = text_head(text_logits).permute(0,2,1)
text_logits = text_head(text_logits).permute(0, 2, 1)
mel_logits = enc[:, -mel_inputs.shape[1]:]
aligned_logits = aligned_head(mel_logits).permute(0,2,1)
mel_logits = mel_head(mel_logits).permute(0,2,1)
aligned_logits = aligned_head(mel_logits).permute(0, 2, 1)
mel_logits = mel_head(mel_logits).permute(0, 2, 1)
return text_logits, mel_logits, aligned_logits
def get_conditioning_latent(self, speech_conditioning_input):
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds.append(self.conditioning_encoder(
speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
conds = conds.mean(dim=1).unsqueeze(1)
return conds
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, aligned_codes, types=None, return_latent=False):
"""
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
@ -376,19 +397,22 @@ class UnifiedVoice(nn.Module):
if types is not None:
text_inputs = text_inputs * (1+types).unsqueeze(-1)
conds = self.get_conditioning_latent(speech_conditioning_input)
ac_expansion_factor = mel_codes.shape[-1] / aligned_codes.shape[-1]
aligned_codes = aligned_codes.repeat(1, ac_expansion_factor)
_, aligned_targets = self.build_aligned_inputs_and_targets(aligned_codes, 0, 0)
_, aligned_targets = self.build_aligned_inputs_and_targets(
aligned_codes, 0, 0)
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(
text_inputs) + self.text_pos_embedding(text_inputs)
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(
mel_codes, self.start_mel_token, self.stop_mel_token)
mel_inp = mel_codes
mel_emb = self.mel_embedding(mel_inp)
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
@ -396,7 +420,8 @@ class UnifiedVoice(nn.Module):
text_logits, mel_logits, aligned_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head,
self.aligned_head, return_latent=return_latent)
if return_latent:
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
# Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
return mel_logits[:, :-2]
loss_text = F.cross_entropy(text_logits, text_targets.long())
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
@ -418,25 +443,31 @@ class UnifiedVoice(nn.Module):
n_head=self.heads,
gradient_checkpointing=False,
use_cache=True)
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
self.inference_model = GPT2InferenceModel(
gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
self.gpt.wte = self.mel_embedding
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(
text_inputs) + self.text_pos_embedding(text_inputs)
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds.append(self.conditioning_encoder(
speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
conds = conds.mean(dim=1).unsqueeze(1)
emb = torch.cat([conds, text_emb], dim=1)
self.inference_model.store_prior_emb(emb)
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs[:,-1] = self.start_mel_token
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],),
fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs[:, -1] = self.start_mel_token
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
max_length=seq_length, return_dict_in_generate=True, **hf_generate_kwargs)
@ -449,11 +480,12 @@ def register_unified_voice3(opt_net, opt):
if __name__ == '__main__':
gpt = UnifiedVoice(model_dim=256, heads=4, max_conditioning_inputs=4, types=2)
mel = torch.randint(high=8192, size=(2,250))
ac = torch.randint(high=256, size=(2,250*1024//443))
gpt = UnifiedVoice(model_dim=256, heads=4,
max_conditioning_inputs=4, types=2)
mel = torch.randint(high=8192, size=(2, 250))
ac = torch.randint(high=256, size=(2, 250*1024//443))
l = gpt(torch.randn(2, 3, 80, 800),
torch.randint(high=256, size=(2,120)),
torch.randint(high=256, size=(2, 120)),
torch.tensor([32, 120]),
mel, torch.tensor([250*256,195*256]), ac,
mel, torch.tensor([250*256, 195*256]), ac,
types=torch.tensor([0, 1]))

@ -4,20 +4,23 @@ import torch.nn.functional as F
from transformers import GPT2Config, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
from transformers.utils.model_parallel_utils import (assert_device_map,
get_device_map)
from models.arch_util import AttentionBlock
from models.audio.tts.transformer_builders import build_hf_gpt_transformer
from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb
from trainer.networks import register_model
from utils.util import opt_get
import torch_intermediary as ml
import dlas.torch_intermediary as ml
from dlas.models.arch_util import AttentionBlock
from dlas.models.audio.tts.transformer_builders import build_hf_gpt_transformer
from dlas.models.lucidrains.x_transformers import (RotaryEmbedding,
apply_rotary_pos_emb)
from dlas.trainer.networks import register_model
from dlas.utils.util import opt_get
class ResBlock(nn.Module):
"""
Basic residual convolutional block that uses GroupNorm.
"""
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
@ -47,7 +50,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
get_device_map(len(self.transformer.h),
range(torch.cuda.device_count()))
if device_map is None
else device_map
)
@ -119,7 +123,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
):
assert self.cached_prior_emb is not None
assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model.
# Training not supported by this inference model.
assert labels is None
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Create embedding
@ -127,15 +132,18 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
if input_ids.shape[1] != 1:
posterior_inputs = input_ids[:, prior_len:]
posterior_emb = self.embeddings(posterior_inputs)
posterior_emb = posterior_emb + self.posterior_pos_embedding(posterior_emb)
posterior_emb = posterior_emb + \
self.posterior_pos_embedding(posterior_emb)
if self.cached_prior_emb.shape[0] != posterior_emb.shape[0]:
prior_emb = self.cached_prior_emb.repeat_interleave(posterior_emb.shape[0] // self.cached_prior_emb.shape[0], 0)
prior_emb = self.cached_prior_emb.repeat_interleave(
posterior_emb.shape[0] // self.cached_prior_emb.shape[0], 0)
else:
prior_emb = self.cached_prior_emb
emb = torch.cat([prior_emb, posterior_emb], dim=1)
else:
emb = self.embeddings(input_ids)
emb = emb + self.posterior_pos_embedding.get_fixed_embedding(attention_mask.shape[1] - prior_len, attention_mask.device)
emb = emb + self.posterior_pos_embedding.get_fixed_embedding(
attention_mask.shape[1] - prior_len, attention_mask.device)
transformer_outputs = self.transformer(
inputs_embeds=emb,
@ -180,7 +188,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
tuple(past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past)
for layer_past in past
)
@ -197,7 +206,8 @@ class ConditioningEncoder(nn.Module):
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
attn.append(AttentionBlock(embedding_dim,
num_attn_heads, do_checkpoint=do_checkpointing))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
self.do_checkpointing = do_checkpointing
@ -217,23 +227,27 @@ class MelEncoder(nn.Module):
super().__init__()
self.channels = channels
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1),
nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
nn.Sequential(
*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//4, channels//2,
kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels//16, channels//2),
nn.ReLU(),
nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
nn.Sequential(
*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//2, channels,
kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels//8, channels),
nn.ReLU(),
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
nn.Sequential(
*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
)
self.reduction = 4
def forward(self, x):
for e in self.encoder:
x = e(x)
return x.permute(0,2,1)
return x.permute(0, 2, 1)
class UnifiedVoice(nn.Module):
@ -243,7 +257,8 @@ class UnifiedVoice(nn.Module):
super().__init__()
self.number_text_tokens = number_text_tokens
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
self.start_text_token = number_text_tokens * \
types if start_text_token is None else start_text_token
self.stop_text_token = 0
self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token
@ -251,17 +266,21 @@ class UnifiedVoice(nn.Module):
self.layers = layers
self.heads = heads
self.max_conditioning_inputs = max_conditioning_inputs
self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens+2+self.max_conditioning_inputs
self.max_mel_tokens = -1 if max_mel_tokens == - \
1 else max_mel_tokens+2+self.max_conditioning_inputs
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens+2
self.model_dim = model_dim
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.conditioning_encoder = ConditioningEncoder(
80, model_dim, num_attn_heads=heads)
# nn.Embedding
self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim)
self.text_embedding = ml.Embedding(
self.number_text_tokens*types+1, model_dim)
# nn.Embedding
self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
build_hf_gpt_transformer(
layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1)
@ -289,8 +308,8 @@ class UnifiedVoice(nn.Module):
}
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1,0), value=start_token)
tar = F.pad(input, (0,1), value=stop_token)
inp = F.pad(input, (1, 0), value=start_token)
tar = F.pad(input, (0, 1), value=stop_token)
return inp, tar
def set_mel_padding(self, mel_input_tokens, wav_lengths):
@ -302,17 +321,20 @@ class UnifiedVoice(nn.Module):
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
mel_lengths = wav_lengths // self.mel_length_compression
for b in range(len(mel_lengths)):
actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
# Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
actual_end = mel_lengths[b] + 1
if actual_end < mel_input_tokens.shape[-1]:
mel_input_tokens[b, actual_end:] = self.stop_mel_token
return mel_input_tokens
def get_logits(self, speech_conditioning_inputs, first_inputs, second_inputs, return_latent=False):
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
emb = torch.cat([speech_conditioning_inputs,
first_inputs, second_inputs], dim=1)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True)
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
# The first logit is tied to the speech_conditioning_input
enc = gpt_out.last_hidden_state[:, 1:]
enc = self.final_norm(enc)
if return_latent:
@ -320,29 +342,29 @@ class UnifiedVoice(nn.Module):
text_logits = enc[:, :first_inputs.shape[1]]
text_logits = self.text_head(text_logits)
text_logits = text_logits.permute(0,2,1)
text_logits = text_logits.permute(0, 2, 1)
mel_logits = enc[:, -second_inputs.shape[1]:]
mel_logits = self.mel_head(mel_logits)
mel_logits = mel_logits.permute(0,2,1)
mel_logits = mel_logits.permute(0, 2, 1)
alignment_logits = enc[:, -second_inputs.shape[1]:]
alignment_logits = self.alignment_head(alignment_logits)
alignment_logits = alignment_logits.permute(0,2,1)
alignment_logits = alignment_logits.permute(0, 2, 1)
return text_logits, mel_logits, alignment_logits
def get_conditioning_latent(self, speech_conditioning_input):
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds.append(self.conditioning_encoder(
speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
conds = conds.mean(dim=1).unsqueeze(1)
return conds
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, ctc_codes, wav_lengths, types=None, return_latent=False):
"""
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
@ -368,24 +390,30 @@ class UnifiedVoice(nn.Module):
ctc_codes[b][j] = last_code
else:
last_code = ctc_codes[b][j]
alignment_targets = F.interpolate(ctc_codes.unsqueeze(1).float(), size=(mel_codes.shape[-1],), mode='nearest').long().squeeze()
alignment_targets = F.interpolate(ctc_codes.unsqueeze(1).float(), size=(
mel_codes.shape[-1],), mode='nearest').long().squeeze()
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
alignment_targets = F.pad(alignment_targets, (0,2), value=0)
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
alignment_targets = F.pad(alignment_targets, (0, 2), value=0)
conds = self.get_conditioning_latent(speech_conditioning_input)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(
text_inputs) + self.text_pos_embedding(text_inputs)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(
mel_codes, self.start_mel_token, self.stop_mel_token)
mel_inp = mel_codes
mel_emb = self.mel_embedding(mel_inp)
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
text_logits, mel_logits, alignment_logits = self.get_logits(conds, text_emb, mel_emb, return_latent=return_latent)
text_logits, mel_logits, alignment_logits = self.get_logits(
conds, text_emb, mel_emb, return_latent=return_latent)
if return_latent:
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
# Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
return mel_logits[:, :-2]
loss_text = F.cross_entropy(text_logits, text_targets.long())
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
@ -407,25 +435,31 @@ class UnifiedVoice(nn.Module):
n_head=self.heads,
gradient_checkpointing=False,
use_cache=True)
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
self.inference_model = GPT2InferenceModel(
gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
self.gpt.wte = self.mel_embedding
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(
text_inputs) + self.text_pos_embedding(text_inputs)
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(
speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds.append(self.conditioning_encoder(
speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
conds = conds.mean(dim=1).unsqueeze(1)
emb = torch.cat([conds, text_emb], dim=1)
self.inference_model.store_prior_emb(emb)
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs[:,-1] = self.start_mel_token
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],),
fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs[:, -1] = self.start_mel_token
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
max_length=seq_length, return_dict_in_generate=True, **hf_generate_kwargs)
@ -438,10 +472,11 @@ def register_unified_voice4(opt_net, opt):
if __name__ == '__main__':
gpt = UnifiedVoice(model_dim=256, heads=4, max_conditioning_inputs=4, types=2)
gpt = UnifiedVoice(model_dim=256, heads=4,
max_conditioning_inputs=4, types=2)
l = gpt(torch.randn(2, 3, 80, 800),
torch.randint(high=256, size=(2,120)),
torch.randint(high=256, size=(2, 120)),
torch.tensor([32, 120]),
torch.randint(high=8192, size=(2,250)),
torch.tensor([250*256,195*256]),
torch.randint(high=8192, size=(2, 250)),
torch.tensor([250*256, 195*256]),
types=torch.tensor([0, 1]))

@ -1,14 +1,15 @@
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from models.audio.tts.mini_encoder import AudioMiniEncoder
from trainer.injectors.spec_augment import spec_augment
from trainer.networks import register_model
from utils.util import opt_get
import torch_intermediary as ml
import dlas.torch_intermediary as ml
from dlas.models.audio.tts.mini_encoder import AudioMiniEncoder
from dlas.trainer.injectors.spec_augment import spec_augment
from dlas.trainer.networks import register_model
from dlas.utils.util import opt_get
def exists(val):
@ -17,7 +18,7 @@ def exists(val):
def masked_mean(t, mask, dim=1):
t = t.masked_fill(~mask[:, :, None], 0.)
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
return t.sum(dim=1) / mask.sum(dim=1)[..., None]
class VoiceCLIP(nn.Module):
@ -36,7 +37,8 @@ class VoiceCLIP(nn.Module):
super().__init__()
self.encoder = AudioMiniEncoder(80, encoder_output)
if pretrained_encoder_dict_path is not None:
self.encoder.load_state_dict(torch.load(pretrained_encoder_dict_path))
self.encoder.load_state_dict(
torch.load(pretrained_encoder_dict_path))
self.to_latent = ml.Linear(encoder_output, dim_latent, bias=False)
self.temperature = nn.Parameter(torch.tensor(1.))
self.mel_compression_ratio = mel_compression_ratio
@ -47,7 +49,8 @@ class VoiceCLIP(nn.Module):
speech_lengths,
return_loss=True
):
half_length = min(speech_mels.shape[-1], torch.min(speech_lengths).item() // self.mel_compression_ratio) // 2
half_length = min(
speech_mels.shape[-1], torch.min(speech_lengths).item() // self.mel_compression_ratio) // 2
# Extract two speech MELs from the same clip, apply some random noise to them and also apply specaugment to them.
first_half = speech_mels[:, :, :half_length]
@ -81,7 +84,8 @@ class VoiceCLIP(nn.Module):
second_emb = self.encoder(second_half)
second_latents = self.to_latent(second_emb)
first_latents, second_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (first_latents, second_latents))
first_latents, second_latents = map(lambda t: F.normalize(
t, p=2, dim=-1), (first_latents, second_latents))
temp = self.temperature.exp()
@ -90,8 +94,10 @@ class VoiceCLIP(nn.Module):
return sim
sim = einsum('i d, j d -> i j', first_latents, second_latents) * temp
labels = torch.arange(first_latents.shape[0], device=first_latents.device)
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
labels = torch.arange(
first_latents.shape[0], device=first_latents.device)
loss = (F.cross_entropy(sim, labels) +
F.cross_entropy(sim.t(), labels)) / 2
return loss
def inference(self, speech_mels):
@ -111,6 +117,6 @@ def register_voice_to_voice_clip(opt_net, opt):
if __name__ == '__main__':
clip = VoiceCLIP()
for k in range(1000):
clip(torch.randn((2,80,156)),
torch.randint(130*1024,156*1024,(2,)),
return_loss=True)
clip(torch.randn((2, 80, 156)),
torch.randint(130*1024, 156*1024, (2,)),
return_loss=True)

@ -3,11 +3,11 @@ import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
from x_transformers import Encoder, Decoder, ContinuousTransformerWrapper
from x_transformers import ContinuousTransformerWrapper, Decoder, Encoder
from models.audio.tts.mini_encoder import AudioMiniEncoder
from trainer.networks import register_model
import torch_intermediary as ml
import dlas.torch_intermediary as ml
from dlas.models.audio.tts.mini_encoder import AudioMiniEncoder
from dlas.trainer.networks import register_model
class CheckpointedLayer(nn.Module):
@ -15,13 +15,15 @@ class CheckpointedLayer(nn.Module):
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
checkpoint for all other args.
"""
def __init__(self, wrap):
super().__init__()
self.wrap = wrap
def forward(self, x, *args, **kwargs):
for k, v in kwargs.items():
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
# This would screw up checkpointing.
assert not (isinstance(v, torch.Tensor) and v.requires_grad)
partial = functools.partial(self.wrap, **kwargs)
return torch.utils.checkpoint.checkpoint(partial, x, *args)
@ -31,20 +33,22 @@ class CheckpointedXTransformer(nn.Module):
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
to channels-last that XTransformer expects.
"""
def __init__(self, **xtransformer_kwargs):
super().__init__()
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
for i in range(len(self.transformer.attn_layers.layers)):
n, b, r = self.transformer.attn_layers.layers[i]
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
self.transformer.attn_layers.layers[i] = nn.ModuleList(
[n, CheckpointedLayer(b), r])
def forward(self, x, **kwargs):
return self.transformer(x, **kwargs)
class Wav2VecMatcher(nn.Module):
W2V_COMPRESSION=320
W2V_COMPRESSION = 320
def __init__(self,
model_dim,
@ -56,41 +60,42 @@ class Wav2VecMatcher(nn.Module):
WAV2VEC_CHANNELS = 1024
self.conditioning_encoder = AudioMiniEncoder(1, model_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
# nn.Embedding
self.text_embedding = ml.Embedding(num_text_tokens, model_dim)
self.encoder = CheckpointedXTransformer(
max_seq_len=-1,
use_pos_emb=False,
attn_layers=Encoder(
dim=model_dim,
depth=encoder_depth,
heads=model_dim//64,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_emb_dim=True,
)
max_seq_len=-1,
use_pos_emb=False,
attn_layers=Encoder(
dim=model_dim,
depth=encoder_depth,
heads=model_dim//64,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_emb_dim=True,
)
self.decoder_start_embedding = nn.Parameter(torch.randn(1,1,model_dim))
self.decoder_stop_embedding = nn.Parameter(torch.randn(1,model_dim))
)
self.decoder_start_embedding = nn.Parameter(
torch.randn(1, 1, model_dim))
self.decoder_stop_embedding = nn.Parameter(torch.randn(1, model_dim))
self.w2v_query_encoder = ml.Linear(WAV2VEC_CHANNELS, model_dim)
self.w2v_value_encoder = ml.Linear(WAV2VEC_CHANNELS, model_dim)
self.decoder = CheckpointedXTransformer(
max_seq_len=-1, # Should be unused
use_pos_emb=False,
attn_layers=Decoder(
dim=model_dim,
depth=decoder_depth,
heads=model_dim//64,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
cross_attend=True,
)
max_seq_len=-1, # Should be unused
use_pos_emb=False,
attn_layers=Decoder(
dim=model_dim,
depth=decoder_depth,
heads=model_dim//64,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
cross_attend=True,
)
)
def get_grad_norm_parameter_groups(self):
@ -111,28 +116,32 @@ class Wav2VecMatcher(nn.Module):
enc_inputs = torch.cat([cond_emb.unsqueeze(1), text_emb], dim=1)
dec_context = self.encoder(enc_inputs)
w2v_values = self.w2v_value_encoder(w2v_logits)
dec_inputs = torch.cat([self.decoder_start_embedding.repeat(w2v_values.shape[0],1,1), w2v_values], dim=1)
dec_inputs = torch.cat([self.decoder_start_embedding.repeat(
w2v_values.shape[0], 1, 1), w2v_values], dim=1)
dec_out = self.decoder(dec_inputs, context=dec_context)[:, :-1]
w2v_queries = self.w2v_query_encoder(w2v_logits)
# Compute losses, A CLIP-like dot product matcher and a mechanism to force pad prediction.
b,l,c = dec_out.shape
b, l, c = dec_out.shape
keys_uncompressed = dec_out.reshape(b*l, c)
queries_uncompressed = w2v_queries.reshape(b*l, c)
dot = torch.einsum("i c, j c -> i j", keys_uncompressed, queries_uncompressed)
dot = torch.einsum("i c, j c -> i j",
keys_uncompressed, queries_uncompressed)
labels = torch.arange(0, b*l, 1, device=dot.device)
ce_loss1 = F.cross_entropy(dot, labels, reduction="none")
ce_loss2 = F.cross_entropy(dot.t(), labels, reduction="none")
mse_pad_loss = F.mse_loss(keys_uncompressed, self.decoder_stop_embedding.repeat(b*l,1), reduction="none").sum(dim=-1)
mse_pad_loss = F.mse_loss(keys_uncompressed, self.decoder_stop_embedding.repeat(
b*l, 1), reduction="none").sum(dim=-1)
# Create a mask based on w2v_lengths that will be used to ensure the encodings of padding tokens are not considered in the cross entropy loss
loss_mask = torch.ones((b,l), device=ce_loss1.device)
loss_mask = torch.ones((b, l), device=ce_loss1.device)
w2v_lengths = clip_lengths // self.W2V_COMPRESSION
for i in range(b):
loss_mask[i, w2v_lengths[i]:] = 0
loss_mask_collapsed = loss_mask.reshape(b*l)
ce_loss = (ce_loss1 * loss_mask_collapsed + ce_loss2 * loss_mask_collapsed).mean()
ce_loss = (ce_loss1 * loss_mask_collapsed +
ce_loss2 * loss_mask_collapsed).mean()
mse_loss = (mse_pad_loss * (loss_mask_collapsed == 0)).mean()
return ce_loss, mse_loss
@ -151,15 +160,18 @@ class Wav2VecMatcher(nn.Module):
dec_out = self.decoder(dec_inputs, context=dec_context)
# Check if that was EOS.
l2 = F.mse_loss(dec_out[:,-1], self.decoder_stop_embedding)
l2 = F.mse_loss(dec_out[:, -1], self.decoder_stop_embedding)
if l2 < .1: # TODO: fix threshold.
break
# Find a matching w2v logit from the given iterable.
matching_logit_index = self.find_matching_w2v_logit(dec_out[:,-1], w2v_logit_iterable)
matching_logit_index = self.find_matching_w2v_logit(
dec_out[:, -1], w2v_logit_iterable)
matching_logit = w2v_logit_iterable[matching_logit_index]
dec_inputs = torch.cat([dec_inputs, self.w2v_value_encoder(matching_logit).unsqueeze(1)], dim=1)
produced_audio = torch.cat([produced_audio, audio_clip_iterable[matching_logit_index]], dim=-1)
dec_inputs = torch.cat(
[dec_inputs, self.w2v_value_encoder(matching_logit).unsqueeze(1)], dim=1)
produced_audio = torch.cat(
[produced_audio, audio_clip_iterable[matching_logit_index]], dim=-1)
return produced_audio
@ -170,9 +182,9 @@ def register_w2v_matcher(opt_net, opt):
if __name__ == '__main__':
model = Wav2VecMatcher(512, 8, 8)
toks = torch.randint(0, 100, (4,100))
tok_lens = torch.tensor([50,60,70,80])
cond = torch.randn(4,1,44000)
logits = torch.randn(4,120,1024)
logit_lens = torch.tensor([60,70,80,90])
model(toks, cond, logits, tok_lens, logit_lens)
toks = torch.randint(0, 100, (4, 100))
tok_lens = torch.tensor([50, 60, 70, 80])
cond = torch.randn(4, 1, 44000)
logits = torch.randn(4, 120, 1024)
logit_lens = torch.tensor([60, 70, 80, 90])
model(toks, cond, logits, tok_lens, logit_lens)

@ -2,8 +2,8 @@ import torch
import torch.nn as nn
from omegaconf import OmegaConf
from models.audio.vocoders.univnet.lvcnet import LVCBlock
from trainer.networks import register_model
from dlas.models.audio.vocoders.univnet.lvcnet import LVCBlock
from dlas.trainer.networks import register_model
MAX_WAV_VALUE = 32768.0
@ -11,7 +11,7 @@ MAX_WAV_VALUE = 32768.0
class UnivNetGenerator(nn.Module):
"""UnivNet Generator"""
def __init__(self, noise_dim=64, channel_size=32, dilations=[1,3,9,27], strides=[8,8,4], lReLU_slope=.2, kpnet_conv_size=3,
def __init__(self, noise_dim=64, channel_size=32, dilations=[1, 3, 9, 27], strides=[8, 8, 4], lReLU_slope=.2, kpnet_conv_size=3,
# Below are MEL configurations options that this generator requires.
hop_length=256, n_mel_channels=100):
super(UnivNetGenerator, self).__init__()
@ -38,11 +38,13 @@ class UnivNetGenerator(nn.Module):
)
self.conv_pre = \
nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode='reflect'))
nn.utils.weight_norm(
nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode='reflect'))
self.conv_post = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode='reflect')),
nn.utils.weight_norm(
nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode='reflect')),
nn.Tanh(),
)
@ -84,11 +86,13 @@ class UnivNetGenerator(nn.Module):
def inference(self, c, z=None):
# pad input mel with zeros to cut artifact
# see https://github.com/seungwonpark/melgan/issues/8
zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
zero = torch.full(
(c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
mel = torch.cat((c, zero), dim=2)
if z is None:
z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
z = torch.randn(c.shape[0], self.noise_dim,
mel.size(2)).to(mel.device)
audio = self.forward(mel, z)
audio = audio[:, :, :-(self.hop_length * 10)]
@ -112,5 +116,6 @@ if __name__ == '__main__':
print(y.shape)
assert y.shape == torch.Size([3, 1, 2560])
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
pytorch_total_params = sum(p.numel()
for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)

@ -35,12 +35,15 @@ class KernelPredictor(torch.nn.Module):
self.conv_kernel_size = conv_kernel_size
self.conv_layers = conv_layers
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
kpnet_kernel_channels = conv_in_channels * \
conv_out_channels * conv_kernel_size * conv_layers # l_w
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
self.input_conv = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
nn.utils.weight_norm(
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
getattr(nn, kpnet_nonlinear_activation)(
**kpnet_nonlinear_activation_params),
)
self.residual_convs = nn.ModuleList()
@ -52,11 +55,13 @@ class KernelPredictor(torch.nn.Module):
nn.utils.weight_norm(
nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
bias=True)),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
getattr(nn, kpnet_nonlinear_activation)(
**kpnet_nonlinear_activation_params),
nn.utils.weight_norm(
nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
bias=True)),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
getattr(nn, kpnet_nonlinear_activation)(
**kpnet_nonlinear_activation_params),
)
)
self.kernel_conv = nn.utils.weight_norm(
@ -170,7 +175,8 @@ class LVCBlock(torch.nn.Module):
for i, conv in enumerate(self.conv_blocks):
output = conv(x) # (B, c_g, stride * L')
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
# (B, 2 * c_g, c_g, kernel_size, cond_length)
k = kernels[:, i, :, :, :, :]
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
output = self.location_variable_convolution(output, k, b,
@ -194,23 +200,29 @@ class LVCBlock(torch.nn.Module):
'''
batch, _, in_length = x.shape
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
assert in_length == (
kernel_length * hop_size), "length of (x, kernel) is not matched"
padding = dilation * int((kernel_size - 1) / 2)
x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding)
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
# (batch, in_channels, in_length + 2*padding)
x = F.pad(x, (padding, padding), 'constant', 0)
# (batch, in_channels, kernel_length, hop_size + 2*padding)
x = x.unfold(2, hop_size + 2 * padding, hop_size)
if hop_size < dilation:
x = F.pad(x, (0, dilation), 'constant', 0)
x = x.unfold(3, dilation,
dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
x = x[:, :, :, :, :hop_size]
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
# (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
x = x.transpose(3, 4)
# (batch, in_channels, kernel_length, dilation, _, kernel_size)
x = x.unfold(4, kernel_size, 1)
o = torch.einsum('bildsk,biokl->bolsd', x, kernel)
o = o.to(memory_format=torch.channels_last_3d)
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
bias = bias.unsqueeze(-1).unsqueeze(-1).to(
memory_format=torch.channels_last_3d)
o = o + bias
o = o.contiguous().view(batch, out_channels, -1)
@ -220,4 +232,4 @@ class LVCBlock(torch.nn.Module):
self.kernel_predictor.remove_weight_norm()
nn.utils.remove_weight_norm(self.convt_pre[1])
for block in self.conv_blocks:
nn.utils.remove_weight_norm(block[1])
nn.utils.remove_weight_norm(block[1])

@ -1,9 +1,10 @@
import sys
from models.audio.tts.tacotron2.stft import STFT
import torch
from dlas.models.audio.tts.tacotron2.stft import STFT
sys.path.append('tacotron2')
import torch
class Denoiser(torch.nn.Module):

@ -25,11 +25,12 @@
#
# *****************************************************************************
import copy
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch.autograd import Variable
from trainer.networks import register_model
from dlas.trainer.networks import register_model
@torch.jit.script
@ -57,7 +58,8 @@ class WaveGlowLoss(torch.nn.Module):
log_s_total = log_s_total + torch.sum(log_s)
log_det_W_total += log_det_W_list[i]
loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total
loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - \
log_s_total - log_det_W_total
return loss/(z.size(0)*z.size(1)*z.size(2))
@ -67,6 +69,7 @@ class Invertible1x1Conv(torch.nn.Module):
of its weight matrix. If reverse=True it does convolution with
inverse
"""
def __init__(self, c):
super(Invertible1x1Conv, self).__init__()
self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
@ -77,7 +80,7 @@ class Invertible1x1Conv(torch.nn.Module):
# Ensure determinant is 1.0 not -1.0
if torch.det(W) < 0:
W[:,0] = -1*W[:,0]
W[:, 0] = -1*W[:, 0]
W = W.view(c, c, 1)
self.conv.weight.data = W
@ -110,11 +113,12 @@ class WN(torch.nn.Module):
from WaveNet is the convolutions need not be causal. There is also no dilation
size reset. The dilation only doubles on each layer
"""
def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
kernel_size):
super(WN, self).__init__()
assert(kernel_size % 2 == 1)
assert(n_channels % 2 == 0)
assert (kernel_size % 2 == 1)
assert (n_channels % 2 == 0)
self.n_layers = n_layers
self.n_channels = n_channels
self.in_layers = torch.nn.ModuleList()
@ -142,14 +146,14 @@ class WN(torch.nn.Module):
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
self.in_layers.append(in_layer)
# last one is not necessary
if i < n_layers - 1:
res_skip_channels = 2*n_channels
else:
res_skip_channels = n_channels
res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
res_skip_layer = torch.nn.utils.weight_norm(
res_skip_layer, name='weight')
self.res_skip_layers.append(res_skip_layer)
def forward(self, forward_input):
@ -164,13 +168,13 @@ class WN(torch.nn.Module):
spect_offset = i*2*self.n_channels
acts = fused_add_tanh_sigmoid_multiply(
self.in_layers[i](audio),
spect[:,spect_offset:spect_offset+2*self.n_channels,:],
spect[:, spect_offset:spect_offset+2*self.n_channels, :],
n_channels_tensor)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
audio = audio + res_skip_acts[:,:self.n_channels,:]
output = output + res_skip_acts[:,self.n_channels:,:]
audio = audio + res_skip_acts[:, :self.n_channels, :]
output = output + res_skip_acts[:, self.n_channels:, :]
else:
output = output + res_skip_acts
@ -185,7 +189,7 @@ class WaveGlow(torch.nn.Module):
self.upsample = torch.nn.ConvTranspose1d(n_mel_channels,
n_mel_channels,
1024, stride=256)
assert(n_group % 2 == 0)
assert (n_group % 2 == 0)
self.n_flows = n_flows
self.n_group = n_group
self.n_early_every = n_early_every
@ -215,7 +219,7 @@ class WaveGlow(torch.nn.Module):
# Upsample spectrogram to size of audio
spect = self.upsample(spect)
assert(spect.size(2) >= audio.size(1))
assert (spect.size(2) >= audio.size(1))
if spect.size(2) > audio.size(1):
spect = spect[:, :, :audio.size(1)]
@ -229,15 +233,15 @@ class WaveGlow(torch.nn.Module):
for k in range(self.n_flows):
if k % self.n_early_every == 0 and k > 0:
output_audio.append(audio[:,:self.n_early_size,:])
audio = audio[:,self.n_early_size:,:]
output_audio.append(audio[:, :self.n_early_size, :])
audio = audio[:, self.n_early_size:, :]
audio, log_det_W = self.convinv[k](audio)
log_det_W_list.append(log_det_W)
n_half = int(audio.size(1)/2)
audio_0 = audio[:,:n_half,:]
audio_1 = audio[:,n_half:,:]
audio_0 = audio[:, :n_half, :]
audio_1 = audio[:, n_half:, :]
output = self.WN[k]((audio_0, spect))
log_s = output[:, n_half:, :]
@ -245,10 +249,10 @@ class WaveGlow(torch.nn.Module):
audio_1 = torch.exp(log_s)*audio_1 + b
log_s_list.append(log_s)
audio = torch.cat([audio_0, audio_1],1)
audio = torch.cat([audio_0, audio_1], 1)
output_audio.append(audio)
return torch.cat(output_audio,1), log_s_list, log_det_W_list
return torch.cat(output_audio, 1), log_s_list, log_det_W_list
def infer(self, spect, sigma=1.0):
spect = self.upsample(spect)
@ -272,26 +276,29 @@ class WaveGlow(torch.nn.Module):
for k in reversed(range(self.n_flows)):
n_half = int(audio.size(1)/2)
audio_0 = audio[:,:n_half,:]
audio_1 = audio[:,n_half:,:]
audio_0 = audio[:, :n_half, :]
audio_1 = audio[:, n_half:, :]
output = self.WN[k]((audio_0, spect))
s = output[:, n_half:, :]
b = output[:, :n_half, :]
audio_1 = (audio_1 - b)/torch.exp(s)
audio = torch.cat([audio_0, audio_1],1)
audio = torch.cat([audio_0, audio_1], 1)
audio = self.convinv[k](audio, reverse=True)
if k % self.n_early_every == 0 and k > 0:
if spect.type() == 'torch.cuda.HalfTensor':
z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
z = torch.cuda.HalfTensor(spect.size(
0), self.n_early_size, spect.size(2)).normal_()
else:
z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
audio = torch.cat((sigma*z, audio),1)
z = torch.cuda.FloatTensor(spect.size(
0), self.n_early_size, spect.size(2)).normal_()
audio = torch.cat((sigma*z, audio), 1)
audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
audio = audio.permute(0, 2, 1).contiguous().view(
audio.size(0), -1).data
return audio
@staticmethod
@ -315,4 +322,4 @@ def remove(conv_list):
@register_model
def register_nv_waveglow(opt_net, opt):
return WaveGlow(**opt_net['args'])
return WaveGlow(**opt_net['args'])

@ -8,11 +8,11 @@
https://arxiv.org/abs/1512.03385v1
"""
import torch
import torch.nn as nn
import torch_intermediary as ml
from trainer.networks import register_model
import dlas.torch_intermediary as ml
from dlas.trainer.networks import register_model
class BasicBlock(nn.Module):
@ -20,53 +20,60 @@ class BasicBlock(nn.Module):
"""
#BasicBlock and BottleNeck block
#have different output size
#we use class attribute expansion
#to distinct
# BasicBlock and BottleNeck block
# have different output size
# we use class attribute expansion
# to distinct
expansion = 1
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
#residual function
# residual function
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion,
kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
)
#shortcut
# shortcut
self.shortcut = nn.Sequential()
#the shortcut output dimension is not the same with residual function
#use 1*1 convolution to match the dimension
# the shortcut output dimension is not the same with residual function
# use 1*1 convolution to match the dimension
if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
)
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class BottleNeck(nn.Module):
"""Residual block for resnet over 50 layers
"""
expansion = 4
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
nn.Conv2d(out_channels, out_channels, stride=stride,
kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
nn.Conv2d(out_channels, out_channels *
BottleNeck.expansion, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels * BottleNeck.expansion),
)
@ -74,13 +81,15 @@ class BottleNeck(nn.Module):
if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
nn.Conv2d(in_channels, out_channels * BottleNeck.expansion,
stride=stride, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels * BottleNeck.expansion)
)
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class ResNet(nn.Module):
def __init__(self, block, num_block, num_classes=100):
@ -92,8 +101,8 @@ class ResNet(nn.Module):
nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True))
#we use a different inputsize than the original paper
#so conv2_x's stride is 1
# we use a different inputsize than the original paper
# so conv2_x's stride is 1
self.conv2_x = self._make_layer(block, 32, num_block[0], 1)
self.conv3_x = self._make_layer(block, 64, num_block[1], 2)
self.conv4_x = self._make_layer(block, 128, num_block[2], 2)
@ -138,30 +147,33 @@ class ResNet(nn.Module):
return output
@register_model
def register_cifar_resnet18(opt_net, opt):
""" return a ResNet 18 object
"""
return ResNet(BasicBlock, [2, 2, 2, 2])
def resnet34():
""" return a ResNet 34 object
"""
return ResNet(BasicBlock, [3, 4, 6, 3])
def resnet50():
""" return a ResNet 50 object
"""
return ResNet(BottleNeck, [3, 4, 6, 3])
def resnet101():
""" return a ResNet 101 object
"""
return ResNet(BottleNeck, [3, 4, 23, 3])
def resnet152():
""" return a ResNet 152 object
"""
return ResNet(BottleNeck, [3, 8, 36, 3])

@ -1,18 +1,17 @@
# A direct copy of torchvision's resnet.py modified to support gradient checkpointing.
import torch
import torch.nn as nn
from torchvision.models.resnet import BasicBlock, Bottleneck
import torchvision
import torch_intermediary as ml
from torchvision.models.resnet import BasicBlock, Bottleneck
import dlas.torch_intermediary as ml
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2']
from trainer.networks import register_model
from utils.util import checkpoint
from dlas.trainer.networks import register_model
from dlas.utils.util import checkpoint
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',

Some files were not shown because too many files have changed in this diff Show More