diff --git a/.gitignore b/.gitignore index 531dd7d9..69b3218b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,30 +1,3 @@ -dlas/experiments/* -dlas/codes/*.txt -dlas/codes/wandb/* -dlas/codes/pretrained_models/* -dlas/codes/scripts/audio/pretrained_models/* - -results/* -tb_logger/* -datasets/* -options/* -data/* -.vscode - -*.html -*.png -*.jpg -*.gif -*.pth -*.pytorch -*.zip -*.cu -*.pt -*.pth -*.pdf -*.tsv - -# template # Byte-compiled / optimized / DLL files __pycache__/ @@ -36,6 +9,7 @@ __pycache__/ # Distribution / packaging .Python +env/ build/ develop-eggs/ dist/ @@ -47,12 +21,9 @@ lib64/ parts/ sdist/ var/ -wheels/ -pretrained/* *.egg-info/ .installed.cfg *.egg -MANIFEST # PyInstaller # Usually these files are written by a python script from a template @@ -72,9 +43,8 @@ htmlcov/ .cache nosetests.xml coverage.xml -*.cover +*,cover .hypothesis/ -.pytest_cache/ # Translations *.mo @@ -83,14 +53,6 @@ coverage.xml # Django stuff: *.log local_settings.py -db.sqlite3 - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy # Sphinx documentation docs/_build/ @@ -98,36 +60,8 @@ docs/_build/ # PyBuilder target/ -# Jupyter Notebook +#Ipython Notebook .ipynb_checkpoints # pyenv .python-version - -# celery beat schedule file -celerybeat-schedule - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ diff --git a/MANIFEST.in b/MANIFEST.in index 1085ae5e..72a4f936 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1 @@ -recursive-include codes/* +recursive-include dlas/* diff --git a/dlas/data/__init__.py b/dlas/data/__init__.py index 81a915ad..a65d453b 100644 --- a/dlas/data/__init__.py +++ b/dlas/data/__init__.py @@ -3,7 +3,7 @@ import torch import torch.utils.data from munch import munchify -from utils.util import opt_get +from dlas.utils.util import opt_get def create_dataloader(dataset, dataset_opt, opt=None, sampler=None, collate_fn=None, shuffle=True): @@ -33,77 +33,90 @@ def create_dataset(dataset_opt, return_collate=False): # datasets for image restoration if mode == 'fullimage': - from data.images.full_image_dataset import FullImageDataset as D + from dlas.data.images.full_image_dataset import FullImageDataset as D elif mode == 'single_image_extensible': - from data.images.single_image_dataset import SingleImageDataset as D + from dlas.data.images.single_image_dataset import \ + SingleImageDataset as D elif mode == 'multi_frame_extensible': - from data.images.multi_frame_dataset import MultiFrameDataset as D + from dlas.data.images.multi_frame_dataset import MultiFrameDataset as D elif mode == 'combined': - from data.combined_dataset import CombinedDataset as D + from dlas.data.combined_dataset import CombinedDataset as D elif mode == 'multiscale': - from data.images.multiscale_dataset import MultiScaleDataset as D + from dlas.data.images.multiscale_dataset import MultiScaleDataset as D elif mode == 'paired_frame': - from data.images.paired_frame_dataset import PairedFrameDataset as D + from dlas.data.images.paired_frame_dataset import \ + PairedFrameDataset as D elif mode == 'stylegan2': - from data.images.stylegan2_dataset import Stylegan2Dataset as D + from dlas.data.images.stylegan2_dataset import Stylegan2Dataset as D elif mode == 'imagefolder': - from data.images.image_folder_dataset import ImageFolderDataset as D + from dlas.data.images.image_folder_dataset import \ + ImageFolderDataset as D elif mode == 'torch_dataset': from data.torch_dataset import TorchDataset as D elif mode == 'byol_dataset': - from data.images.byol_attachment import ByolDatasetWrapper as D + from dlas.data.images.byol_attachment import ByolDatasetWrapper as D elif mode == 'byol_structured_dataset': - from data.images.byol_attachment import StructuredCropDatasetWrapper as D + from dlas.data.images.byol_attachment import \ + StructuredCropDatasetWrapper as D elif mode == 'random_aug_wrapper': - from data.images.byol_attachment import DatasetRandomAugWrapper as D + from dlas.data.images.byol_attachment import \ + DatasetRandomAugWrapper as D elif mode == 'random_dataset': - from data.images.random_dataset import RandomDataset as D + from dlas.data.images.random_dataset import RandomDataset as D elif mode == 'zipfile': - from data.images.zip_file_dataset import ZipFileDataset as D + from dlas.data.images.zip_file_dataset import ZipFileDataset as D elif mode == 'nv_tacotron': - from data.audio.nv_tacotron_dataset import TextWavLoader as D - from data.audio.nv_tacotron_dataset import TextMelCollate as C - from models.audio.tts.tacotron2 import create_hparams + from dlas.data.audio.nv_tacotron_dataset import TextMelCollate as C + from dlas.data.audio.nv_tacotron_dataset import TextWavLoader as D + from dlas.models.audio.tts.tacotron2 import create_hparams default_params = create_hparams() default_params.update(dataset_opt) dataset_opt = munchify(default_params) if opt_get(dataset_opt, ['needs_collate'], True): collate = C() elif mode == 'paired_voice_audio': - from data.audio.paired_voice_audio_dataset import TextWavLoader as D - from models.audio.tts.tacotron2 import create_hparams + from dlas.data.audio.paired_voice_audio_dataset import \ + TextWavLoader as D + from dlas.models.audio.tts.tacotron2 import create_hparams default_params = create_hparams() default_params.update(dataset_opt) dataset_opt = munchify(default_params) elif mode == 'fast_paired_voice_audio': - from data.audio.fast_paired_dataset import FastPairedVoiceDataset as D - from models.audio.tts.tacotron2 import create_hparams + from dlas.data.audio.fast_paired_dataset import \ + FastPairedVoiceDataset as D + from dlas.models.audio.tts.tacotron2 import create_hparams default_params = create_hparams() default_params.update(dataset_opt) dataset_opt = munchify(default_params) elif mode == 'fast_paired_voice_audio_with_phonemes': - from data.audio.fast_paired_dataset_with_phonemes import FastPairedVoiceDataset as D - from models.audio.tts.tacotron2 import create_hparams + from dlas.data.audio.fast_paired_dataset_with_phonemes import \ + FastPairedVoiceDataset as D + from dlas.models.audio.tts.tacotron2 import create_hparams default_params = create_hparams() default_params.update(dataset_opt) dataset_opt = munchify(default_params) elif mode == 'gpt_tts': - from data.audio.gpt_tts_dataset import GptTtsDataset as D - from data.audio.gpt_tts_dataset import GptTtsCollater as C + from dlas.data.audio.gpt_tts_dataset import GptTtsCollater as C + from dlas.data.audio.gpt_tts_dataset import GptTtsDataset as D collate = C(dataset_opt) elif mode == 'unsupervised_audio': - from data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset as D + from dlas.data.audio.unsupervised_audio_dataset import \ + UnsupervisedAudioDataset as D elif mode == 'unsupervised_audio_with_noise': - from data.audio.audio_with_noise_dataset import AudioWithNoiseDataset as D + from dlas.data.audio.audio_with_noise_dataset import \ + AudioWithNoiseDataset as D elif mode == 'preprocessed_mel': - from data.audio.preprocessed_mel_dataset import PreprocessedMelDataset as D + from dlas.data.audio.preprocessed_mel_dataset import \ + PreprocessedMelDataset as D elif mode == 'grand_conjoined_voice': - from data.audio.grand_conjoined_dataset import GrandConjoinedDataset as D - from data.zero_pad_dict_collate import ZeroPadDictCollate as C + from dlas.data.audio.grand_conjoined_dataset import \ + GrandConjoinedDataset as D + from dlas.data.zero_pad_dict_collate import ZeroPadDictCollate as C if opt_get(dataset_opt, ['needs_collate'], False): collate = C() else: - raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) + raise NotImplementedError( + 'Dataset [{:s}] is not recognized.'.format(mode)) dataset = D(dataset_opt) if return_collate: @@ -115,9 +128,10 @@ def create_dataset(dataset_opt, return_collate=False): def get_dataset_debugger(dataset_opt): mode = dataset_opt['mode'] if mode == 'paired_voice_audio': - from data.audio.paired_voice_audio_dataset import PairedVoiceDebugger + from dlas.data.audio.paired_voice_audio_dataset import \ + PairedVoiceDebugger return PairedVoiceDebugger() elif mode == 'fast_paired_voice_audio': - from data.audio.fast_paired_dataset import FastPairedVoiceDebugger + from dlas.data.audio.fast_paired_dataset import FastPairedVoiceDebugger return FastPairedVoiceDebugger() - return None \ No newline at end of file + return None diff --git a/dlas/data/audio/__init__.py b/dlas/data/audio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dlas/data/audio/audio_with_noise_dataset.py b/dlas/data/audio/audio_with_noise_dataset.py index 32d6e67c..16eb69eb 100644 --- a/dlas/data/audio/audio_with_noise_dataset.py +++ b/dlas/data/audio/audio_with_noise_dataset.py @@ -2,18 +2,17 @@ import random import sys from math import pi -import librosa import torch +import torch.nn.functional as F import torchaudio from torch.utils.data import Dataset from tqdm import tqdm -import torch.nn.functional as F -from data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset, load_audio -from data.util import load_paths_from_cache, find_files_of_type, is_audio_file - -# Just all ones. -from utils.util import opt_get +from dlas.data.audio.unsupervised_audio_dataset import ( + UnsupervisedAudioDataset, load_audio) +from dlas.data.util import (find_files_of_type, is_audio_file, + load_paths_from_cache) +from dlas.utils.util import opt_get def _integration_fn_fully_enabled(n): @@ -23,7 +22,7 @@ def _integration_fn_fully_enabled(n): # Randomly assigns up to 5 blocks of the output tensor the value '1'. Rest is zero def _integration_fn_spiky(n): fn = torch.zeros((n,)) - spikes = random.randint(1,5) + spikes = random.randint(1, 5) for _ in range(spikes): sz = random.randint(n//8, n//2) pos = random.randint(0, n) @@ -35,18 +34,19 @@ def _integration_fn_spiky(n): # Uses a sinusoidal ramp up and down (of random length) to a peak which is held for a random duration. def _integration_fn_smooth(n): center = random.randint(1, n-2) - max_duration=n-center-1 + max_duration = n-center-1 duration = random.randint(max_duration//4, max_duration) end = center+duration - ramp_up_sz = random.randint(n//16,n//4) - ramp_up = torch.sin(pi*torch.arange(0,ramp_up_sz)/(2*ramp_up_sz)) + ramp_up_sz = random.randint(n//16, n//4) + ramp_up = torch.sin(pi*torch.arange(0, ramp_up_sz)/(2*ramp_up_sz)) if ramp_up_sz > center: ramp_up = ramp_up[(ramp_up_sz-center):] ramp_up_sz = center - ramp_down_sz = random.randint(n//16,n//4) - ramp_down = torch.flip(torch.sin(pi*torch.arange(0,ramp_down_sz)/(2*ramp_down_sz)), dims=[0]) + ramp_down_sz = random.randint(n//16, n//4) + ramp_down = torch.flip( + torch.sin(pi*torch.arange(0, ramp_down_sz)/(2*ramp_down_sz)), dims=[0]) if ramp_down_sz > (n-end): ramp_down = ramp_down[:(n-end)] ramp_down_sz = n-end @@ -71,16 +71,22 @@ def load_rir(path, sr, max_sz): Wraps a unsupervised_audio_dataset and applies noise to the output clips, then provides labels depending on what noise was added. ''' + + class AudioWithNoiseDataset(Dataset): def __init__(self, opt): self.underlying_dataset = UnsupervisedAudioDataset(opt) - self.env_noise_paths = load_paths_from_cache(opt['env_noise_paths'], opt['env_noise_cache']) - self.music_paths = load_paths_from_cache(opt['music_paths'], opt['music_cache']) - self.openair_paths = find_files_of_type('img', opt['openair_path'], qualifier=is_audio_file)[0] + self.env_noise_paths = load_paths_from_cache( + opt['env_noise_paths'], opt['env_noise_cache']) + self.music_paths = load_paths_from_cache( + opt['music_paths'], opt['music_cache']) + self.openair_paths = find_files_of_type( + 'img', opt['openair_path'], qualifier=is_audio_file)[0] self.min_volume = opt_get(opt, ['min_noise_volume'], .2) self.max_volume = opt_get(opt, ['max_noise_volume'], .5) self.sampling_rate = self.underlying_dataset.sampling_rate - self.use_gpu_for_reverb_compute = opt_get(opt, ['use_gpu_for_reverb_compute'], True) + self.use_gpu_for_reverb_compute = opt_get( + opt, ['use_gpu_for_reverb_compute'], True) self.openair_kernels = None self.current_item_fetch = 0 self.fetch_error_count = 0 @@ -90,7 +96,8 @@ class AudioWithNoiseDataset(Dataset): # Load the openair reverbs as CUDA tensors. self.openair_kernels = [] for oa in self.openair_paths: - self.openair_kernels.append(load_rir(oa, self.underlying_dataset.sampling_rate, self.underlying_dataset.sampling_rate*2).cuda()) + self.openair_kernels.append(load_rir( + oa, self.underlying_dataset.sampling_rate, self.underlying_dataset.sampling_rate*2).cuda()) def __getitem__(self, item): if self.current_item_fetch != item: @@ -113,10 +120,11 @@ class AudioWithNoiseDataset(Dataset): clip = clip * clipvol label = random.randint(0, 4) # Current excludes GSM corruption. - #label = 3 + # label = 3 if label > 0 and label < 4: # 0 is basically "leave it alone" aug_needed = True - augvol = (random.random() * (self.max_volume-self.min_volume) + self.min_volume) + augvol = (random.random() * (self.max_volume - + self.min_volume) + self.min_volume) if label == 1: # Add environmental noise. augpath = random.choice(self.env_noise_paths) @@ -131,13 +139,15 @@ class AudioWithNoiseDataset(Dataset): # This can take two forms: if padding_room < 22000 or random.random() < .5: # (1) The voices talk over one another. If there is no padding room, we always take this choice. - intg_fns = [_integration_fn_smooth, _integration_fn_fully_enabled] + intg_fns = [_integration_fn_smooth, + _integration_fn_fully_enabled] else: # (2) There are simply two voices in the clip, separated from one another. # This is a special case that does not use the same logic as the rest of the augmentations. - aug = load_audio(augpath, self.underlying_dataset.sampling_rate) + aug = load_audio( + augpath, self.underlying_dataset.sampling_rate) # Pad with some random silence - aug = F.pad(aug, (random.randint(20,4000), 0)) + aug = F.pad(aug, (random.randint(20, 4000), 0)) # Fit what we can given the padding room we have. aug = aug[:, :padding_room] clip = torch.cat([clip, aug], dim=1) @@ -146,7 +156,8 @@ class AudioWithNoiseDataset(Dataset): out['clip_lengths'] = torch.tensor(clip.shape[-1]) aug_needed = False if aug_needed: - aug = load_audio(augpath, self.underlying_dataset.sampling_rate) + aug = load_audio( + augpath, self.underlying_dataset.sampling_rate) if aug.shape[1] > clip.shape[1]: n, cn = aug.shape[1], clip.shape[1] gap = n-cn @@ -157,7 +168,8 @@ class AudioWithNoiseDataset(Dataset): if aug.shape[1] < clip.shape[1]: gap = clip.shape[1] - aug.shape[1] placement = random.randint(0, gap-1) - aug = torch.nn.functional.pad(aug, (placement, gap-placement)) + aug = torch.nn.functional.pad( + aug, (placement, gap-placement)) clip = clip + aug elif label == 4: # Perform reverb (to simulate being in a large room with an omni-mic). This is performed by convolving @@ -166,19 +178,23 @@ class AudioWithNoiseDataset(Dataset): rir = random.choice(self.openair_kernels) else: augpath = random.choice(self.openair_paths) - rir = load_rir(augpath, self.underlying_dataset.sampling_rate, clip.shape[-1]) + rir = load_rir( + augpath, self.underlying_dataset.sampling_rate, clip.shape[-1]) clip = torch.nn.functional.pad(clip, (rir.shape[1]-1, 0)) if self.use_gpu_for_reverb_compute: clip = clip.cuda() - clip = torch.nn.functional.conv1d(clip.unsqueeze(0), rir.unsqueeze(0)).squeeze(0).cpu() + clip = torch.nn.functional.conv1d( + clip.unsqueeze(0), rir.unsqueeze(0)).squeeze(0).cpu() elif label == 5: # Apply the GSM codec to simulate cellular phone audio. - clip = torchaudio.functional.apply_codec(clip, self.underlying_dataset.sampling_rate, format="gsm") + clip = torchaudio.functional.apply_codec( + clip, self.underlying_dataset.sampling_rate, format="gsm") except: if self.fetch_error_count > 10: - print(f"Exception encountered processing {item}, re-trying because this is often just a failed aug.") + print( + f"Exception encountered processing {item}, re-trying because this is often just a failed aug.") print(sys.exc_info()) - #raise # Uncomment to surface exceptions. + # raise # Uncomment to surface exceptions. self.fetch_error_count += 1 return self[item] @@ -187,7 +203,7 @@ class AudioWithNoiseDataset(Dataset): clip = F.pad(clip, (0, padding_room)) out['clip'] = clip out['label'] = label - #out['aug'] = aug + # out['aug'] = aug out['augpath'] = augpath out['augvol'] = augvol out['clipvol'] = clipvol @@ -216,14 +232,15 @@ if __name__ == '__main__': 'openair_path': 'D:\\data\\audio\\openair\\resampled', 'use_gpu_for_reverb_compute': False, } - from data import create_dataset, create_dataloader, util + from data import create_dataloader, create_dataset, util ds = create_dataset(params) dl = create_dataloader(ds, params, pin_memory=False) i = 0 for b in tqdm(dl): for b_ in range(b['clip'].shape[0]): - #torchaudio.save(f'{i}_clip_{b_}_{b["label"][b_].item()}.wav', b['clip'][b_][:, :b['clip_lengths'][b_]], ds.sampling_rate) - #torchaudio.save(f'{i}_clip_{b_}_aug.wav', b['aug'][b_], ds.sampling_rate) - print(f'{i} aug path: {b["augpath"][b_]} aug volume: {b["augvol"][b_]} clip volume: {b["clipvol"][b_]}') + # torchaudio.save(f'{i}_clip_{b_}_{b["label"][b_].item()}.wav', b['clip'][b_][:, :b['clip_lengths'][b_]], ds.sampling_rate) + # torchaudio.save(f'{i}_clip_{b_}_aug.wav', b['aug'][b_], ds.sampling_rate) + print( + f'{i} aug path: {b["augpath"][b_]} aug volume: {b["augvol"][b_]} clip volume: {b["clipvol"][b_]}') i += 1 diff --git a/dlas/data/audio/fast_paired_dataset.py b/dlas/data/audio/fast_paired_dataset.py index c1d2c24f..45b3085f 100644 --- a/dlas/data/audio/fast_paired_dataset.py +++ b/dlas/data/audio/fast_paired_dataset.py @@ -12,13 +12,15 @@ import torchaudio from tqdm import tqdm from transformers import Wav2Vec2CTCTokenizer -from data.audio.paired_voice_audio_dataset import CharacterTokenizer -from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips -from utils.util import opt_get +from dlas.data.audio.paired_voice_audio_dataset import CharacterTokenizer +from dlas.data.audio.unsupervised_audio_dataset import (load_audio, + load_similar_clips) +from dlas.utils.util import opt_get def parse_tsv_aligned_codes(line, base_path): fpt = line.strip().split('\t') + def convert_string_list_to_tensor(strlist): if strlist.startswith('['): strlist = strlist[1:] @@ -43,6 +45,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): The upshot is that this dataset loads extremely quickly and consumes almost no system memory. """ + def __init__(self, hparams): self.paths = hparams['path'] if not isinstance(self.paths, list): @@ -52,26 +55,33 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): self.types = opt_get(hparams, ['types'], [0 for _ in self.paths]) self.load_conditioning = opt_get(hparams, ['load_conditioning'], False) - self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1) - self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100) - self.produce_ctc_metadata = opt_get(hparams, ['produce_ctc_metadata'], False) - self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False) + self.conditioning_candidates = opt_get( + hparams, ['num_conditioning_candidates'], 1) + self.conditioning_length = opt_get( + hparams, ['conditioning_length'], 44100) + self.produce_ctc_metadata = opt_get( + hparams, ['produce_ctc_metadata'], False) + self.debug_failures = opt_get( + hparams, ['debug_loading_failures'], False) self.text_cleaners = hparams.text_cleaners self.sample_rate = hparams.sample_rate self.aligned_codes_to_audio_ratio = 443 * self.sample_rate // 22050 self.max_wav_len = opt_get(hparams, ['max_wav_length'], None) - self.load_aligned_codes = opt_get(hparams, ['load_aligned_codes'], False) + self.load_aligned_codes = opt_get( + hparams, ['load_aligned_codes'], False) if self.max_wav_len is not None: self.max_aligned_codes = self.max_wav_len // self.aligned_codes_to_audio_ratio self.max_text_len = opt_get(hparams, ['max_text_length'], None) assert self.max_wav_len is not None and self.max_text_len is not None self.use_bpe_tokenizer = opt_get(hparams, ['use_bpe_tokenizer'], False) if self.use_bpe_tokenizer: - from data.audio.voice_tokenizer import VoiceBpeTokenizer - self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json')) + from dlas.data.audio.voice_tokenizer import VoiceBpeTokenizer + self.tokenizer = VoiceBpeTokenizer(opt_get( + hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json')) else: self.tokenizer = CharacterTokenizer() - self.skipped_items = 0 # records how many items are skipped when accessing an index. + # records how many items are skipped when accessing an index. + self.skipped_items = 0 self.load_times = torch.zeros((256,)) self.load_ind = 0 @@ -110,7 +120,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): try: # This can fail when seeking to a UTF-8 escape byte. f.readline() except: - return self.load_random_line(depth=depth + 1), type # On failure, just recurse and try again. + # On failure, just recurse and try again. + return self.load_random_line(depth=depth + 1), type l2 = f.readline() if l2: @@ -119,14 +130,16 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): return parse_tsv_aligned_codes(l2, base_path), type except: print(f"error parsing random offset: {sys.exc_info()}") - return self.load_random_line(depth=depth+1), type # On failure, just recurse and try again. + # On failure, just recurse and try again. + return self.load_random_line(depth=depth+1), type def get_ctc_metadata(self, codes): grouped = groupby(codes.tolist()) rcodes, repeats, seps = [], [], [0] for val, group in grouped: if val == 0: - seps[-1] = len(list(group)) # This is a very important distinction! It means the padding belongs to the character proceeding it. + # This is a very important distinction! It means the padding belongs to the character proceeding it. + seps[-1] = len(list(group)) else: rcodes.append(val) repeats.append(len(list(group))) @@ -142,7 +155,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): if rcodes.shape[0] < self.max_text_len: gap = self.max_text_len - rcodes.shape[0] rcodes = F.pad(rcodes, (0, gap)) - repeats = F.pad(repeats, (0, gap), value=1) # The minimum value for repeats is 1, hence this is the pad value too. + # The minimum value for repeats is 1, hence this is the pad value too. + repeats = F.pad(repeats, (0, gap), value=1) seps = F.pad(seps, (0, gap)) elif rcodes.shape[0] > self.max_text_len: rcodes = rcodes[:self.max_text_len] @@ -165,7 +179,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): if text is None or len(text.strip()) == 0: raise ValueError cond, cond_is_self = load_similar_clips(apt[0], self.conditioning_length, self.sample_rate, - n=self.conditioning_candidates) if self.load_conditioning else (None, False) + n=self.conditioning_candidates) if self.load_conditioning else (None, False) except: if self.skipped_items > 100: raise # Rethrow if we have nested too far. @@ -179,12 +193,13 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): self.skipped_items = 0 if wav is None or \ (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \ - (self.max_text_len is not None and tseq.shape[0] > self.max_text_len): + (self.max_text_len is not None and tseq.shape[0] > self.max_text_len): # Basically, this audio file is nonexistent or too long to be supported by the dataset. # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result. if self.debug_failures: - print(f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}") - rv = random.randint(0,len(self)-1) + print( + f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}") + rv = random.randint(0, len(self)-1) return self[rv] orig_output = wav.shape[-1] orig_text_len = tseq.shape[0] @@ -192,7 +207,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): if wav.shape[-1] != self.max_wav_len: wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1])) # These codes are aligned to audio inputs, so make sure to pad them as well. - aligned_codes = F.pad(aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0])) + aligned_codes = F.pad( + aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0])) if tseq.shape[0] != self.max_text_len: tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0])) @@ -223,7 +239,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): return res def __len__(self): - return self.total_size_bytes // 1000 # 1000 cuts down a TSV file to the actual length pretty well. + # 1000 cuts down a TSV file to the actual length pretty well. + return self.total_size_bytes // 1000 class FastPairedVoiceDebugger: @@ -243,7 +260,8 @@ class FastPairedVoiceDebugger: if isinstance(state, dict): self.total_items = opt_get(state, ['total_items'], 0) self.loaded_items = opt_get(state, ['loaded_items'], 0) - self.self_conditioning_items = opt_get(state, ['self_conditioning_items'], 0) + self.self_conditioning_items = opt_get( + state, ['self_conditioning_items'], 0) def update(self, batch): self.total_items += batch['wav'].shape[0] @@ -252,7 +270,8 @@ class FastPairedVoiceDebugger: for filename in batch['filenames']: self.unique_files.add(hashlib.sha256(filename.encode('utf-8'))) if 'conditioning' in batch.keys(): - self.self_conditioning_items += batch['conditioning_contains_self'].sum().item() + self.self_conditioning_items += batch['conditioning_contains_self'].sum( + ).item() def get_debugging_map(self): return { @@ -269,13 +288,13 @@ if __name__ == '__main__': params = { 'mode': 'fast_paired_voice_audio', 'path': ['y:/libritts/train-other-500/transcribed-oco.tsv', - 'y:/libritts/train-clean-100/transcribed-oco.tsv', - 'y:/libritts/train-clean-360/transcribed-oco.tsv', - 'y:/clips/books1/transcribed-oco.tsv', - 'y:/clips/books2/transcribed-oco.tsv', - 'y:/bigasr_dataset/hifi_tts/transcribed-oco.tsv', - 'y:/clips/podcasts-1/transcribed-oco.tsv',], - 'types': [0,1,1,1,2,2,0], + 'y:/libritts/train-clean-100/transcribed-oco.tsv', + 'y:/libritts/train-clean-360/transcribed-oco.tsv', + 'y:/clips/books1/transcribed-oco.tsv', + 'y:/clips/books2/transcribed-oco.tsv', + 'y:/bigasr_dataset/hifi_tts/transcribed-oco.tsv', + 'y:/clips/podcasts-1/transcribed-oco.tsv',], + 'types': [0, 1, 1, 1, 2, 2, 0], 'phase': 'train', 'n_workers': 0, 'batch_size': batch_sz, @@ -289,11 +308,12 @@ if __name__ == '__main__': 'load_aligned_codes': True, 'produce_ctc_metadata': True, } - from data import create_dataset, create_dataloader + from data import create_dataloader, create_dataset def save(b, i, ib, key, c=None): if c is not None: - torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', b[key][ib][c], 22050) + torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', + b[key][ib][c], 22050) else: torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050) @@ -304,8 +324,8 @@ if __name__ == '__main__': max_pads, max_repeats = 0, 0 for i, b in tqdm(enumerate(dl)): for ib in range(batch_sz): - #max_pads = max(max_pads, b['ctc_pads'].max()) - #max_repeats = max(max_repeats, b['ctc_repeats'].max()) + # max_pads = max(max_pads, b['ctc_pads'].max()) + # max_repeats = max(max_repeats, b['ctc_repeats'].max()) print(f'{i} {ib} {b["real_text"][ib]}') save(b, i, ib, 'wav') save(b, i, ib, 'conditioning', 0) @@ -314,4 +334,3 @@ if __name__ == '__main__': if i > 15: break print(max_pads, max_repeats) - diff --git a/dlas/data/audio/fast_paired_dataset_with_phonemes.py b/dlas/data/audio/fast_paired_dataset_with_phonemes.py index 54b208b9..4a2bec25 100644 --- a/dlas/data/audio/fast_paired_dataset_with_phonemes.py +++ b/dlas/data/audio/fast_paired_dataset_with_phonemes.py @@ -12,13 +12,15 @@ import torchaudio from tqdm import tqdm from transformers import Wav2Vec2Processor -from data.audio.paired_voice_audio_dataset import CharacterTokenizer -from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips -from utils.util import opt_get +from dlas.data.audio.paired_voice_audio_dataset import CharacterTokenizer +from dlas.data.audio.unsupervised_audio_dataset import (load_audio, + load_similar_clips) +from dlas.utils.util import opt_get def parse_tsv_aligned_codes(line, base_path): fpt = line.strip().split('\t') + def convert_string_list_to_tensor(strlist): if strlist.startswith('['): strlist = strlist[1:] @@ -43,10 +45,12 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): The upshot is that this dataset loads extremely quickly and consumes almost no system memory. """ + def __init__(self, hparams): self.paths = hparams['path'] phoneme_paths = hparams['phoneme_paths'] - self.paths = [(p, False) for p in self.paths] + [(p, True) for p in phoneme_paths] + self.paths = [(p, False) for p in self.paths] + [(p, True) + for p in phoneme_paths] self.paths_size_bytes = [os.path.getsize(p) for p, _ in self.paths] self.total_size_bytes = sum(self.paths_size_bytes) @@ -54,28 +58,36 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): self.normal_text_end_token = hparams['normal_text_end_token'] self.load_conditioning = opt_get(hparams, ['load_conditioning'], False) - self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1) - self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100) - self.produce_ctc_metadata = opt_get(hparams, ['produce_ctc_metadata'], False) - self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False) + self.conditioning_candidates = opt_get( + hparams, ['num_conditioning_candidates'], 1) + self.conditioning_length = opt_get( + hparams, ['conditioning_length'], 44100) + self.produce_ctc_metadata = opt_get( + hparams, ['produce_ctc_metadata'], False) + self.debug_failures = opt_get( + hparams, ['debug_loading_failures'], False) self.text_cleaners = hparams.text_cleaners self.sample_rate = hparams.sample_rate self.aligned_codes_to_audio_ratio = 443 * self.sample_rate // 22050 self.max_wav_len = opt_get(hparams, ['max_wav_length'], None) - self.load_aligned_codes = opt_get(hparams, ['load_aligned_codes'], False) + self.load_aligned_codes = opt_get( + hparams, ['load_aligned_codes'], False) if self.max_wav_len is not None: self.max_aligned_codes = self.max_wav_len // self.aligned_codes_to_audio_ratio self.max_text_len = opt_get(hparams, ['max_text_length'], None) assert self.max_wav_len is not None and self.max_text_len is not None self.use_bpe_tokenizer = opt_get(hparams, ['use_bpe_tokenizer'], False) if self.use_bpe_tokenizer: - from data.audio.voice_tokenizer import VoiceBpeTokenizer - self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json')) + from dlas.data.audio.voice_tokenizer import VoiceBpeTokenizer + self.tokenizer = VoiceBpeTokenizer(opt_get( + hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json')) else: self.tokenizer = CharacterTokenizer() - self.ipa_phoneme_tokenizer = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft").tokenizer + self.ipa_phoneme_tokenizer = Wav2Vec2Processor.from_pretrained( + "facebook/wav2vec2-lv-60-espeak-cv-ft").tokenizer self.ipa_phoneme_tokenizer.do_phonemize = False - self.skipped_items = 0 # records how many items are skipped when accessing an index. + # records how many items are skipped when accessing an index. + self.skipped_items = 0 self.load_times = torch.zeros((256,)) self.load_ind = 0 @@ -117,7 +129,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): try: # This can fail when seeking to a UTF-8 escape byte. f.readline() except: - return self.load_random_line(depth=depth + 1) # On failure, just recurse and try again. + # On failure, just recurse and try again. + return self.load_random_line(depth=depth + 1) l2 = f.readline() if l2: @@ -126,14 +139,16 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): return parse_tsv_aligned_codes(l2, base_path), type, is_phonetic except: print(f"error parsing random offset: {sys.exc_info()}") - return self.load_random_line(depth=depth+1) # On failure, just recurse and try again. + # On failure, just recurse and try again. + return self.load_random_line(depth=depth+1) def get_ctc_metadata(self, codes): grouped = groupby(codes.tolist()) rcodes, repeats, seps = [], [], [0] for val, group in grouped: if val == 0: - seps[-1] = len(list(group)) # This is a very important distinction! It means the padding belongs to the character proceeding it. + # This is a very important distinction! It means the padding belongs to the character proceeding it. + seps[-1] = len(list(group)) else: rcodes.append(val) repeats.append(len(list(group))) @@ -149,7 +164,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): if rcodes.shape[0] < self.max_text_len: gap = self.max_text_len - rcodes.shape[0] rcodes = F.pad(rcodes, (0, gap)) - repeats = F.pad(repeats, (0, gap), value=1) # The minimum value for repeats is 1, hence this is the pad value too. + # The minimum value for repeats is 1, hence this is the pad value too. + repeats = F.pad(repeats, (0, gap), value=1) seps = F.pad(seps, (0, gap)) elif rcodes.shape[0] > self.max_text_len: rcodes = rcodes[:self.max_text_len] @@ -171,7 +187,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): if text is None or len(text.strip()) == 0: raise ValueError cond, cond_is_self = load_similar_clips(apt[0], self.conditioning_length, self.sample_rate, - n=self.conditioning_candidates) if self.load_conditioning else (None, False) + n=self.conditioning_candidates) if self.load_conditioning else (None, False) except: if self.skipped_items > 100: raise # Rethrow if we have nested too far. @@ -185,12 +201,13 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): self.skipped_items = 0 if wav is None or \ (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \ - (self.max_text_len is not None and tseq.shape[0] > self.max_text_len): + (self.max_text_len is not None and tseq.shape[0] > self.max_text_len): # Basically, this audio file is nonexistent or too long to be supported by the dataset. # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result. if self.debug_failures: - print(f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}") - rv = random.randint(0,len(self)-1) + print( + f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}") + rv = random.randint(0, len(self)-1) return self[rv] # Shift phonetic token and aligned_code tokens over. @@ -206,7 +223,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): if wav.shape[-1] != self.max_wav_len: wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1])) # These codes are aligned to audio inputs, so make sure to pad them as well. - aligned_codes = F.pad(aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0])) + aligned_codes = F.pad( + aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0])) if tseq.shape[0] != self.max_text_len: tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0])) @@ -237,7 +255,8 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): return res def __len__(self): - return self.total_size_bytes // 1000 # 1000 cuts down a TSV file to the actual length pretty well. + # 1000 cuts down a TSV file to the actual length pretty well. + return self.total_size_bytes // 1000 class FastPairedVoiceDebugger: @@ -257,7 +276,8 @@ class FastPairedVoiceDebugger: if isinstance(state, dict): self.total_items = opt_get(state, ['total_items'], 0) self.loaded_items = opt_get(state, ['loaded_items'], 0) - self.self_conditioning_items = opt_get(state, ['self_conditioning_items'], 0) + self.self_conditioning_items = opt_get( + state, ['self_conditioning_items'], 0) def update(self, batch): self.total_items += batch['wav'].shape[0] @@ -266,7 +286,8 @@ class FastPairedVoiceDebugger: for filename in batch['filenames']: self.unique_files.add(hashlib.sha256(filename.encode('utf-8'))) if 'conditioning' in batch.keys(): - self.self_conditioning_items += batch['conditioning_contains_self'].sum().item() + self.self_conditioning_items += batch['conditioning_contains_self'].sum( + ).item() def get_debugging_map(self): return { @@ -284,7 +305,7 @@ if __name__ == '__main__': 'mode': 'fast_paired_voice_audio_with_phonemes', 'path': ['y:/libritts/train-clean-100/transcribed-oco.tsv',], 'phoneme_paths': ['y:/libritts/train-other-500/transcribed-phoneme-oco.tsv'], - 'types': [0,0], + 'types': [0, 0], 'normal_text_end_token': 256, 'phase': 'train', 'n_workers': 0, @@ -299,11 +320,12 @@ if __name__ == '__main__': 'load_aligned_codes': False, 'debug_loading_failures': True, } - from data import create_dataset, create_dataloader + from data import create_dataloader, create_dataset def save(b, i, ib, key, c=None): if c is not None: - torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', b[key][ib][c], 22050) + torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', + b[key][ib][c], 22050) else: torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050) @@ -314,14 +336,13 @@ if __name__ == '__main__': max_pads, max_repeats = 0, 0 for i, b in tqdm(enumerate(dl)): for ib in range(batch_sz): - #max_pads = max(max_pads, b['ctc_pads'].max()) - #max_repeats = max(max_repeats, b['ctc_repeats'].max()) + # max_pads = max(max_pads, b['ctc_pads'].max()) + # max_repeats = max(max_repeats, b['ctc_repeats'].max()) print(f'{i} {ib} {b["real_text"][ib]}') - #save(b, i, ib, 'wav') - #save(b, i, ib, 'conditioning', 0) - #save(b, i, ib, 'conditioning', 1) + # save(b, i, ib, 'wav') + # save(b, i, ib, 'conditioning', 0) + # save(b, i, ib, 'conditioning', 1) pass if i > 15: break print(max_pads, max_repeats) - diff --git a/dlas/data/audio/gpt_tts_dataset.py b/dlas/data/audio/gpt_tts_dataset.py index a872fd67..35c83f0b 100644 --- a/dlas/data/audio/gpt_tts_dataset.py +++ b/dlas/data/audio/gpt_tts_dataset.py @@ -6,9 +6,8 @@ import torch.utils.data from torch import LongTensor from tqdm import tqdm -from models.audio.tts.tacotron2 import load_filepaths_and_text -from models.audio.tts.tacotron2 import symbols -from models.audio.tts.tacotron2 import text_to_sequence +from dlas.models.audio.tts.tacotron2 import (load_filepaths_and_text, symbols, + text_to_sequence) class GptTtsDataset(torch.utils.data.Dataset): @@ -21,7 +20,7 @@ class GptTtsDataset(torch.utils.data.Dataset): def __init__(self, opt): self.path = os.path.dirname(opt['path']) self.audiopaths_and_text = load_filepaths_and_text(opt['path']) - self.text_cleaners=['english_cleaners'] + self.text_cleaners = ['english_cleaners'] self.MEL_DICTIONARY_SIZE = opt['mel_vocab_size']+3 self.MEL_START_TOKEN = LongTensor([self.MEL_DICTIONARY_SIZE-3]) @@ -32,7 +31,8 @@ class GptTtsDataset(torch.utils.data.Dataset): audiopath_and_text = self.audiopaths_and_text[index] audiopath, text = audiopath_and_text[0], audiopath_and_text[1] text = torch.IntTensor(text_to_sequence(text, self.text_cleaners)) - text = torch.cat([self.TEXT_START_TOKEN, text, self.TEXT_STOP_TOKEN], dim=0) + text = torch.cat([self.TEXT_START_TOKEN, text, + self.TEXT_STOP_TOKEN], dim=0) # Fetch quantized MELs quant_path = audiopath.replace('wavs/', 'quantized_mels/') + '.pth' @@ -57,8 +57,9 @@ class GptTtsCollater(): def __call__(self, batch): text_lens = [len(x[0]) for x in batch] - #max_text_len = max(text_lens) - max_text_len = self.MAX_SYMBOLS_PER_PHRASE # This forces all outputs to have the full 200 characters. Testing if this makes a difference. + # max_text_len = max(text_lens) + # This forces all outputs to have the full 200 characters. Testing if this makes a difference. + max_text_len = self.MAX_SYMBOLS_PER_PHRASE mel_lens = [len(x[1]) for x in batch] max_mel_len = max(mel_lens) texts = [] @@ -70,7 +71,8 @@ class GptTtsCollater(): text = F.pad(text, (0, max_text_len-len(text)), value=0) text = torch.where(text == 0, text_range_embedding, text) texts.append(text) - qmels.append(F.pad(qmel, (0, max_mel_len-len(qmel)), value=self.MEL_PAD_TOKEN)) + qmels.append(F.pad(qmel, (0, max_mel_len-len(qmel)), + value=self.MEL_PAD_TOKEN)) filenames = [j[2] for j in batch] @@ -96,7 +98,7 @@ if __name__ == '__main__': 'batch_size': 16, 'mel_vocab_size': 512, } - from data import create_dataset, create_dataloader + from data import create_dataloader, create_dataset ds, c = create_dataset(params, return_collate=True) dl = create_dataloader(ds, params, collate_fn=c) @@ -107,5 +109,5 @@ if __name__ == '__main__': for b in tqdm(dl): max_mel = max(max_mel, b['padded_qmel'].shape[2]) max_text = max(max_text, b['padded_text'].shape[1]) - m=torch.stack(m) + m = torch.stack(m) print(m.mean(), m.std()) diff --git a/dlas/data/audio/grand_conjoined_dataset.py b/dlas/data/audio/grand_conjoined_dataset.py index 7e56b79d..7c24f6c0 100644 --- a/dlas/data/audio/grand_conjoined_dataset.py +++ b/dlas/data/audio/grand_conjoined_dataset.py @@ -7,14 +7,15 @@ import torchaudio from munch import munchify from tqdm import tqdm -from data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset -from data.text.hf_datasets_wrapper import HfDataset -from utils.util import opt_get +from dlas.data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset +from dlas.data.text.hf_datasets_wrapper import HfDataset +from dlas.utils.util import opt_get def build_paired_voice_dataset(args): - from data.audio.paired_voice_audio_dataset import TextWavLoader as D from models.audio.tts.tacotron2 import create_hparams + + from dlas.data.audio.paired_voice_audio_dataset import TextWavLoader as D default_params = create_hparams() default_params.update(args) dataset_opt = munchify(default_params) @@ -33,6 +34,7 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): Performs tokenization at this level, ignoring any tokenization performed by upstream datasets. """ + def __init__(self, opt): sample_rate = 22050 # Fixed. paired_dataset_args = opt['paired_dataset_args'] @@ -47,7 +49,8 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): self.max_solo_text_length = opt['max_solo_text_length'] self.collate = opt_get(opt, ['needs_collate'], False) self.sample_rate = sample_rate - self.num_conditioning_candidates = opt_get(opt, ['num_conditioning_candidates'], 0) + self.num_conditioning_candidates = opt_get( + opt, ['num_conditioning_candidates'], 0) self.conditioning_length = opt_get(opt, ['conditioning_length'], 44000) load_conditioning = self.num_conditioning_candidates > 0 @@ -75,7 +78,8 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): def fetch_text_at(self, i): try: txt = self.text[i % len(self.text)]['text'] - assert '*' not in txt # This is a hack to get around the use of '*' to mask expletives in some text-only datasets. There really isn't a linguistic use for this character anyways. + # This is a hack to get around the use of '*' to mask expletives in some text-only datasets. There really isn't a linguistic use for this character anyways. + assert '*' not in txt tok = self.speech_and_text.get_text(txt) padding_required = self.max_solo_text_length - tok.shape[0] if padding_required < 0: @@ -137,7 +141,8 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): sp = self.speech[i % len(self.speech)] # Set upper bound on solo speech lengths. This is handled automatically when collation is turned off, but needs to be done otherwise. sp['clip'] = sp['clip'][:, :self.max_solo_audio_length] - sp['clip_lengths'] = sp['clip_lengths'].clamp(0, self.max_solo_audio_length) + sp['clip_lengths'] = sp['clip_lengths'].clamp( + 0, self.max_solo_audio_length) return self.optionally_add_conditioning_candidates({ 'paired_audio': snt['wav'], 'paired_audio_lengths': snt['wav_lengths'], @@ -205,7 +210,7 @@ if __name__ == '__main__': 'use_bpe_tokenizer': False, }, } - from data import create_dataset, create_dataloader + from data import create_dataloader, create_dataset os.remove('test_cache_delete_me2.pth') ds, c = create_dataset(train_params, return_collate=True) @@ -213,7 +218,8 @@ if __name__ == '__main__': def save(b, i, ib, key, c=None): if c is not None: - torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', b[key][ib][c], 22050) + torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', + b[key][ib][c], 22050) else: torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050) @@ -224,16 +230,17 @@ if __name__ == '__main__': m = None for i, b in tqdm(enumerate(dl)): for ib in range(batch_sz): - #save(b, i, ib, 'paired_audio') - #save(b, i, ib, 'paired_audio_conditioning', 0) - #save(b, i, ib, 'paired_audio_conditioning', 1) - print(f'Paired file: {b["paired_file"][ib]} text: {b["paired_text"][ib]}') - print(f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}') - #save(b, i, ib, 'speech_audio') - #save(b, i, ib, 'speech_audio_conditioning', 0) - #save(b, i, ib, 'speech_audio_conditioning', 1) - #print(f'Text: {b["text_text"][ib]}') - #print(f'Text decoded: {decode(b, ib, "text_tokens")}') + # save(b, i, ib, 'paired_audio') + # save(b, i, ib, 'paired_audio_conditioning', 0) + # save(b, i, ib, 'paired_audio_conditioning', 1) + print( + f'Paired file: {b["paired_file"][ib]} text: {b["paired_text"][ib]}') + print( + f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}') + # save(b, i, ib, 'speech_audio') + # save(b, i, ib, 'speech_audio_conditioning', 0) + # save(b, i, ib, 'speech_audio_conditioning', 1) + # print(f'Text: {b["text_text"][ib]}') + # print(f'Text decoded: {decode(b, ib, "text_tokens")}') if i > 5: break - diff --git a/dlas/data/audio/nv_tacotron_dataset.py b/dlas/data/audio/nv_tacotron_dataset.py index f94268cb..e09158ff 100644 --- a/dlas/data/audio/nv_tacotron_dataset.py +++ b/dlas/data/audio/nv_tacotron_dataset.py @@ -7,32 +7,36 @@ import torch.utils.data import torchaudio from tqdm import tqdm -from data.audio.unsupervised_audio_dataset import load_audio -from data.util import find_files_of_type, is_audio_file -from models.audio.tts.tacotron2 import load_filepaths_and_text -from models.audio.tts.tacotron2 import text_to_sequence -from utils.util import opt_get +from dlas.data.audio.unsupervised_audio_dataset import load_audio +from dlas.data.util import find_files_of_type, is_audio_file +from dlas.models.audio.tts.tacotron2 import (load_filepaths_and_text, + text_to_sequence) +from dlas.utils.util import opt_get def load_tsv(filename): with open(filename, encoding='utf-8') as f: components = [line.strip().split('\t') for line in f] base = os.path.dirname(filename) - filepaths_and_text = [[os.path.join(base, f'{component[1]}'), component[0]] for component in components] + filepaths_and_text = [ + [os.path.join(base, f'{component[1]}'), component[0]] for component in components] return filepaths_and_text def load_mozilla_cv(filename): with open(filename, encoding='utf-8') as f: - components = [line.strip().split('\t') for line in f][1:] # First line is the header + components = [line.strip().split('\t') + for line in f][1:] # First line is the header base = os.path.dirname(filename) - filepaths_and_text = [[os.path.join(base, f'clips/{component[1]}'), component[2]] for component in components] + filepaths_and_text = [[os.path.join( + base, f'clips/{component[1]}'), component[2]] for component in components] return filepaths_and_text def load_voxpopuli(filename): with open(filename, encoding='utf-8') as f: - lines = [line.strip().split('\t') for line in f][1:] # First line is the header + lines = [line.strip().split('\t') + for line in f][1:] # First line is the header base = os.path.dirname(filename) filepaths_and_text = [] for line in lines: @@ -40,7 +44,8 @@ def load_voxpopuli(filename): continue file, raw_text, norm_text, speaker_id, split, gender = line year = file[:4] - filepaths_and_text.append([os.path.join(base, year, f'{file}.ogg.wav'), raw_text]) + filepaths_and_text.append( + [os.path.join(base, year, f'{file}.ogg.wav'), raw_text]) return filepaths_and_text @@ -56,8 +61,10 @@ class TextWavLoader(torch.utils.data.Dataset): assert len(self.path) == len(fetcher_mode) self.load_conditioning = opt_get(hparams, ['load_conditioning'], False) - self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 3) - self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100) + self.conditioning_candidates = opt_get( + hparams, ['num_conditioning_candidates'], 3) + self.conditioning_length = opt_get( + hparams, ['conditioning_length'], 44100) self.audiopaths_and_text = [] for p, fm in zip(self.path, fetcher_mode): if fm == 'lj' or fm == 'libritts': @@ -65,10 +72,12 @@ class TextWavLoader(torch.utils.data.Dataset): elif fm == 'tsv': fetcher_fn = load_tsv elif fm == 'mozilla_cv': - assert not self.load_conditioning # Conditioning inputs are incompatible with mozilla_cv + # Conditioning inputs are incompatible with mozilla_cv + assert not self.load_conditioning fetcher_fn = load_mozilla_cv elif fm == 'voxpopuli': - assert not self.load_conditioning # Conditioning inputs are incompatible with voxpopuli + # Conditioning inputs are incompatible with voxpopuli + assert not self.load_conditioning fetcher_fn = load_voxpopuli else: raise NotImplementedError() @@ -96,10 +105,13 @@ class TextWavLoader(torch.utils.data.Dataset): return text_norm def load_conditioning_candidates(self, path): - candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0] - assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related. + candidates = find_files_of_type( + 'img', os.path.dirname(path), qualifier=is_audio_file)[0] + # Sanity check to ensure we aren't loading "related files" that aren't actually related. + assert len(candidates) < 50000 if len(candidates) == 0: - print(f"No conditioning candidates found for {path} (not even the clip itself??)") + print( + f"No conditioning candidates found for {path} (not even the clip itself??)") raise NotImplementedError() # Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates. related_clips = [] @@ -110,25 +122,28 @@ class TextWavLoader(torch.utils.data.Dataset): rel_clip = F.pad(rel_clip, pad=(0, abs(gap))) elif gap > 0: rand_start = random.randint(0, gap) - rel_clip = rel_clip[:, rand_start:rand_start+self.conditioning_length] + rel_clip = rel_clip[:, rand_start:rand_start + + self.conditioning_length] related_clips.append(rel_clip) return torch.stack(related_clips, dim=0) def __getitem__(self, index): try: - tseq, wav, text, path = self.get_wav_text_pair(self.audiopaths_and_text[index]) - cond = self.load_conditioning_candidates(self.audiopaths_and_text[index][0]) if self.load_conditioning else None + tseq, wav, text, path = self.get_wav_text_pair( + self.audiopaths_and_text[index]) + cond = self.load_conditioning_candidates( + self.audiopaths_and_text[index][0]) if self.load_conditioning else None except: print(f"error loading {self.audiopaths_and_text[index][0]}") return self[index+1] if wav is None or \ (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \ - (self.max_text_len is not None and tseq.shape[0] > self.max_text_len): + (self.max_text_len is not None and tseq.shape[0] > self.max_text_len): # Basically, this audio file is nonexistent or too long to be supported by the dataset. # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result. - #if wav is not None: + # if wav is not None: # print(f"Exception {index} wav_len:{wav.shape[-1]} text_len:{tseq.shape[0]} fname: {path}") - rv = random.randint(0,len(self)-1) + rv = random.randint(0, len(self)-1) return self[rv] orig_output = wav.shape[-1] orig_text_len = tseq.shape[0] @@ -157,6 +172,7 @@ class TextWavLoader(torch.utils.data.Dataset): class TextMelCollate(): """ Zero-pads model inputs and targets based on number of frames per step """ + def __call__(self, batch): """Collate's training batch from normalized text and wav PARAMS @@ -226,7 +242,7 @@ if __name__ == '__main__': 'num_conditioning_candidates': 3, 'conditioning_length': 44100, } - from data import create_dataset, create_dataloader + from data import create_dataloader, create_dataset ds, c = create_dataset(params, return_collate=True) dl = create_dataloader(ds, params, collate_fn=c) @@ -240,4 +256,5 @@ if __name__ == '__main__': print(f'{i} {ib} {b["real_text"][ib]}') torchaudio.save(f'{i}_clip_{ib}.wav', b['wav'][ib], ds.sample_rate) for c in range(3): - torchaudio.save(f'{i}_clip_{ib}_cond{c}.wav', b['conditioning'][ib, c], ds.sample_rate) + torchaudio.save(f'{i}_clip_{ib}_cond{c}.wav', + b['conditioning'][ib, c], ds.sample_rate) diff --git a/dlas/data/audio/paired_voice_audio_dataset.py b/dlas/data/audio/paired_voice_audio_dataset.py index b8e6f0a3..e3c87194 100644 --- a/dlas/data/audio/paired_voice_audio_dataset.py +++ b/dlas/data/audio/paired_voice_audio_dataset.py @@ -8,10 +8,13 @@ import torch.utils.data import torchaudio from tqdm import tqdm -from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips -from models.audio.tts.tacotron2 import load_filepaths_and_text, load_filepaths_and_text_type -from models.audio.tts.tacotron2 import text_to_sequence, sequence_to_text -from utils.util import opt_get +from dlas.data.audio.unsupervised_audio_dataset import (load_audio, + load_similar_clips) +from dlas.models.audio.tts.tacotron2 import (load_filepaths_and_text, + load_filepaths_and_text_type, + sequence_to_text, + text_to_sequence) +from dlas.utils.util import opt_get def load_tsv_type(filename, type): @@ -24,10 +27,12 @@ def load_tsv_type(filename, type): if len(components) < 2: bad_lines += 1 if bad_lines > 1000: - print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}') + print( + f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}') raise ValueError continue - filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0]] + [type]) + filepaths_and_text.append( + [os.path.join(base, f'{components[1]}'), components[0]] + [type]) return filepaths_and_text @@ -41,10 +46,12 @@ def load_tsv(filename): if len(components) < 2: bad_lines += 1 if bad_lines > 1000: - print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}') + print( + f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}') raise ValueError continue - filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0]]) + filepaths_and_text.append( + [os.path.join(base, f'{components[1]}'), components[0]]) return filepaths_and_text @@ -67,10 +74,12 @@ def load_tsv_aligned_codes_type(filename, type): if len(components) < 3: bad_lines += 1 if bad_lines > 1000: - print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}') + print( + f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}') raise ValueError continue - filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0], convert_string_list_to_tensor(components[2])] + [type]) + filepaths_and_text.append([os.path.join( + base, f'{components[1]}'), components[0], convert_string_list_to_tensor(components[2])] + [type]) return filepaths_and_text @@ -84,24 +93,29 @@ def load_tsv_aligned_codes(filename): if len(components) < 3: bad_lines += 1 if bad_lines > 1000: - print(f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}') + print( + f'{filename} contains 1000+ bad entries. Failing. Sample last entry: {line}') raise ValueError continue - filepaths_and_text.append([os.path.join(base, f'{components[1]}'), components[0], convert_string_list_to_tensor(components[2])]) + filepaths_and_text.append([os.path.join( + base, f'{components[1]}'), components[0], convert_string_list_to_tensor(components[2])]) return filepaths_and_text def load_mozilla_cv(filename, type): with open(filename, encoding='utf-8') as f: - components = [line.strip().split('\t') for line in f][1:] # First line is the header + components = [line.strip().split('\t') + for line in f][1:] # First line is the header base = os.path.dirname(filename) - filepaths_and_text = [[os.path.join(base, f'clips/{component[1]}'), component[2], type] for component in components] + filepaths_and_text = [[os.path.join( + base, f'clips/{component[1]}'), component[2], type] for component in components] return filepaths_and_text def load_voxpopuli(filename, type): with open(filename, encoding='utf-8') as f: - lines = [line.strip().split('\t') for line in f][1:] # First line is the header + lines = [line.strip().split('\t') + for line in f][1:] # First line is the header base = os.path.dirname(filename) filepaths_and_text = [] for line in lines: @@ -109,7 +123,8 @@ def load_voxpopuli(filename, type): continue file, raw_text, norm_text, speaker_id, split, gender = line year = file[:4] - filepaths_and_text.append([os.path.join(base, year, f'{file}.ogg.wav'), raw_text, type]) + filepaths_and_text.append( + [os.path.join(base, year, f'{file}.ogg.wav'), raw_text, type]) return filepaths_and_text @@ -134,11 +149,16 @@ class TextWavLoader(torch.utils.data.Dataset): assert len(self.path) == len(fetcher_mode) self.load_conditioning = opt_get(hparams, ['load_conditioning'], False) - self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1) - self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100) - self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False) - self.load_aligned_codes = opt_get(hparams, ['load_aligned_codes'], False) - self.aligned_codes_to_audio_ratio = opt_get(hparams, ['aligned_codes_ratio'], 443) + self.conditioning_candidates = opt_get( + hparams, ['num_conditioning_candidates'], 1) + self.conditioning_length = opt_get( + hparams, ['conditioning_length'], 44100) + self.debug_failures = opt_get( + hparams, ['debug_loading_failures'], False) + self.load_aligned_codes = opt_get( + hparams, ['load_aligned_codes'], False) + self.aligned_codes_to_audio_ratio = opt_get( + hparams, ['aligned_codes_ratio'], 443) self.audiopaths_and_text = [] for p, fm, type in zip(self.path, fetcher_mode, self.types): if fm == 'lj' or fm == 'libritts': @@ -146,10 +166,12 @@ class TextWavLoader(torch.utils.data.Dataset): elif fm == 'tsv': fetcher_fn = load_tsv_aligned_codes_type if self.load_aligned_codes else load_tsv_type elif fm == 'mozilla_cv': - assert not self.load_conditioning # Conditioning inputs are incompatible with mozilla_cv + # Conditioning inputs are incompatible with mozilla_cv + assert not self.load_conditioning fetcher_fn = load_mozilla_cv elif fm == 'voxpopuli': - assert not self.load_conditioning # Conditioning inputs are incompatible with voxpopuli + # Conditioning inputs are incompatible with voxpopuli + assert not self.load_conditioning fetcher_fn = load_voxpopuli else: raise NotImplementedError() @@ -165,11 +187,13 @@ class TextWavLoader(torch.utils.data.Dataset): assert self.max_wav_len is not None and self.max_text_len is not None self.use_bpe_tokenizer = opt_get(hparams, ['use_bpe_tokenizer'], True) if self.use_bpe_tokenizer: - from data.audio.voice_tokenizer import VoiceBpeTokenizer - self.tokenizer = VoiceBpeTokenizer(opt_get(hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json')) + from dlas.data.audio.voice_tokenizer import VoiceBpeTokenizer + self.tokenizer = VoiceBpeTokenizer(opt_get( + hparams, ['tokenizer_vocab'], '../experiments/bpe_lowercase_asr_256.json')) else: self.tokenizer = CharacterTokenizer() - self.skipped_items = 0 # records how many items are skipped when accessing an index. + # records how many items are skipped when accessing an index. + self.skipped_items = 0 def get_wav_text_pair(self, audiopath_and_text): # separate filename and text @@ -191,19 +215,21 @@ class TextWavLoader(torch.utils.data.Dataset): def __getitem__(self, index): self.skipped_items += 1 try: - tseq, wav, text, path, type = self.get_wav_text_pair(self.audiopaths_and_text[index]) + tseq, wav, text, path, type = self.get_wav_text_pair( + self.audiopaths_and_text[index]) if text is None or len(text.strip()) == 0: raise ValueError if wav is None or wav.shape[-1] < (.6 * self.sample_rate): # Ultra short clips are also useless (and can cause problems within some models). raise ValueError cond, cond_is_self = load_similar_clips(self.audiopaths_and_text[index][0], self.conditioning_length, self.sample_rate, - n=self.conditioning_candidates) if self.load_conditioning else (None, False) + n=self.conditioning_candidates) if self.load_conditioning else (None, False) except: if self.skipped_items > 100: raise # Rethrow if we have nested too far. if self.debug_failures: - print(f"error loading {self.audiopaths_and_text[index][0]} {sys.exc_info()}") + print( + f"error loading {self.audiopaths_and_text[index][0]} {sys.exc_info()}") return self[(index+1) % len(self)] if self.load_aligned_codes: @@ -213,12 +239,13 @@ class TextWavLoader(torch.utils.data.Dataset): self.skipped_items = 0 if wav is None or \ (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \ - (self.max_text_len is not None and tseq.shape[0] > self.max_text_len): + (self.max_text_len is not None and tseq.shape[0] > self.max_text_len): # Basically, this audio file is nonexistent or too long to be supported by the dataset. # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result. if self.debug_failures: - print(f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}") - rv = random.randint(0,len(self)-1) + print( + f"error loading {path}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}") + rv = random.randint(0, len(self)-1) return self[rv] orig_output = wav.shape[-1] orig_text_len = tseq.shape[0] @@ -226,7 +253,8 @@ class TextWavLoader(torch.utils.data.Dataset): wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1])) if self.load_aligned_codes: # These codes are aligned to audio inputs, so make sure to pad them as well. - aligned_codes = F.pad(aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0])) + aligned_codes = F.pad( + aligned_codes, (0, self.max_aligned_codes-aligned_codes.shape[0])) if tseq.shape[0] != self.max_text_len: tseq = F.pad(tseq, (0, self.max_text_len - tseq.shape[0])) res = { @@ -265,13 +293,15 @@ class PairedVoiceDebugger: if isinstance(state, dict): self.total_items = opt_get(state, ['total_items'], 0) self.loaded_items = opt_get(state, ['loaded_items'], 0) - self.self_conditioning_items = opt_get(state, ['self_conditioning_items'], 0) + self.self_conditioning_items = opt_get( + state, ['self_conditioning_items'], 0) def update(self, batch): self.total_items += batch['wav'].shape[0] self.loaded_items += batch['skipped_items'].sum().item() if 'conditioning' in batch.keys(): - self.self_conditioning_items += batch['conditioning_contains_self'].sum().item() + self.self_conditioning_items += batch['conditioning_contains_self'].sum( + ).item() def get_debugging_map(self): return { @@ -299,11 +329,12 @@ if __name__ == '__main__': 'use_bpe_tokenizer': True, 'load_aligned_codes': False, } - from data import create_dataset, create_dataloader + from data import create_dataloader, create_dataset def save(b, i, ib, key, c=None): if c is not None: - torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', b[key][ib][c], 22050) + torchaudio.save(f'{i}_clip_{ib}_{key}_{c}.wav', + b[key][ib][c], 22050) else: torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050) @@ -317,4 +348,3 @@ if __name__ == '__main__': save(b, i, ib, 'wav') if i > 5: break - diff --git a/dlas/data/audio/preprocessed_mel_dataset.py b/dlas/data/audio/preprocessed_mel_dataset.py index a88e704e..9a29bec2 100644 --- a/dlas/data/audio/preprocessed_mel_dataset.py +++ b/dlas/data/audio/preprocessed_mel_dataset.py @@ -9,14 +9,15 @@ import torchaudio import torchvision from tqdm import tqdm -from utils.util import opt_get +from dlas.utils.util import opt_get class PreprocessedMelDataset(torch.utils.data.Dataset): def __init__(self, opt): path = opt['path'] - cache_path = opt['cache_path'] # Will fail when multiple paths specified, must be specified in this case. + # Will fail when multiple paths specified, must be specified in this case. + cache_path = opt['cache_path'] if os.path.exists(cache_path): self.paths = torch.load(cache_path) else: @@ -36,8 +37,8 @@ class PreprocessedMelDataset(torch.utils.data.Dataset): padding_needed = self.pad_to - mel.shape[-1] mask = torch.zeros_like(mel) if padding_needed > 0: - mel = F.pad(mel, (0,padding_needed)) - mask = F.pad(mask, (0,padding_needed), value=1) + mel = F.pad(mel, (0, padding_needed)) + mask = F.pad(mask, (0, padding_needed), value=1) output = { 'mel': mel, @@ -62,13 +63,13 @@ if __name__ == '__main__': 'n_workers': 0, 'batch_size': 16, } - from data import create_dataset, create_dataloader + from data import create_dataloader, create_dataset ds = create_dataset(params) dl = create_dataloader(ds, params) i = 0 for b in tqdm(dl): - #pass + # pass torchvision.utils.save_image((b['mel'].unsqueeze(1)+1)/2, f'{i}.png') i += 1 if i > 20: diff --git a/dlas/data/audio/unsupervised_audio_dataset.py b/dlas/data/audio/unsupervised_audio_dataset.py index 5118329a..79451b6b 100644 --- a/dlas/data/audio/unsupervised_audio_dataset.py +++ b/dlas/data/audio/unsupervised_audio_dataset.py @@ -3,15 +3,16 @@ import random import sys import torch -import torch.utils.data import torch.nn.functional as F +import torch.utils.data import torchaudio from audio2numpy import open_audio from tqdm import tqdm -from data.util import find_files_of_type, is_audio_file, load_paths_from_cache -from models.audio.tts.tacotron2.taco_utils import load_wav_to_torch -from utils.util import opt_get +from dlas.data.util import (find_files_of_type, is_audio_file, + load_paths_from_cache) +from dlas.models.audio.tts.tacotron2.taco_utils import load_wav_to_torch +from dlas.utils.util import opt_get def load_audio(audiopath, sampling_rate): @@ -53,17 +54,21 @@ def load_similar_clips(path, sample_length, sample_rate, n=3, fallback_to_self=T similarities = torch.load(sim_path) fname = os.path.basename(path) if fname in similarities.keys(): - candidates = [os.path.join(os.path.dirname(path), s) for s in similarities[fname]] + candidates = [os.path.join(os.path.dirname(path), s) + for s in similarities[fname]] else: - print(f'Similarities list found for {path} but {fname} was not in that list.') - #candidates.append(path) # Always include self as a possible similar clip. + print( + f'Similarities list found for {path} but {fname} was not in that list.') + # candidates.append(path) # Always include self as a possible similar clip. if len(candidates) == 0: if fallback_to_self: candidates = [path] else: - candidates = find_files_of_type('img', os.path.dirname(path), qualifier=is_audio_file)[0] + candidates = find_files_of_type( + 'img', os.path.dirname(path), qualifier=is_audio_file)[0] - assert len(candidates) < 50000 # Sanity check to ensure we aren't loading "related files" that aren't actually related. + # Sanity check to ensure we aren't loading "related files" that aren't actually related. + assert len(candidates) < 50000 if len(candidates) == 0: print(f"No conditioning candidates found for {path}") raise NotImplementedError() @@ -92,7 +97,8 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): def __init__(self, opt): path = opt['path'] - cache_path = opt['cache_path'] # Will fail when multiple paths specified, must be specified in this case. + # Will fail when multiple paths specified, must be specified in this case. + cache_path = opt['cache_path'] exclusions = [] if 'exclusions' in opt.keys(): for exc in opt['exclusions']: @@ -102,7 +108,8 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): assert isinstance(ew, list) not_ew = opt_get(opt, ['not_endswith'], []) assert isinstance(not_ew, list) - self.audiopaths = load_paths_from_cache(path, cache_path, exclusions, endswith=ew, not_endswith=not_ew) + self.audiopaths = load_paths_from_cache( + path, cache_path, exclusions, endswith=ew, not_endswith=not_ew) # Parse options self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050) @@ -121,7 +128,8 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): self.extra_samples = opt_get(opt, ['extra_samples'], 0) self.extra_sample_len = opt_get(opt, ['extra_sample_length'], 44000) - self.debug_loading_failures = opt_get(opt, ['debug_loading_failures'], True) + self.debug_loading_failures = opt_get( + opt, ['debug_loading_failures'], True) def get_audio_for_index(self, index): audiopath = self.audiopaths[index] @@ -144,8 +152,9 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): alt_files, alt_is_self = self.get_related_audio_for_index(index) except: if self.debug_loading_failures: - print(f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}") - return self[random.randint(0,len(self))] + print( + f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}") + return self[random.randint(0, len(self))] # When generating resampled clips, skew is a bias that tries to spread them out from each other, reducing their # influence on one another. @@ -157,10 +166,12 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): for sk in skew: if self.pad_to is not None: if audio_norm.shape[-1] <= self.pad_to: - clips.append(torch.nn.functional.pad(audio_norm, (0, self.pad_to - audio_norm.shape[-1]))) + clips.append(torch.nn.functional.pad( + audio_norm, (0, self.pad_to - audio_norm.shape[-1]))) else: gap = audio_norm.shape[-1] - self.pad_to - start = min(max(random.randint(0, gap-1) + sk * gap // 2, 0), gap-1) + start = min( + max(random.randint(0, gap-1) + sk * gap // 2, 0), gap-1) clips.append(audio_norm[:, start:start+self.pad_to]) else: clips.append(audio_norm) @@ -200,16 +211,18 @@ if __name__ == '__main__': 'n_workers': 1, 'batch_size': 16, } - from data import create_dataset, create_dataloader + from data import create_dataloader, create_dataset ds = create_dataset(params) dl = create_dataloader(ds, params) i = 0 for b in tqdm(dl): for b_ in range(b['clip'].shape[0]): - #pass - torchaudio.save(f'{i}_clip_{b_}.wav', b['clip'][b_], ds.sampling_rate) - torchaudio.save(f'{i}_alt_clip_{b_}.wav', b['alt_clips'][b_], ds.sampling_rate) + # pass + torchaudio.save(f'{i}_clip_{b_}.wav', + b['clip'][b_], ds.sampling_rate) + torchaudio.save(f'{i}_alt_clip_{b_}.wav', + b['alt_clips'][b_], ds.sampling_rate) i += 1 if i > 200: break diff --git a/dlas/data/audio/voice_tokenizer.py b/dlas/data/audio/voice_tokenizer.py index a1a7ef18..790a4e0f 100644 --- a/dlas/data/audio/voice_tokenizer.py +++ b/dlas/data/audio/voice_tokenizer.py @@ -1,16 +1,17 @@ +import json import re import torch -import json - from tokenizers import Tokenizer from tokenizers.models import BPE from tokenizers.pre_tokenizers import Whitespace from tokenizers.trainers import BpeTrainer -from data.audio.paired_voice_audio_dataset import load_mozilla_cv, load_voxpopuli, load_tsv -from models.audio.tts.tacotron2 import load_filepaths_and_text -from models.audio.tts.tacotron2.text.cleaners import english_cleaners +from dlas.data.audio.paired_voice_audio_dataset import (load_mozilla_cv, + load_tsv, + load_voxpopuli) +from dlas.models.audio.tts.tacotron2 import load_filepaths_and_text +from dlas.models.audio.tts.tacotron2.text.cleaners import english_cleaners def remove_extraneous_punctuation(word): @@ -21,7 +22,8 @@ def remove_extraneous_punctuation(word): '—': '-', '`': '\'', 'ʼ': '\'' } - replace = re.compile("|".join([re.escape(k) for k in sorted(replacement_punctuation, key=len, reverse=True)]), flags=re.DOTALL) + replace = re.compile("|".join([re.escape(k) for k in sorted( + replacement_punctuation, key=len, reverse=True)]), flags=re.DOTALL) word = replace.sub(lambda x: replacement_punctuation[x.group(0)], word) # TODO: some of these are spoken ('@', '%', '+', etc). Integrate them into the cleaners. @@ -29,27 +31,32 @@ def remove_extraneous_punctuation(word): word = extraneous.sub('', word) return word + def expand_numbers(text): - return normalize_numbers(text) + return normalize_numbers(text) def lowercase(text): - return text.lower() + return text.lower() + _whitespace_re = re.compile(r'\s+') + def collapse_whitespace(text): - return re.sub(_whitespace_re, ' ', text) + return re.sub(_whitespace_re, ' ', text) def convert_to_ascii(text): - return unidecode(text) + return unidecode(text) + def basic_cleaners(text): - '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' - text = lowercase(text) - text = collapse_whitespace(text) - return text + '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' + text = lowercase(text) + text = collapse_whitespace(text) + return text + class VoiceBpeTokenizer: def __init__(self, vocab_file, preprocess=None): @@ -72,23 +79,24 @@ class VoiceBpeTokenizer: kks = pykakasi.kakasi() results = kks.convert(txt) - txt = " ".join([ result['kana'] for result in results ]) + txt = " ".join([result['kana'] for result in results]) txt = basic_cleaners(txt) else: txt = english_cleaners(txt) - + return txt def encode(self, txt): if self.preprocess: - txt = self.preprocess_text(txt) + txt = self.preprocess_text(txt) txt = txt.replace(' ', '[SPACE]') return self.tokenizer.encode(txt).ids def decode(self, seq): if isinstance(seq, torch.Tensor): seq = seq.cpu().numpy() - txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(' ', '') + txt = self.tokenizer.decode( + seq, skip_special_tokens=False).replace(' ', '') txt = txt.replace('[SPACE]', ' ') txt = txt.replace('[STOP]', '') txt = txt.replace('[UNK]', '') @@ -117,10 +125,11 @@ def build_text_file_from_priors(priors, output): def train(): with open('all_texts.txt', 'r', encoding='utf-8') as at: ttsd = at.readlines() - #bcd = datasets.load_dataset('bookcorpus', cache_dir='Z:\\huggingface_datasets\\cache')['train'] + # bcd = datasets.load_dataset('bookcorpus', cache_dir='Z:\\huggingface_datasets\\cache')['train'] - #allowed_characters_re = re.compile(r'^[0-9a-z!@#%_=:;"/, \-\$\^&\*\(\)\+\{\[\]\}\\\.\'\?—–ʼ]+$') + # allowed_characters_re = re.compile(r'^[0-9a-z!@#%_=:;"/, \-\$\^&\*\(\)\+\{\[\]\}\\\.\'\?—–ʼ]+$') allowed_characters_re = re.compile(r'^[a-z!:;"/, \-\(\)\.\'\?ʼ]+$') + def preprocess_word(word, report=False): word = english_cleaners(word) word = remove_extraneous_punctuation(word) @@ -135,16 +144,19 @@ def train(): for i in range(0, len(ttsd), batch_size): yield [preprocess_word(t, True) for t in ttsd[i:i+batch_size]] - #print("Processing bookcorpus.") - #for i in range(0, len(bcd), batch_size): + # print("Processing bookcorpus.") + # for i in range(0, len(bcd), batch_size): # yield [preprocess_word(t) for t in bcd[i:i+batch_size]['text']] - trainer = BpeTrainer(special_tokens=['[STOP]', '[UNK]', '[SPACE]'], vocab_size=255) + trainer = BpeTrainer( + special_tokens=['[STOP]', '[UNK]', '[SPACE]'], vocab_size=255) tokenizer = Tokenizer(BPE(unk_token="[UNK]")) tokenizer.pre_tokenizer = Whitespace() - tokenizer.train_from_iterator(batch_iterator(), trainer, length=len(ttsd))#+len(bcd)) + tokenizer.train_from_iterator( + batch_iterator(), trainer, length=len(ttsd)) # +len(bcd)) - print(tokenizer.decode(tokenizer.encode("i was traveling throughhadslfghds the woods in 1235375t137{{}}").ids)) + print(tokenizer.decode(tokenizer.encode( + "i was traveling throughhadslfghds the woods in 1235375t137{{}}").ids)) tokenizer.save('gpt_tts_tokenizer.json') @@ -171,5 +183,5 @@ if __name__ == '__main__': ('Y:\\clips\\books2-transcribed.tsv', 'tsv'), ('Y:\\clips\\podcasts-0-transcribed.tsv', 'tsv')], 'all_texts.txt') ''' - #train() + # train() test() diff --git a/dlas/data/audio/wav_aug.py b/dlas/data/audio/wav_aug.py index cc21b972..5955b4f7 100644 --- a/dlas/data/audio/wav_aug.py +++ b/dlas/data/audio/wav_aug.py @@ -3,19 +3,19 @@ import random import torch import torchaudio.sox_effects -from models.audio.tts.tacotron2.taco_utils import load_wav_to_torch +from dlas.models.audio.tts.tacotron2.taco_utils import load_wav_to_torch # Returns random double on [l,h] as a string -def rdstr(l=0,h=1): +def rdstr(l=0, h=1): assert h > l - i=h-l + i = h-l return str(random.random() * i + l) # Returns a randint on [s,e] as a string def rdi(e, s=0): - return str(random.randint(s,e)) + return str(random.randint(s, e)) class WavAugmentor: @@ -43,12 +43,13 @@ class WavAugmentor: band_effect = random.choice(band_effects) ''' volume_effects = [ - ['loudness', rdi(10,-2)], - ['overdrive', rdi(20,0), rdi(20,0)], + ['loudness', rdi(10, -2)], + ['overdrive', rdi(20, 0), rdi(20, 0)], ] vol_effect = random.choice(volume_effects) effects = [speed_effect, vol_effect] - out, sr = torchaudio.sox_effects.apply_effects_tensor(wav, sample_rate, effects) + out, sr = torchaudio.sox_effects.apply_effects_tensor( + wav, sample_rate, effects) # Add a variable amount of noise out = out + torch.rand_like(out) * random.random() * .03 return out @@ -60,4 +61,4 @@ if __name__ == '__main__': aug = WavAugmentor() for j in range(10): out = aug.augment(sample, 24000) - torchaudio.save(f'out{j}.wav', out, 24000) \ No newline at end of file + torchaudio.save(f'out{j}.wav', out, 24000) diff --git a/dlas/data/combined_dataset.py b/dlas/data/combined_dataset.py index b91b1b94..598db78b 100644 --- a/dlas/data/combined_dataset.py +++ b/dlas/data/combined_dataset.py @@ -1,5 +1,6 @@ import torch -from data import create_dataset + +from dlas.data import create_dataset # Simple composite dataset that combines multiple other datasets. @@ -31,4 +32,4 @@ class CombinedDataset(torch.utils.data.Dataset): return output def __len__(self): - return max(len(d) for d in self.datasets.values()) \ No newline at end of file + return max(len(d) for d in self.datasets.values()) diff --git a/dlas/data/data_sampler.py b/dlas/data/data_sampler.py index 9c409418..43464d4f 100644 --- a/dlas/data/data_sampler.py +++ b/dlas/data/data_sampler.py @@ -4,9 +4,10 @@ Support enlarging the dataset for *iteration-oriented* training, for saving time dataloader after each epoch """ import math + import torch -from torch.utils.data.sampler import Sampler import torch.distributed as dist +from torch.utils.data.sampler import Sampler class DistIterSampler(Sampler): @@ -30,17 +31,20 @@ class DistIterSampler(Sampler): def __init__(self, dataset, num_replicas=None, rank=None, ratio=100): if num_replicas is None: if not dist.is_available(): - raise RuntimeError("Requires distributed package to be available") + raise RuntimeError( + "Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): - raise RuntimeError("Requires distributed package to be available") + raise RuntimeError( + "Requires distributed package to be available") rank = dist.get_rank() self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 - self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas)) + self.num_samples = int( + math.ceil(len(self.dataset) * ratio / self.num_replicas)) self.total_size = self.num_samples * self.num_replicas def __iter__(self): diff --git a/dlas/data/images/base_unsupervised_image_dataset.py b/dlas/data/images/base_unsupervised_image_dataset.py index ce8b8c3f..0eb8525e 100644 --- a/dlas/data/images/base_unsupervised_image_dataset.py +++ b/dlas/data/images/base_unsupervised_image_dataset.py @@ -1,10 +1,13 @@ -import torch -from torch.utils import data -from data.images.image_corruptor import ImageCorruptor -from data.images.chunk_with_reference import ChunkWithReference import os + import cv2 import numpy as np +import torch +from torch.utils import data + +from dlas.data.images.chunk_with_reference import ChunkWithReference +from dlas.data.images.image_corruptor import ImageCorruptor + # Class whose purpose is to hold as much logic as can possibly be shared between datasets that operate on raw image # data and nothing else (which also have a very specific directory structure being used, as dictated by @@ -13,13 +16,17 @@ class BaseUnsupervisedImageDataset(data.Dataset): def __init__(self, opt): self.opt = opt self.corruptor = ImageCorruptor(opt) - self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys() else None - self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys() else 1 + self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys( + ) else None + self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys( + ) else 1 self.for_eval = opt['eval'] if 'eval' in opt.keys() else False self.scale = opt['scale'] if not self.for_eval else 1 self.paths = opt['paths'] - self.corrupt_before_downsize = opt['corrupt_before_downsize'] if 'corrupt_before_downsize' in opt.keys() else False - assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors. + self.corrupt_before_downsize = opt['corrupt_before_downsize'] if 'corrupt_before_downsize' in opt.keys( + ) else False + # If we dont throw here, we get some really obscure errors. + assert (self.target_hq_size // self.scale) % self.multiple == 0 if not isinstance(self.paths, list): self.paths = [self.paths] self.weights = [1] @@ -34,8 +41,10 @@ class BaseUnsupervisedImageDataset(data.Dataset): if os.path.exists(cache_path): chunks = torch.load(cache_path) else: - print("Building chunk cache, this can take some time for large datasets..") - chunks = [ChunkWithReference(opt, d) for d in sorted(os.scandir(path), key=lambda e: e.name) if d.is_dir()] + print( + "Building chunk cache, this can take some time for large datasets..") + chunks = [ChunkWithReference(opt, d) for d in sorted( + os.scandir(path), key=lambda e: e.name) if d.is_dir()] # Prune out chunks that have no images res = [] for c in chunks: @@ -60,7 +69,7 @@ class BaseUnsupervisedImageDataset(data.Dataset): for c in self.chunks: paths.extend(c.tiles) return paths - + # Utility method for translating a point when the dimensions of an image change. def resize_point(self, point, orig_dim, new_dim): oh, ow = orig_dim @@ -78,20 +87,27 @@ class BaseUnsupervisedImageDataset(data.Dataset): for hq, hq_ref, hq_mask, hq_center in zip(imgs_hq, refs_hq, masks_hq, centers_hq): # It is assumed that the target size is a square. target_size = (self.target_hq_size, self.target_hq_size) - hqs_adjusted.append(cv2.resize(hq, target_size, interpolation=cv2.INTER_AREA)) - hq_refs_adjusted.append(cv2.resize(hq_ref, target_size, interpolation=cv2.INTER_AREA)) - hq_masks_adjusted.append(cv2.resize(hq_mask, target_size, interpolation=cv2.INTER_AREA)) - hq_centers_adjusted.append(self.resize_point(hq_center, (h, w), target_size)) + hqs_adjusted.append(cv2.resize( + hq, target_size, interpolation=cv2.INTER_AREA)) + hq_refs_adjusted.append(cv2.resize( + hq_ref, target_size, interpolation=cv2.INTER_AREA)) + hq_masks_adjusted.append(cv2.resize( + hq_mask, target_size, interpolation=cv2.INTER_AREA)) + hq_centers_adjusted.append( + self.resize_point(hq_center, (h, w), target_size)) h, w = self.target_hq_size, self.target_hq_size else: hqs_adjusted, hq_refs_adjusted, hq_masks_adjusted, hq_centers_adjusted = imgs_hq, refs_hq, masks_hq, centers_hq - hq_masks_adjusted = [m.squeeze(-1) for m in hq_masks_adjusted] # This is done implicitly above.. - hq_multiple = self.multiple * self.scale # Multiple must apply to LQ image. + # This is done implicitly above.. + hq_masks_adjusted = [m.squeeze(-1) for m in hq_masks_adjusted] + # Multiple must apply to LQ image. + hq_multiple = self.multiple * self.scale if h % hq_multiple != 0 or w % hq_multiple != 0: hqs_conformed, hq_refs_conformed, hq_masks_conformed, hq_centers_conformed = [], [], [], [] for hq, hq_ref, hq_mask, hq_center in zip(hqs_adjusted, hq_refs_adjusted, hq_masks_adjusted, hq_centers_adjusted): h, w = (h - h % hq_multiple), (w - w % hq_multiple) - hq_centers_conformed.append(self.resize_point(hq_center, hq.shape[:2], (h, w))) + hq_centers_conformed.append( + self.resize_point(hq_center, hq.shape[:2], (h, w))) hqs_conformed.append(hq[:h, :w, :]) hq_refs_conformed.append(hq_ref[:h, :w, :]) hq_masks_conformed.append(hq_mask[:h, :w, :]) @@ -110,10 +126,14 @@ class BaseUnsupervisedImageDataset(data.Dataset): lms.append(hq_mask) lcs.append(hq_center) else: - ls.append(cv2.resize(hq, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA)) - lrs.append(cv2.resize(hq_ref, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA)) - lms.append(cv2.resize(hq_mask, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA)) - lcs.append(self.resize_point(hq_center, (h, w), ls[0].shape[:2])) + ls.append(cv2.resize(hq, (h // self.scale, w // + self.scale), interpolation=cv2.INTER_AREA)) + lrs.append(cv2.resize(hq_ref, (h // self.scale, w // + self.scale), interpolation=cv2.INTER_AREA)) + lms.append(cv2.resize(hq_mask, (h // self.scale, w // + self.scale), interpolation=cv2.INTER_AREA)) + lcs.append(self.resize_point( + hq_center, (h, w), ls[0].shape[:2])) # Corrupt the LQ image (only in eval mode) if not self.corrupt_before_downsize and not self.for_eval: ls = self.corruptor.corrupt_images(ls) diff --git a/dlas/data/images/byol_attachment.py b/dlas/data/images/byol_attachment.py index 57b1fb70..64346373 100644 --- a/dlas/data/images/byol_attachment.py +++ b/dlas/data/images/byol_attachment.py @@ -3,22 +3,20 @@ from time import time import kornia import numpy as np - import torch -import torchvision -from torch.utils.data import Dataset -from kornia import augmentation as augs, geometry -from kornia import filters import torch.nn as nn import torch.nn.functional as F - +import torchvision +from kornia import augmentation as augs +from kornia import filters, geometry +from torch.utils.data import Dataset # Wrapper for a DLAS Dataset class that applies random augmentations from the BYOL paper to BOTH the 'lq' and 'hq' # inputs. These are then outputted as 'aug1' and 'aug2'. from tqdm import tqdm -from data import create_dataset -from models.arch_util import PixelUnshuffle -from utils.util import opt_get +from dlas.data import create_dataset +from dlas.models.arch_util import PixelUnshuffle +from dlas.utils.util import opt_get class RandomApply(nn.Module): @@ -26,6 +24,7 @@ class RandomApply(nn.Module): super().__init__() self.fn = fn self.p = p + def forward(self, x): if random.random() > self.p: return x @@ -39,9 +38,10 @@ class ByolDatasetWrapper(Dataset): self.cropped_img_size = opt['crop_size'] self.key1 = opt_get(opt, ['key1'], 'hq') self.key2 = opt_get(opt, ['key2'], 'lq') - for_sr = opt_get(opt, ['for_sr'], False) # When set, color alterations and blurs are disabled. + # When set, color alterations and blurs are disabled. + for_sr = opt_get(opt, ['for_sr'], False) - augmentations = [ \ + augmentations = [ augs.RandomHorizontalFlip(), augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))] if not for_sr: @@ -51,12 +51,14 @@ class ByolDatasetWrapper(Dataset): if opt['normalize']: # The paper calls for normalization. Most datasets/models in this repo don't use this. # Recommend setting true if you want to train exactly like the paper. - augmentations.append(augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))) + augmentations.append(augs.Normalize(mean=torch.tensor( + [0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))) self.aug = nn.Sequential(*augmentations) def __getitem__(self, item): item = self.wrapped_dataset[item] - item.update({'aug1': self.aug(item[self.key1]).squeeze(dim=0), 'aug2': self.aug(item[self.key2]).squeeze(dim=0)}) + item.update({'aug1': self.aug(item[self.key1]).squeeze( + dim=0), 'aug2': self.aug(item[self.key2]).squeeze(dim=0)}) return item def __len__(self): @@ -71,7 +73,7 @@ class DatasetRandomAugWrapper(Dataset): self.wrapped_dataset = create_dataset(opt['dataset']) self.cropped_img_size = opt['crop_size'] self.includes_labels = opt['includes_labels'] - augmentations = [ \ + augmentations = [ RandomApply(augs.ColorJitter(0.4, 0.4, 0.4, 0.2), p=0.8), augs.RandomGrayscale(p=0.2), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)] @@ -87,18 +89,19 @@ class DatasetRandomAugWrapper(Dataset): dtypes = [] for k in item.keys(): if 'label' in k and isinstance(item[k], torch.Tensor) and len(item[k].shape) == 3: - assert item[k].shape[0] == 1 # Only supports a channel dim of 1. + # Only supports a channel dim of 1. + assert item[k].shape[0] == 1 labels.append(k) dtypes.append(item[k].dtype) hq = torch.cat([hq, item[k].type(torch.float)], dim=0) hq = self.rrc(hq.unsqueeze(0)).squeeze(0) for i, k in enumerate(labels): # Strip out any label values that are not whole numbers. - item[k] = hq[3+i:3+i+1,:,:] + item[k] = hq[3+i:3+i+1, :, :] whole = (item[k].round() == item[k]) item[k] = item[k] * whole item[k] = item[k].type(dtypes[i]) - item['lq'] = hq[:3,:,:] + item['lq'] = hq[:3, :, :] item['hq'] = item['lq'] return item @@ -137,14 +140,16 @@ def test_dataset_random_aug_wrapper(): for k, v in o.items(): # 'lq', 'hq', 'aug1', 'aug2', if k in ['hq']: - torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) + torchvision.utils.save_image( + v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) masked = v * (o['labels_mask'] * .5 + .5) - #torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k)) + # torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k)) # Pick a random (non-zero) label and spit it out with the textual label. if len(o['labels'].unique()) > 1: randlbl = np.random.choice(o['labels'].unique()[1:]) moremask = v * ((1*(o['labels'] == randlbl))*.5+.5) - torchvision.utils.save_image(moremask.unsqueeze(0), "debug/%i_%s_%s.png" % (i, k, o['label_strings'][randlbl])) + torchvision.utils.save_image(moremask.unsqueeze( + 0), "debug/%i_%s_%s.png" % (i, k, o['label_strings'][randlbl])) def no_batch_interpolate(i, size, mode): @@ -165,10 +170,10 @@ def snap(ref, other): # Pads a tensor with zeros so that it fits in a dxd square. def pad_to(im, d): if len(im.shape) == 3: - pd = torch.zeros((im.shape[0],d,d)) + pd = torch.zeros((im.shape[0], d, d)) pd[:, :im.shape[1], :im.shape[2]] = im else: - pd = torch.zeros((im.shape[0],im.shape[1],d,d), device=im.device) + pd = torch.zeros((im.shape[0], im.shape[1], d, d), device=im.device) pd[:, :, :im.shape[2], :im.shape[3]] = im return pd @@ -182,7 +187,8 @@ class RandomSharedRegionCrop(nn.Module): def __init__(self, multiple, jitter_range=0): super().__init__() self.multiple = multiple - self.jitter_range = jitter_range # When specified, images are shifted an additional random([-j,j]) pixels where j=jitter_range + # When specified, images are shifted an additional random([-j,j]) pixels where j=jitter_range + self.jitter_range = jitter_range def forward(self, i1, i2): assert i1.shape[-1] == i2.shape[-1] @@ -218,19 +224,25 @@ class RandomSharedRegionCrop(nn.Module): # Step 4 m = self.multiple - jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range) - jt = jt if base_t != 0 else abs(jt) # If the top of a patch is zero, a negative jitter will cause it to go negative. - jt = jt if (base_t+base_h)*m != i1.shape[1] else 0 # Likewise, jitter shouldn't allow the patch to go over-bounds. + jl, jt = random.randint(-self.jitter_range, + self.jitter_range), random.randint(-self.jitter_range, self.jitter_range) + # If the top of a patch is zero, a negative jitter will cause it to go negative. + jt = jt if base_t != 0 else abs(jt) + # Likewise, jitter shouldn't allow the patch to go over-bounds. + jt = jt if (base_t+base_h)*m != i1.shape[1] else 0 jl = jl if base_l != 0 else abs(jl) jl = jl if (base_l+base_w)*m != i1.shape[1] else 0 - p1 = i1[:, base_t*m+jt:(base_t+base_h)*m+jt, base_l*m+jl:(base_l+base_w)*m+jl] + p1 = i1[:, base_t*m+jt:(base_t+base_h)*m+jt, + base_l*m+jl:(base_l+base_w)*m+jl] p1_resized = no_batch_interpolate(p1, size=(d*m, d*m), mode="bilinear") - jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range) + jl, jt = random.randint(-self.jitter_range, + self.jitter_range), random.randint(-self.jitter_range, self.jitter_range) jt = jt if im2_t != 0 else abs(jt) jt = jt if (im2_t+im2_h)*m != i2.shape[1] else 0 jl = jl if im2_l != 0 else abs(jl) jl = jl if (im2_l+im2_w)*m != i2.shape[1] else 0 - p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt, im2_l*m+jl:(im2_l+im2_w)*m+jl] + p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt, + im2_l*m+jl:(im2_l+im2_w)*m+jl] p2_resized = no_batch_interpolate(p2, size=(d*m, d*m), mode="bilinear") # Step 5 @@ -246,14 +258,17 @@ class RandomSharedRegionCrop(nn.Module): i2_shared_t, i2_shared_l = snap(im2_t, base_t), snap(im2_l, base_l) ix_h = min(base_b, im2_b) - max(base_t, im2_t) ix_w = min(base_r, im2_r) - max(base_l, im2_l) - recompute_package = torch.tensor([d, base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, should_flip, ix_h, ix_w], dtype=torch.long) + recompute_package = torch.tensor([d, base_h, base_w, i1_shared_t, i1_shared_l, im2_h, + im2_w, i2_shared_t, i2_shared_l, should_flip, ix_h, ix_w], dtype=torch.long) # Step 7 mask1 = torch.full((1, base_h*m, base_w*m), fill_value=.5) - mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m, i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1 + mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m, + i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1 masked1 = pad_to(p1 * mask1, d*m) mask2 = torch.full((1, im2_h*m, im2_w*m), fill_value=.5) - mask2[:, i2_shared_t*m:(i2_shared_t+ix_h)*m, i2_shared_l*m:(i2_shared_l+ix_w)*m] = 1 + mask2[:, i2_shared_t*m:(i2_shared_t+ix_h)*m, + i2_shared_l*m:(i2_shared_l+ix_w)*m] = 1 masked2 = pad_to(p2 * mask2, d*m) mask = torch.full((1, d*m, d*m), fill_value=.33) mask[:, base_t*m:(base_t+base_w)*m, base_l*m:(base_l+base_h)*m] += .33 @@ -262,10 +277,13 @@ class RandomSharedRegionCrop(nn.Module): # Step 8 - Rebuild shared regions for testing purposes. p1_shuf, p2_shuf = PixelUnshuffle(self.multiple)(p1_resized.unsqueeze(0)), \ - PixelUnshuffle(self.multiple)(p2_resized.unsqueeze(0)) - i1_shared, i2_shared = reconstructed_shared_regions(p1_shuf, p2_shuf, recompute_package.unsqueeze(0)) - i1_shared = pad_to(nn.PixelShuffle(self.multiple)(i1_shared).squeeze(0), d * m) - i2_shared = pad_to(nn.PixelShuffle(self.multiple)(i2_shared).squeeze(0), d*m) + PixelUnshuffle(self.multiple)(p2_resized.unsqueeze(0)) + i1_shared, i2_shared = reconstructed_shared_regions( + p1_shuf, p2_shuf, recompute_package.unsqueeze(0)) + i1_shared = pad_to(nn.PixelShuffle(self.multiple) + (i1_shared).squeeze(0), d * m) + i2_shared = pad_to(nn.PixelShuffle(self.multiple) + (i2_shared).squeeze(0), d*m) return p1_resized, p2_resized, recompute_package, masked1, masked2, masked_dbg, i1_shared, i2_shared @@ -280,7 +298,8 @@ def reconstructed_shared_regions(fea1, fea2, recompute_package: torch.Tensor): # It'd be real nice if we could do this at the batch level, but I don't see a really good way to do that outside # of conforming the recompute_package across the entire batch. for b in range(package.shape[0]): - expected_dim, f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, should_flip, s_h, s_w = tuple(package[b].tolist()) + expected_dim, f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, should_flip, s_h, s_w = tuple( + package[b].tolist()) # If you are hitting this assert, you specified `latent_multiple` in your dataset config wrong. assert expected_dim == fea1.shape[2] and expected_dim == fea2.shape[2] @@ -292,8 +311,10 @@ def reconstructed_shared_regions(fea1, fea2, recompute_package: torch.Tensor): f1s = F.interpolate(fea1[b].unsqueeze(0), (f1_h, f1_w), mode="nearest") f2s = F.interpolate(f2.unsqueeze(0), (f2_h, f2_w), mode="nearest") # Outputs must be padded so they can "get along" with each other. - res1.append(pad_to(f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w], pad_dim)) - res2.append(pad_to(f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w], pad_dim)) + res1.append( + pad_to(f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w], pad_dim)) + res2.append( + pad_to(f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w], pad_dim)) return torch.cat(res1, dim=0), torch.cat(res2, dim=0) @@ -308,10 +329,11 @@ class StructuredCropDatasetWrapper(Dataset): super().__init__() self.wrapped_dataset = create_dataset(opt['dataset']) augmentations = [RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), - augs.RandomGrayscale(p=0.2), - RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)] + augs.RandomGrayscale(p=0.2), + RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)] self.aug = nn.Sequential(*augmentations) - self.rrc = RandomSharedRegionCrop(opt['latent_multiple'], opt_get(opt, ['jitter_range'], 0)) + self.rrc = RandomSharedRegionCrop( + opt['latent_multiple'], opt_get(opt, ['jitter_range'], 0)) def __getitem__(self, item): item = self.wrapped_dataset[item] @@ -332,17 +354,17 @@ def test_structured_crop_dataset_wrapper(): opt = { 'dataset': { - 'mode': 'imagefolder', - 'name': 'amalgam', - 'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'], - 'weights': [1], - 'target_size': 256, - 'force_multiple': 32, - 'scale': 1, - 'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'], - 'random_corruptions': ['noise-5', 'none'], - 'num_corrupts_per_image': 1, - 'corrupt_before_downsize': True, + 'mode': 'imagefolder', + 'name': 'amalgam', + 'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'], + 'weights': [1], + 'target_size': 256, + 'force_multiple': 32, + 'scale': 1, + 'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'], + 'random_corruptions': ['noise-5', 'none'], + 'num_corrupts_per_image': 1, + 'corrupt_before_downsize': True, }, 'latent_multiple': 16, 'jitter_range': 0, @@ -353,16 +375,17 @@ def test_structured_crop_dataset_wrapper(): os.makedirs("debug", exist_ok=True) for i in tqdm(range(0, len(ds))): o = ds[random.randint(0, len(ds)-1)] - #for k, v in o.items(): - # 'lq', 'hq', 'aug1', 'aug2', - #if k in [ 'aug_shared_view', 'masked1', 'masked2']: - #torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) + # for k, v in o.items(): + # 'lq', 'hq', 'aug1', 'aug2', + # if k in [ 'aug_shared_view', 'masked1', 'masked2']: + # torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) rcpkg = o['similar_region_dimensions'] pixun = PixelUnshuffle(16) pixsh = nn.PixelShuffle(16) - rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(0)), pixun(o['aug2'].unsqueeze(0)), rcpkg.unsqueeze(0)) - #torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,)) - #torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,)) + rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze( + 0)), pixun(o['aug2'].unsqueeze(0)), rcpkg.unsqueeze(0)) + # torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,)) + # torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,)) if __name__ == '__main__': diff --git a/dlas/data/images/chunk_with_reference.py b/dlas/data/images/chunk_with_reference.py index c7bce448..f8341cb6 100644 --- a/dlas/data/images/chunk_with_reference.py +++ b/dlas/data/images/chunk_with_reference.py @@ -1,17 +1,19 @@ import os.path as osp -from data import util -import torch -import numpy as np +import numpy as np +import torch + +from dlas.data import util # Iterable that reads all the images in a directory that contains a reference image, tile images and center coordinates. -from utils.util import opt_get +from dlas.utils.util import opt_get class ChunkWithReference: def __init__(self, opt, path): self.path = path.path self.tiles, _ = util.find_files_of_type('img', self.path) - self.need_metadata = opt_get(opt, ['strict'], False) or opt_get(opt, ['needs_metadata'], False) + self.need_metadata = opt_get(opt, ['strict'], False) or opt_get( + opt, ['needs_metadata'], False) self.need_ref = opt_get(opt, ['need_ref'], False) if 'ignore_first' in opt.keys(): self.tiles = self.tiles[opt['ignore_first']:] @@ -41,12 +43,14 @@ class ChunkWithReference: else: center = torch.tensor([128, 128], dtype=torch.long) tile_width = 256 - mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype) - mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1 + mask = np.full(tile.shape[:2] + (1,), + fill_value=.1, dtype=tile.dtype) + mask[center[0] - tile_width // 2:center[0] + tile_width // 2, + center[1] - tile_width // 2:center[1] + tile_width // 2] = 1 else: ref = np.zeros_like(tile) mask = np.zeros(tile.shape[:2] + (1,)) - center = (0,0) + center = (0, 0) return tile, ref, center, mask, self.tiles[item] diff --git a/dlas/data/images/cifar.py b/dlas/data/images/cifar.py index c8ff7c9c..538bc074 100644 --- a/dlas/data/images/cifar.py +++ b/dlas/data/images/cifar.py @@ -1,14 +1,15 @@ # A copy of the cifar dataset from torch which also returns coarse labels. -from PIL import Image import os import os.path -import numpy as np import pickle from typing import Any, Callable, Optional, Tuple +import numpy as np +from PIL import Image from torchvision.datasets import VisionDataset -from torchvision.datasets.utils import check_integrity, download_and_extract_archive +from torchvision.datasets.utils import (check_integrity, + download_and_extract_archive) class CIFAR10(VisionDataset): @@ -104,7 +105,8 @@ class CIFAR10(VisionDataset): with open(path, 'rb') as infile: data = pickle.load(infile, encoding='latin1') self.classes = data[self.meta['key']] - self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} + self.class_to_idx = {_class: i for i, + _class in enumerate(self.classes)} def __getitem__(self, index: int) -> Tuple[Any, Any]: """ @@ -147,7 +149,8 @@ class CIFAR10(VisionDataset): if self._check_integrity(): print('Files already downloaded and verified') return - download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) + download_and_extract_archive( + self.url, self.root, filename=self.filename, md5=self.tgz_md5) def extra_repr(self) -> str: return "Split: {}".format("Train" if self.train is True else "Test") diff --git a/dlas/data/images/full_image_dataset.py b/dlas/data/images/full_image_dataset.py index 92e14036..8b372818 100644 --- a/dlas/data/images/full_image_dataset.py +++ b/dlas/data/images/full_image_dataset.py @@ -1,12 +1,14 @@ import random -import numpy as np +from io import BytesIO + import cv2 +import numpy as np import torch import torch.utils.data as data -import data.util as util -from PIL import Image, ImageOps -from io import BytesIO import torchvision.transforms.functional as F +from PIL import Image, ImageOps + +import dlas.data.util as util # Reads full-quality images and pulls tiles from them. Also extracts LR renderings of the full image with cues as to @@ -16,6 +18,7 @@ class FullImageDataset(data.Dataset): Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs. If only GT images are provided, generate LQ images on-the-fly. """ + def get_lq_path(self, i): which_lq = random.randint(0, len(self.paths_LQ)-1) return self.paths_LQ[which_lq][i % len(self.paths_LQ[which_lq])] @@ -27,19 +30,23 @@ class FullImageDataset(data.Dataset): self.paths_LQ, self.paths_GT = None, None self.sizes_LQ, self.sizes_GT = None, None self.LQ_env, self.GT_env = None, None - self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1 + self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys( + ) else 1 - self.paths_GT, self.sizes_GT = util.find_files_of_type(self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights']) + self.paths_GT, self.sizes_GT = util.find_files_of_type( + self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights']) if 'dataroot_LQ' in opt.keys(): self.paths_LQ = [] if isinstance(opt['dataroot_LQ'], list): # Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and # we want the model to learn them all. for dr_lq in opt['dataroot_LQ']: - lq_path, self.sizes_LQ = util.find_files_of_type(self.data_type, dr_lq) + lq_path, self.sizes_LQ = util.find_files_of_type( + self.data_type, dr_lq) self.paths_LQ.append(lq_path) else: - lq_path, self.sizes_LQ = util.find_files_of_type(self.data_type, opt['dataroot_LQ']) + lq_path, self.sizes_LQ = util.find_files_of_type( + self.data_type, opt['dataroot_LQ']) self.paths_LQ.append(lq_path) assert self.paths_GT, 'Error: GT path is empty.' @@ -48,7 +55,8 @@ class FullImageDataset(data.Dataset): def motion_blur(self, image, size, angle): k = np.zeros((size, size), dtype=np.float32) k[(size - 1) // 2, :] = np.ones(size, dtype=np.float32) - k = cv2.warpAffine(k, cv2.getRotationMatrix2D((size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size)) + k = cv2.warpAffine(k, cv2.getRotationMatrix2D( + (size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size)) k = k * (1.0 / np.sum(k)) return cv2.filter2D(image, -1, k) @@ -100,8 +108,10 @@ class FullImageDataset(data.Dataset): if 'fixed_size' in self.opt.keys() and self.opt['fixed_size']: square_size = target_sz else: - tile_expansion_dev = self.opt['tile_scale_normal_stddev'] if 'tile_scale_normal_stddev' in self.opt.keys() else .17 - square_size = int(target_sz + possible_sizes_above_target * min(np.abs(np.random.normal(scale=tile_expansion_dev)), 1.0)) + tile_expansion_dev = self.opt['tile_scale_normal_stddev'] if 'tile_scale_normal_stddev' in self.opt.keys( + ) else .17 + square_size = int(target_sz + possible_sizes_above_target * + min(np.abs(np.random.normal(scale=tile_expansion_dev)), 1.0)) # Pick the left,top coords to draw the patch from left = self.pick_along_range(w, square_size, .3) @@ -110,11 +120,15 @@ class FullImageDataset(data.Dataset): mask = np.zeros((h, w, 1), dtype=image.dtype) mask[top:top+square_size, left:left+square_size] = 1 patch = image[top:top+square_size, left:left+square_size, :] - center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long) + center = torch.tensor( + [top + square_size // 2, left + square_size // 2], dtype=torch.long) - patch = cv2.resize(patch, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR) - image = cv2.resize(image, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR) - mask = cv2.resize(mask, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR) + patch = cv2.resize(patch, (target_sz, target_sz), + interpolation=cv2.INTER_LINEAR) + image = cv2.resize(image, (target_sz, target_sz), + interpolation=cv2.INTER_LINEAR) + mask = cv2.resize(mask, (target_sz, target_sz), + interpolation=cv2.INTER_LINEAR) center = self.resize_point(center, (h, w), image.shape[:2]) return patch, image, mask, center @@ -127,19 +141,24 @@ class FullImageDataset(data.Dataset): assert H >= GT_size and W >= GT_size LQ_size = GT_size // scale - img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) - img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) + img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), + interpolation=cv2.INTER_LINEAR) + img_GT = cv2.resize(img_GT, (GT_size, GT_size), + interpolation=cv2.INTER_LINEAR) if self.opt['use_blurring']: # Pick randomly between gaussian, motion, or no blur. blur_det = random.randint(0, 100) - blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude'] + blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys( + ) else self.opt['blur_magnitude'] blur_magnitude = max(1, int(blur_magnitude*strength)) if blur_det < 40: blur_sig = int(random.randrange(0, int(blur_magnitude))) - img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig) + img_LQ = cv2.GaussianBlur( + img_LQ, (blur_magnitude, blur_magnitude), blur_sig) elif blur_det < 70: - img_LQ = self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), random.randint(0, 360)) + img_LQ = self.motion_blur(img_LQ, random.randrange( + 1, int(blur_magnitude) * 3), random.randint(0, 360)) return img_GT, img_LQ @@ -174,13 +193,15 @@ class FullImageDataset(data.Dataset): # Gaussian Blur (point or motion) blur_magnitude = 3 blur_sig = int(random.randrange(0, int(blur_magnitude))) - image = cv2.GaussianBlur(image, (blur_magnitude, blur_magnitude), blur_sig) + image = cv2.GaussianBlur( + image, (blur_magnitude, blur_magnitude), blur_sig) elif 2 in aug_code: # Median Blur image = cv2.medianBlur(image, 3) elif 3 in aug_code: # Motion blur - image = self.motion_blur(image, random.randrange(1, 9), random.randint(0, 360)) + image = self.motion_blur( + image, random.randrange(1, 9), random.randint(0, 360)) elif 4 in aug_code: # Smooth blur image = cv2.blur(image, ksize=3) @@ -217,15 +238,19 @@ class FullImageDataset(data.Dataset): full_path = self.paths_GT[index % len(self.paths_GT)] LQ_path = full_path img_full = util.read_img(None, full_path, None) - img_full = util.channel_convert(img_full.shape[2], 'RGB', [img_full])[0] + img_full = util.channel_convert( + img_full.shape[2], 'RGB', [img_full])[0] if self.opt['phase'] == 'train': - img_full = util.augment([img_full], self.opt['use_flip'], self.opt['use_rot'])[0] + img_full = util.augment( + [img_full], self.opt['use_flip'], self.opt['use_rot'])[0] img_full = self.get_square_image(img_full) - img_GT, gt_fullsize_ref, gt_mask, gt_center = self.pull_tile(img_full) + img_GT, gt_fullsize_ref, gt_mask, gt_center = self.pull_tile( + img_full) else: img_GT, gt_fullsize_ref = img_full, img_full gt_mask = np.ones(img_full.shape[:2], dtype=gt_fullsize_ref.dtype) - gt_center = torch.tensor([img_full.shape[0] // 2, img_full.shape[1] // 2], dtype=torch.long) + gt_center = torch.tensor( + [img_full.shape[0] // 2, img_full.shape[1] // 2], dtype=torch.long) orig_gt_dim = gt_fullsize_ref.shape[:2] # get LQ image @@ -233,13 +258,17 @@ class FullImageDataset(data.Dataset): LQ_path = self.get_lq_path(index) img_lq_full = util.read_img(None, LQ_path, None) if self.opt['phase'] == 'train': - img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0] + img_lq_full = util.augment( + [img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0] img_lq_full = self.get_square_image(img_lq_full) - img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(img_lq_full, lq=True) + img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile( + img_lq_full, lq=True) else: img_LQ, lq_fullsize_ref = img_lq_full, img_lq_full - lq_mask = np.ones(img_lq_full.shape[:2], dtype=lq_fullsize_ref.dtype) - lq_center = torch.tensor([img_lq_full.shape[0] // 2, img_lq_full.shape[1] // 2], dtype=torch.long) + lq_mask = np.ones( + img_lq_full.shape[:2], dtype=lq_fullsize_ref.dtype) + lq_center = torch.tensor( + [img_lq_full.shape[0] // 2, img_lq_full.shape[1] // 2], dtype=torch.long) else: # down-sampling on-the-fly # randomly scale during training if self.opt['phase'] == 'train': @@ -258,7 +287,8 @@ class FullImageDataset(data.Dataset): H_s = _mod(H_s, random_scale, scale, GT_size) W_s = _mod(W_s, random_scale, scale, GT_size) - img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR) + img_GT = cv2.resize(img_GT, (W_s, H_s), + interpolation=cv2.INTER_LINEAR) if img_GT.ndim == 2: img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) @@ -266,10 +296,12 @@ class FullImageDataset(data.Dataset): # using matlab imresize img_LQ = util.imresize_np(img_GT, 1 / scale, True) - lq_fullsize_ref = util.imresize_np(gt_fullsize_ref, 1 / scale, True) + lq_fullsize_ref = util.imresize_np( + gt_fullsize_ref, 1 / scale, True) if img_LQ.ndim == 2: img_LQ = np.expand_dims(img_LQ, axis=2) - lq_mask, lq_center = gt_mask, self.resize_point(gt_center.clone(), orig_gt_dim, lq_fullsize_ref.shape[:2]) + lq_mask, lq_center = gt_mask, self.resize_point( + gt_center.clone(), orig_gt_dim, lq_fullsize_ref.shape[:2]) orig_lq_dim = lq_fullsize_ref.shape[:2] # Enforce force_resize constraints via clipping. @@ -285,15 +317,20 @@ class FullImageDataset(data.Dataset): if self.opt['phase'] == 'train': img_GT, img_LQ = self.augment_tile(img_GT, img_LQ) - gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(gt_fullsize_ref, lq_fullsize_ref, strength=.2) + gt_fullsize_ref, lq_fullsize_ref = self.augment_tile( + gt_fullsize_ref, lq_fullsize_ref, strength=.2) # Scale masks. - lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR) - gt_mask = cv2.resize(gt_mask, (gt_fullsize_ref.shape[1], gt_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR) + lq_mask = cv2.resize( + lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR) + gt_mask = cv2.resize( + gt_mask, (gt_fullsize_ref.shape[1], gt_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR) # Scale center coords - lq_center = self.resize_point(lq_center, orig_lq_dim, lq_fullsize_ref.shape[:2]) - gt_center = self.resize_point(gt_center, orig_gt_dim, gt_fullsize_ref.shape[:2]) + lq_center = self.resize_point( + lq_center, orig_lq_dim, lq_fullsize_ref.shape[:2]) + gt_center = self.resize_point( + gt_center, orig_gt_dim, gt_fullsize_ref.shape[:2]) # BGR to RGB, HWC to CHW, numpy to tensor if img_GT.shape[2] == 3: @@ -303,16 +340,20 @@ class FullImageDataset(data.Dataset): gt_fullsize_ref = cv2.cvtColor(gt_fullsize_ref, cv2.COLOR_BGR2RGB) # LQ needs to go to a PIL image to perform the compression-artifact transformation. - #if self.opt['phase'] == 'train': - #img_LQ = self.pil_augment(img_LQ) - #lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2) + # if self.opt['phase'] == 'train': + # img_LQ = self.pil_augment(img_LQ) + # lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2) - img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() - gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float() + img_GT = torch.from_numpy(np.ascontiguousarray( + np.transpose(img_GT, (2, 0, 1)))).float() + gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray( + np.transpose(gt_fullsize_ref, (2, 0, 1)))).float() img_LQ = F.to_tensor(img_LQ) lq_fullsize_ref = F.to_tensor(lq_fullsize_ref) - lq_mask = torch.from_numpy(np.ascontiguousarray(lq_mask)).unsqueeze(dim=0) - gt_mask = torch.from_numpy(np.ascontiguousarray(gt_mask)).unsqueeze(dim=0) + lq_mask = torch.from_numpy( + np.ascontiguousarray(lq_mask)).unsqueeze(dim=0) + gt_mask = torch.from_numpy( + np.ascontiguousarray(gt_mask)).unsqueeze(dim=0) if 'lq_noise' in self.opt.keys(): lq_noise = torch.randn_like(img_LQ) * self.opt['lq_noise'] / 255 @@ -331,6 +372,7 @@ class FullImageDataset(data.Dataset): def __len__(self): return len(self.paths_GT) + if __name__ == '__main__': ''' opt = { @@ -365,10 +407,10 @@ if __name__ == '__main__': o = ds[i] for k, v in o.items(): if 'path' not in k: - #if 'full' in k: - #masked = v[:3, :, :] * v[3] - #torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k)) - #v = v[:3, :, :] - #import torchvision - #torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) + # if 'full' in k: + # masked = v[:3, :, :] * v[3] + # torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k)) + # v = v[:3, :, :] + # import torchvision + # torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) pass diff --git a/dlas/data/images/image_corruptor.py b/dlas/data/images/image_corruptor.py index f7000b29..726fabda 100644 --- a/dlas/data/images/image_corruptor.py +++ b/dlas/data/images/image_corruptor.py @@ -1,20 +1,18 @@ import functools import random +from io import BytesIO from math import cos, pi import cv2 import kornia import numpy as np import torch -from kornia.augmentation import ColorJitter - from data.util import read_img +from kornia.augmentation import ColorJitter from PIL import Image -from io import BytesIO - # Get a rough visualization of the above distribution. (Y-axis is meaningless, just spreads data) -from utils.util import opt_get +from dlas.utils.util import opt_get ''' if __name__ == '__main__': @@ -29,9 +27,9 @@ if __name__ == '__main__': def kornia_color_jitter_numpy(img, setting): if setting * 255 > 1: # I'm using Kornia's ColorJitter, which requires pytorch arrays in b,c,h,w format. - img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0) + img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) img = ColorJitter(setting, setting, setting, setting)(img) - img = img.squeeze(0).permute(1,2,0).numpy() + img = img.squeeze(0).permute(1, 2, 0).numpy() return img @@ -41,14 +39,18 @@ class ImageCorruptor: def __init__(self, opt): self.opt = opt self.reset_random() - self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1 - self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else [] - self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0 + self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys( + ) else 1 + self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else [ + ] + self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys( + ) else 0 self.cosine_bias = opt_get(opt, ['cosine_bias'], True) if self.num_corrupts == 0: return else: - self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else [] + self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else [ + ] def reset_random(self): if 'random_seed' in self.opt.keys(): @@ -75,7 +77,8 @@ class ImageCorruptor: if self.num_corrupts == 0: augmentations = [] else: - augmentations = random.choices(self.random_corruptions, k=self.num_corrupts) + augmentations = random.choices( + self.random_corruptions, k=self.num_corrupts) # Sources of entropy corrupted_imgs = [] @@ -99,7 +102,6 @@ class ImageCorruptor: img = ufn(img) corrupted_imgs.append(img) - if return_entropy: return corrupted_imgs, entropy else: @@ -119,11 +121,11 @@ class ImageCorruptor: setting = rand_val * (hi_end - lo_end) + lo_end img = kornia_color_jitter_numpy(img, setting) elif 'gaussian_blur' in aug: - img = cv2.GaussianBlur(img, (0,0), self.blur_scale*rand_val*1.5) + img = cv2.GaussianBlur(img, (0, 0), self.blur_scale*rand_val*1.5) elif 'motion_blur' in aug: # Motion blur intensity = self.blur_scale*rand_val * 3 + 1 - angle = random.randint(0,360) + angle = random.randint(0, 360) k = np.zeros((intensity, intensity), dtype=np.float32) k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32) k = cv2.warpAffine(k, cv2.getRotationMatrix2D((intensity / 2 - 0.5, intensity / 2 - 0.5), angle, 1.0), @@ -145,10 +147,13 @@ class ImageCorruptor: else: scale = 4 if scale > 1: - interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4] - mode = random.randint(0,4) % len(interpolation_modes) + interpolation_modes = [ + cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4] + mode = random.randint(0, 4) % len(interpolation_modes) # Downsample first, then upsample using the random mode. - img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=mode) + img = cv2.resize(img, dsize=( + img.shape[1]//scale, img.shape[0]//scale), interpolation=mode) + def lq_resampling_undo_fn(scale, img): return cv2.resize(img, dsize=(img.shape[1]*scale, img.shape[0]*scale), interpolation=cv2.INTER_LINEAR) undo_fn = functools.partial(lq_resampling_undo_fn, scale) @@ -171,22 +176,23 @@ class ImageCorruptor: elif 'jpeg' in aug: if 'noise' not in applied_augmentations and 'noise-5' not in applied_augmentations: if aug == 'jpeg': - lo=10 - range=20 + lo = 10 + range = 20 elif aug == 'jpeg-low': - lo=15 - range=10 + lo = 15 + range = 10 elif aug == 'jpeg-medium': - lo=23 - range=25 + lo = 23 + range = 25 elif aug == 'jpeg-broad': - lo=15 - range=60 + lo = 15 + range = 60 elif aug == 'jpeg-normal': - lo=47 - range=35 + lo = 47 + range = 35 else: - raise NotImplementedError("specified jpeg corruption doesn't exist") + raise NotImplementedError( + "specified jpeg corruption doesn't exist") # JPEG compression qf = (int((1-rand_val)*range) + lo) # Use PIL to perform a mock compression to a data buffer, then swap back to cv2. @@ -195,14 +201,15 @@ class ImageCorruptor: buffer = BytesIO() img.save(buffer, "JPEG", quality=qf, optimize=True) buffer.seek(0) - jpeg_img_bytes = np.asarray(bytearray(buffer.read()), dtype="uint8") + jpeg_img_bytes = np.asarray( + bytearray(buffer.read()), dtype="uint8") img = read_img("buffer", jpeg_img_bytes, rgb=True) elif 'saturation' in aug: # Lightening / saturation saturation = rand_val * .3 img = np.clip(img + saturation, a_max=1, a_min=0) elif 'greyscale' in aug: - img = np.tile(np.mean(img, axis=2, keepdims=True), [1,1,3]) + img = np.tile(np.mean(img, axis=2, keepdims=True), [1, 1, 3]) elif 'none' not in aug: raise NotImplementedError("Augmentation doesn't exist") diff --git a/dlas/data/images/image_folder_dataset.py b/dlas/data/images/image_folder_dataset.py index 87f9bd24..572d6279 100644 --- a/dlas/data/images/image_folder_dataset.py +++ b/dlas/data/images/image_folder_dataset.py @@ -1,21 +1,21 @@ import functools +import os import random import cv2 import numpy as np import torch -import os - import torchvision +from data import util from torch.utils.data import DataLoader from torchvision.transforms import Normalize from tqdm import tqdm -from data import util # Builds a dataset created from a simple folder containing a list of training/test/validation images. -from data.images.image_corruptor import ImageCorruptor, kornia_color_jitter_numpy -from data.images.image_label_parser import VsNetImageLabeler -from utils.util import opt_get +from dlas.data.images.image_corruptor import (ImageCorruptor, + kornia_color_jitter_numpy) +from dlas.data.images.image_label_parser import VsNetImageLabeler +from dlas.utils.util import opt_get def ndarray_center_crop(crop, img): diff --git a/dlas/data/images/image_label_parser.py b/dlas/data/images/image_label_parser.py index 89a79132..7a114fc8 100644 --- a/dlas/data/images/image_label_parser.py +++ b/dlas/data/images/image_label_parser.py @@ -50,15 +50,15 @@ class VsNetImageLabeler: def get_labels_as_tensor(self, hq, img_key, resize_factor): _, h, w = hq.shape - labels = torch.zeros((1,h,w), dtype=torch.long) - mask = torch.zeros((1,h,w), dtype=torch.float) + labels = torch.zeros((1, h, w), dtype=torch.long) + mask = torch.zeros((1, h, w), dtype=torch.float) lbl_list = self.labeled_images[img_key] for patch_lbl in lbl_list: t, l, h, w = patch_lbl['patch_top'] // resize_factor, patch_lbl['patch_left'] // resize_factor, \ - patch_lbl['patch_height'] // resize_factor, patch_lbl['patch_width'] // resize_factor + patch_lbl['patch_height'] // resize_factor, patch_lbl['patch_width'] // resize_factor val = patch_lbl['labelValue'] - labels[:,t:t+h,l:l+w] = val - mask[:,t:t+h,l:l+w] = 1.0 + labels[:, t:t+h, l:l+w] = val + mask[:, t:t+h, l:l+w] = 1.0 return labels, mask, self.str_labels def add_label(self, binding, img_name, top, left, dim): @@ -100,22 +100,23 @@ class CompactJsonLabeler: assert self.config == parsed['config'] assert self.labels == parsed['labels'] assert self.label_map == parsed['label_map'] - self.images.update(parsed['images']) # This will overwrite existing images, which is acceptable. + # This will overwrite existing images, which is acceptable. + self.images.update(parsed['images']) def get_labeled_paths(self, base_path): return [os.path.join(base_path, pth) for pth in self.images.keys()] def get_labels_as_tensor(self, hq, img_key, resize_factor): _, h, w = hq.shape - labels = torch.zeros((1,h,w), dtype=torch.long) - mask = torch.zeros((1,h,w), dtype=torch.float) + labels = torch.zeros((1, h, w), dtype=torch.long) + mask = torch.zeros((1, h, w), dtype=torch.float) lbl_list = self.images[img_key] for patch_lbl in lbl_list: t, l, h, w = patch_lbl['top'] // resize_factor, patch_lbl['left'] // resize_factor, \ - self.config['dim'] // resize_factor, self.config['dim'] // resize_factor + self.config['dim'] // resize_factor, self.config['dim'] // resize_factor val = patch_lbl['labelValue'] - labels[:,t:t+h,l:l+w] = val - mask[:,t:t+h,l:l+w] = 1.0 + labels[:, t:t+h, l:l+w] = val + mask[:, t:t+h, l:l+w] = 1.0 return labels, mask, self.str_labels def add_label(self, binding, img_name, top, left, dim): diff --git a/dlas/data/images/image_pair_with_corresponding_points_dataset.py b/dlas/data/images/image_pair_with_corresponding_points_dataset.py index 1d2d35c0..c01e671d 100644 --- a/dlas/data/images/image_pair_with_corresponding_points_dataset.py +++ b/dlas/data/images/image_pair_with_corresponding_points_dataset.py @@ -1,13 +1,12 @@ -import torch import os +import torch import torchvision from PIL import Image from torch.utils.data import DataLoader, Dataset from torchvision import transforms from tqdm import tqdm - # Builds a dataset created from a simple folder containing a list of training/test/validation images. @@ -15,33 +14,42 @@ class ImagePairWithCorrespondingPointsDataset(Dataset): def __init__(self, opt): self.opt = opt self.path = opt['path'] - self.pairs = list(filter(lambda f: not os.path.isdir(f), os.listdir(self.path))) + self.pairs = list( + filter(lambda f: not os.path.isdir(f), os.listdir(self.path))) self.transforms = transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + transforms.Normalize( + (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) self.size = opt['size'] - def __getitem__(self, item): dir = self.pairs[item] - img1 = self.transforms(Image.open(os.path.join(self.path, dir, "1.jpg"))) - img2 = self.transforms(Image.open(os.path.join(self.path, dir, "2.jpg"))) - coords1, coords2 = torch.load(os.path.join(self.path, dir, "coords.pth")) + img1 = self.transforms(Image.open( + os.path.join(self.path, dir, "1.jpg"))) + img2 = self.transforms(Image.open( + os.path.join(self.path, dir, "2.jpg"))) + coords1, coords2 = torch.load( + os.path.join(self.path, dir, "coords.pth")) assert img1.shape[-2] == img1.shape[-1] assert img2.shape[-2] == img2.shape[-1] if img1.shape[-1] != self.size: scale = img1.shape[-1] / self.size - assert(int(scale) == scale) # We will only downsample to even resolutions. + # We will only downsample to even resolutions. + assert (int(scale) == scale) scale = 1 / scale - img1 = torch.nn.functional.interpolate(img1.unsqueeze(0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0) + img1 = torch.nn.functional.interpolate(img1.unsqueeze( + 0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0) coords1 = [int(c * scale) for c in coords1] if img2.shape[-1] != self.size: scale = img2.shape[-1] / self.size - assert(int(scale) == scale) # We will only downsample to even resolutions. + # We will only downsample to even resolutions. + assert (int(scale) == scale) scale = 1 / scale - img2 = torch.nn.functional.interpolate(img2.unsqueeze(0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0) + img2 = torch.nn.functional.interpolate(img2.unsqueeze( + 0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0) coords2 = [int(c * scale) for c in coords2] - coords1 = (coords1[1], coords1[0]) # The UI puts these out backwards (x,y). Flip them. + # The UI puts these out backwards (x,y). Flip them. + coords1 = (coords1[1], coords1[0]) coords2 = (coords2[1], coords2[0]) return { 'img1': img1, @@ -53,6 +61,7 @@ class ImagePairWithCorrespondingPointsDataset(Dataset): def __len__(self): return len(self.pairs) + if __name__ == '__main__': opt = { 'path': 'F:\\dlas\\codes\\scripts\\ui\\image_pair_labeler\\results', @@ -60,13 +69,14 @@ if __name__ == '__main__': } output_path = '..' - ds = DataLoader(ImagePairWithCorrespondingPointsDataset(opt), shuffle=True, num_workers=0) + ds = DataLoader(ImagePairWithCorrespondingPointsDataset( + opt), shuffle=True, num_workers=0) for i, d in tqdm(enumerate(ds)): i1 = d['img1'] i2 = d['img2'] c1 = d['coords1'] c2 = d['coords2'] - i1[:,:,c1[0]-3:c1[0]+3,c1[1]-3:c1[1]+3] = 0 - i2[:,:,c2[0]-3:c2[0]+3,c2[1]-3:c2[1]+3] = 0 + i1[:, :, c1[0]-3:c1[0]+3, c1[1]-3:c1[1]+3] = 0 + i2[:, :, c2[0]-3:c2[0]+3, c2[1]-3:c2[1]+3] = 0 torchvision.utils.save_image(i1, f'{output_path}\\{i}_1.png') - torchvision.utils.save_image(i2, f'{output_path}\\{i}_2.png') \ No newline at end of file + torchvision.utils.save_image(i2, f'{output_path}\\{i}_2.png') diff --git a/dlas/data/images/multi_frame_dataset.py b/dlas/data/images/multi_frame_dataset.py index 76c4b127..6603f2b0 100644 --- a/dlas/data/images/multi_frame_dataset.py +++ b/dlas/data/images/multi_frame_dataset.py @@ -1,8 +1,12 @@ -from data.images.base_unsupervised_image_dataset import BaseUnsupervisedImageDataset +import os.path as osp +from bisect import bisect_left + import numpy as np import torch -from bisect import bisect_left -import os.path as osp + +from dlas.data.images.base_unsupervised_image_dataset import \ + BaseUnsupervisedImageDataset + class MultiFrameDataset(BaseUnsupervisedImageDataset): def __init__(self, opt): @@ -30,7 +34,8 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset): for i in range(self.num_frames): idx = search_idx + i if idx < 0 or idx >= len(self.chunks) or chunk_offset < 0 or chunk_offset >= len(self.chunks[idx]): - print("Chunk reference indexing failed for %s." % (im_name,), search_idx, i, chunk_offset, self.num_frames) + print("Chunk reference indexing failed for %s." % + (im_name,), search_idx, i, chunk_offset, self.num_frames) h, r, c, m, p = self.chunks[search_idx + i][chunk_offset] hqs.append(h) refs.append(r) @@ -41,24 +46,32 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset): def __getitem__(self, item): chunk_ind = bisect_left(self.starting_indices, item) - chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1 - hqs, refs, masks, centers, path = self.get_sequential_image_paths_from(chunk_ind, item-self.starting_indices[chunk_ind]) + chunk_ind = chunk_ind if chunk_ind < len( + self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1 + hqs, refs, masks, centers, path = self.get_sequential_image_paths_from( + chunk_ind, item-self.starting_indices[chunk_ind]) hs, hrs, hms, hcs = self.resize_hq(hqs, refs, masks, centers) ls, lrs, lms, lcs = self.synthesize_lq(hs, hrs, hms, hcs) # Convert to torch tensor - hq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hs), (0, 3, 1, 2)))).float() - hq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hrs), (0, 3, 1, 2)))).float() - hq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(hms))).unsqueeze(dim=1) + hq = torch.from_numpy(np.ascontiguousarray( + np.transpose(np.stack(hs), (0, 3, 1, 2)))).float() + hq_ref = torch.from_numpy(np.ascontiguousarray( + np.transpose(np.stack(hrs), (0, 3, 1, 2)))).float() + hq_mask = torch.from_numpy( + np.ascontiguousarray(np.stack(hms))).unsqueeze(dim=1) hq_ref = torch.cat([hq_ref, hq_mask], dim=1) - lq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(ls), (0, 3, 1, 2)))).float() - lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float() - lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1) + lq = torch.from_numpy(np.ascontiguousarray( + np.transpose(np.stack(ls), (0, 3, 1, 2)))).float() + lq_ref = torch.from_numpy(np.ascontiguousarray( + np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float() + lq_mask = torch.from_numpy( + np.ascontiguousarray(np.stack(lms))).unsqueeze(dim=1) lq_ref = torch.cat([lq_ref, lq_mask], dim=1) return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, - 'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)} + 'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)} if __name__ == '__main__': @@ -84,7 +97,7 @@ if __name__ == '__main__': for i in range(len(ds)): import random k = 'lq' - element = ds[random.randint(0,len(ds))] + element = ds[random.randint(0, len(ds))] base_file = osp.basename(element["GT_path"]) o = element[k].unsqueeze(0) if bs < 32: @@ -99,8 +112,9 @@ if __name__ == '__main__': b, fr, f, h, w = batch.shape for j in range(fr): import torchvision - base=osp.basename(base_file) - torchvision.utils.save_image(batch[:, j], "debug/%i_%s_%i__%s.png" % (i, k, j, base)) + base = osp.basename(base_file) + torchvision.utils.save_image( + batch[:, j], "debug/%i_%s_%i__%s.png" % (i, k, j, base)) bs = 0 batch = None diff --git a/dlas/data/images/multiscale_dataset.py b/dlas/data/images/multiscale_dataset.py index 540d3076..ef65b58d 100644 --- a/dlas/data/images/multiscale_dataset.py +++ b/dlas/data/images/multiscale_dataset.py @@ -1,12 +1,13 @@ import random -import numpy as np + import cv2 +import data.util as util +import numpy as np import torch import torch.utils.data as data -import data.util as util # Reads full-quality images and pulls tiles at regular zoom intervals from them. Only usable for training purposes. -from data.images.image_corruptor import ImageCorruptor +from dlas.data.images.image_corruptor import ImageCorruptor # Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping @@ -27,6 +28,7 @@ def get_square_image(image): left = max(int(center + offset * (center - 2)), 0) return image[:, left:left + h, :] + class MultiScaleDataset(data.Dataset): def __init__(self, opt): super(MultiScaleDataset, self).__init__() @@ -36,10 +38,10 @@ class MultiScaleDataset(data.Dataset): self.num_scales = self.opt['num_scales'] self.hq_size_cap = self.tile_size * 2 ** self.num_scales self.scale = self.opt['scale'] - self.paths_hq, self.sizes_hq = util.find_files_of_type(self.data_type, opt['paths'], [1 for _ in opt['paths']]) + self.paths_hq, self.sizes_hq = util.find_files_of_type( + self.data_type, opt['paths'], [1 for _ in opt['paths']]) self.corruptor = ImageCorruptor(opt) - def recursively_extract_patches(self, input_img, result_list, depth): if depth >= self.num_scales: return @@ -49,7 +51,8 @@ class MultiScaleDataset(data.Dataset): input_img[:patch_size, patch_size:], input_img[patch_size:, :patch_size], input_img[patch_size:, patch_size:]] - result_list.extend([cv2.resize(p, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA) for p in patches]) + result_list.extend([cv2.resize( + p, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA) for p in patches]) for p in patches: self.recursively_extract_patches(p, result_list, depth+1) @@ -57,30 +60,41 @@ class MultiScaleDataset(data.Dataset): # get full size image full_path = self.paths_hq[index % len(self.paths_hq)] loaded_img = util.read_img(None, full_path, None) - img_full1 = util.channel_convert(loaded_img.shape[2], 'RGB', [loaded_img])[0] + img_full1 = util.channel_convert( + loaded_img.shape[2], 'RGB', [loaded_img])[0] img_full2 = util.augment([img_full1], True, True)[0] img_full3 = get_square_image(img_full2) # This error crops up from time to time. I suspect an issue with util.read_img. if img_full3.shape[0] == 0 or img_full3.shape[1] == 0: - print("Error with image: %s. Loaded image shape: %s" % (full_path,str(loaded_img.shape)), str(img_full1.shape), str(img_full2.shape), str(img_full3.shape)) + print("Error with image: %s. Loaded image shape: %s" % (full_path, str( + loaded_img.shape)), str(img_full1.shape), str(img_full2.shape), str(img_full3.shape)) # Attempt to recover by just using a fixed array of zeros, which the downstream networks should be fine training against, within reason. - img_full3 = np.zeros((1024,1024,3), dtype=np.int) - img_full = cv2.resize(img_full3, (self.hq_size_cap, self.hq_size_cap), interpolation=cv2.INTER_AREA) - patches_hq = [cv2.resize(img_full, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)] + img_full3 = np.zeros((1024, 1024, 3), dtype=np.int) + img_full = cv2.resize( + img_full3, (self.hq_size_cap, self.hq_size_cap), interpolation=cv2.INTER_AREA) + patches_hq = [cv2.resize( + img_full, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)] self.recursively_extract_patches(img_full, patches_hq, 1) # Image corruption is applied against the full size image for this dataset. img_corrupted = self.corruptor.corrupt_images([img_full])[0] - patches_hq_corrupted = [cv2.resize(img_corrupted, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)] - self.recursively_extract_patches(img_corrupted, patches_hq_corrupted, 1) + patches_hq_corrupted = [cv2.resize( + img_corrupted, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)] + self.recursively_extract_patches( + img_corrupted, patches_hq_corrupted, 1) # BGR to RGB, HWC to CHW, numpy to tensor if patches_hq[0].shape[2] == 3: - patches_hq = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq] - patches_hq_corrupted = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) for p in patches_hq_corrupted] - patches_hq = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq] + patches_hq = [cv2.cvtColor(p, cv2.COLOR_BGR2RGB) + for p in patches_hq] + patches_hq_corrupted = [cv2.cvtColor( + p, cv2.COLOR_BGR2RGB) for p in patches_hq_corrupted] + patches_hq = [torch.from_numpy(np.ascontiguousarray( + np.transpose(p, (2, 0, 1)))).float() for p in patches_hq] patches_hq = torch.stack(patches_hq, dim=0) - patches_hq_corrupted = [torch.from_numpy(np.ascontiguousarray(np.transpose(p, (2, 0, 1)))).float() for p in patches_hq_corrupted] - patches_lq = [torch.nn.functional.interpolate(p.unsqueeze(0), scale_factor=1/self.scale, mode='area').squeeze() for p in patches_hq_corrupted] + patches_hq_corrupted = [torch.from_numpy(np.ascontiguousarray( + np.transpose(p, (2, 0, 1)))).float() for p in patches_hq_corrupted] + patches_lq = [torch.nn.functional.interpolate(p.unsqueeze( + 0), scale_factor=1/self.scale, mode='area').squeeze() for p in patches_hq_corrupted] patches_lq = torch.stack(patches_lq, dim=0) d = {'lq': patches_lq, 'hq': patches_hq, 'GT_path': full_path} @@ -89,6 +103,7 @@ class MultiScaleDataset(data.Dataset): def __len__(self): return len(self.paths_hq) + class MultiscaleTreeNode: def __init__(self, index, parent, i): self.index = index @@ -117,7 +132,8 @@ def build_multiscale_patch_index_map(depth): def _build_multiscale_patch_index_map(depth, ind, node, leaves): - subnodes = [node.add_child(MultiscaleTreeNode(ind+i, node, i)) for i in range(4)] + subnodes = [node.add_child(MultiscaleTreeNode(ind+i, node, i)) + for i in range(4)] ind += 4 if depth == 1: leaves.extend(subnodes) @@ -146,7 +162,7 @@ if __name__ == '__main__': os.makedirs("debug", exist_ok=True) multiscale_tree = build_multiscale_patch_index_map(4) for i in range(500, len(ds)): - quadrant=2 + quadrant = 2 print(i) o = ds[random.randint(0, len(ds))] tree_ind = random.randint(0, len(multiscale_tree)) @@ -155,9 +171,10 @@ if __name__ == '__main__': continue depth = 0 node = multiscale_tree[tree_ind] - #for j, img in enumerate(v): + # for j, img in enumerate(v): # torchvision.utils.save_image(img.unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j)) while node is not None: - torchvision.utils.save_image(v[node.index].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, depth)) + torchvision.utils.save_image(v[node.index].unsqueeze( + 0), "debug/%i_%s_%i.png" % (i, k, depth)) depth += 1 - node = node.parent \ No newline at end of file + node = node.parent diff --git a/dlas/data/images/paired_frame_dataset.py b/dlas/data/images/paired_frame_dataset.py index 09605995..3ee41a1f 100644 --- a/dlas/data/images/paired_frame_dataset.py +++ b/dlas/data/images/paired_frame_dataset.py @@ -1,8 +1,12 @@ -from data.images.base_unsupervised_image_dataset import BaseUnsupervisedImageDataset +import os.path as osp +from bisect import bisect_left + import numpy as np import torch -from bisect import bisect_left -import os.path as osp + +from dlas.data.images.base_unsupervised_image_dataset import \ + BaseUnsupervisedImageDataset + class PairedFrameDataset(BaseUnsupervisedImageDataset): def __init__(self, opt): @@ -26,24 +30,32 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset): def __getitem__(self, item): chunk_ind = bisect_left(self.starting_indices, item) - chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1 - hqs, refs, masks, centers, path = self.get_pair(chunk_ind, item-self.starting_indices[chunk_ind]) + chunk_ind = chunk_ind if chunk_ind < len( + self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1 + hqs, refs, masks, centers, path = self.get_pair( + chunk_ind, item-self.starting_indices[chunk_ind]) hs, hrs, hms, hcs = self.resize_hq(hqs, refs, masks, centers) ls, lrs, lms, lcs = self.synthesize_lq(hs, hrs, hms, hcs) # Convert to torch tensor - hq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hs), (0, 3, 1, 2)))).float() - hq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hrs), (0, 3, 1, 2)))).float() - hq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(hms))).squeeze().unsqueeze(dim=1) + hq = torch.from_numpy(np.ascontiguousarray( + np.transpose(np.stack(hs), (0, 3, 1, 2)))).float() + hq_ref = torch.from_numpy(np.ascontiguousarray( + np.transpose(np.stack(hrs), (0, 3, 1, 2)))).float() + hq_mask = torch.from_numpy(np.ascontiguousarray( + np.stack(hms))).squeeze().unsqueeze(dim=1) hq_ref = torch.cat([hq_ref, hq_mask], dim=1) - lq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(ls), (0, 3, 1, 2)))).float() - lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float() - lq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(lms))).squeeze().unsqueeze(dim=1) + lq = torch.from_numpy(np.ascontiguousarray( + np.transpose(np.stack(ls), (0, 3, 1, 2)))).float() + lq_ref = torch.from_numpy(np.ascontiguousarray( + np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float() + lq_mask = torch.from_numpy(np.ascontiguousarray( + np.stack(lms))).squeeze().unsqueeze(dim=1) lq_ref = torch.cat([lq_ref, lq_mask], dim=1) return {'GT_path': path, 'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, - 'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)} + 'lq_center': torch.tensor(lcs, dtype=torch.long), 'gt_center': torch.tensor(hcs, dtype=torch.long)} if __name__ == '__main__': @@ -51,7 +63,7 @@ if __name__ == '__main__': 'name': 'amalgam', 'paths': ['F:\\4k6k\\datasets\\ns_images\\vr\\validation'], 'weights': [1], - #'target_size': 128, + # 'target_size': 128, 'force_multiple': 32, 'scale': 2, 'eval': False, @@ -69,7 +81,7 @@ if __name__ == '__main__': for i in range(len(ds)): import random k = 'lq' - element = ds[random.randint(0,len(ds))] + element = ds[random.randint(0, len(ds))] base_file = osp.basename(element["GT_path"]) o = element[k].unsqueeze(0) if bs < 2: @@ -84,8 +96,9 @@ if __name__ == '__main__': b, fr, f, h, w = batch.shape for j in range(fr): import torchvision - base=osp.basename(base_file) - torchvision.utils.save_image(batch[:, j], "debug/%i_%s_%i__%s.png" % (i, k, j, base)) + base = osp.basename(base_file) + torchvision.utils.save_image( + batch[:, j], "debug/%i_%s_%i__%s.png" % (i, k, j, base)) bs = 0 batch = None diff --git a/dlas/data/images/single_image_dataset.py b/dlas/data/images/single_image_dataset.py index 3aaebb56..38507e16 100644 --- a/dlas/data/images/single_image_dataset.py +++ b/dlas/data/images/single_image_dataset.py @@ -1,8 +1,11 @@ import random from bisect import bisect_left + import numpy as np import torch -from data.images.base_unsupervised_image_dataset import BaseUnsupervisedImageDataset + +from dlas.data.images.base_unsupervised_image_dataset import \ + BaseUnsupervisedImageDataset # Builds a dataset composed of a set of folders. Each folder represents a single high resolution image that has been @@ -14,30 +17,40 @@ class SingleImageDataset(BaseUnsupervisedImageDataset): def get_paths(self): for i in range(len(self)): chunk_ind = bisect_left(self.starting_indices, i) - chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == i else chunk_ind-1 + chunk_ind = chunk_ind if chunk_ind < len( + self.starting_indices) and self.starting_indices[chunk_ind] == i else chunk_ind-1 yield self.chunks[chunk_ind].tiles[i-self.starting_indices[chunk_ind]] def __getitem__(self, item): chunk_ind = bisect_left(self.starting_indices, item) - chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1 - hq, hq_ref, hq_center, hq_mask, path = self.chunks[chunk_ind][item-self.starting_indices[chunk_ind]] + chunk_ind = chunk_ind if chunk_ind < len( + self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1 + hq, hq_ref, hq_center, hq_mask, path = self.chunks[chunk_ind][item - + self.starting_indices[chunk_ind]] - hs, hrs, hms, hcs = self.resize_hq([hq], [hq_ref], [hq_mask], [hq_center]) + hs, hrs, hms, hcs = self.resize_hq( + [hq], [hq_ref], [hq_mask], [hq_center]) ls, lrs, lms, lcs = self.synthesize_lq(hs, hrs, hms, hcs) # Convert to torch tensor - hq = torch.from_numpy(np.ascontiguousarray(np.transpose(hs[0], (2, 0, 1)))).float() - hq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(hrs[0], (2, 0, 1)))).float() - hq_mask = torch.from_numpy(np.ascontiguousarray(hms[0])).unsqueeze(dim=0) + hq = torch.from_numpy(np.ascontiguousarray( + np.transpose(hs[0], (2, 0, 1)))).float() + hq_ref = torch.from_numpy(np.ascontiguousarray( + np.transpose(hrs[0], (2, 0, 1)))).float() + hq_mask = torch.from_numpy( + np.ascontiguousarray(hms[0])).unsqueeze(dim=0) hq_ref = torch.cat([hq_ref, hq_mask], dim=0) - lq = torch.from_numpy(np.ascontiguousarray(np.transpose(ls[0], (2, 0, 1)))).float() - lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(lrs[0], (2, 0, 1)))).float() - lq_mask = torch.from_numpy(np.ascontiguousarray(lms[0])).unsqueeze(dim=0) + lq = torch.from_numpy(np.ascontiguousarray( + np.transpose(ls[0], (2, 0, 1)))).float() + lq_ref = torch.from_numpy(np.ascontiguousarray( + np.transpose(lrs[0], (2, 0, 1)))).float() + lq_mask = torch.from_numpy( + np.ascontiguousarray(lms[0])).unsqueeze(dim=0) lq_ref = torch.cat([lq_ref, lq_mask], dim=0) return {'lq': lq, 'hq': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, - 'lq_center': torch.tensor(lcs[0], dtype=torch.long), 'gt_center': torch.tensor(hcs[0], dtype=torch.long), - 'LQ_path': path, 'GT_path': path} + 'lq_center': torch.tensor(lcs[0], dtype=torch.long), 'gt_center': torch.tensor(hcs[0], dtype=torch.long), + 'LQ_path': path, 'GT_path': path} if __name__ == '__main__': @@ -60,13 +73,14 @@ if __name__ == '__main__': os.makedirs("debug", exist_ok=True) for i in range(0, len(ds)): o = ds[random.randint(0, len(ds))] - #for k, v in o.items(): + # for k, v in o.items(): k = 'lq' v = o[k] - #if 'LQ' in k and 'path' not in k and 'center' not in k: - #if 'full' in k: - #masked = v[:3, :, :] * v[3] - #torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k)) - #v = v[:3, :, :] + # if 'LQ' in k and 'path' not in k and 'center' not in k: + # if 'full' in k: + # masked = v[:3, :, :] * v[3] + # torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k)) + # v = v[:3, :, :] import torchvision - torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) \ No newline at end of file + torchvision.utils.save_image( + v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) diff --git a/dlas/data/images/stylegan2_dataset.py b/dlas/data/images/stylegan2_dataset.py index 7f4946a9..67557faf 100644 --- a/dlas/data/images/stylegan2_dataset.py +++ b/dlas/data/images/stylegan2_dataset.py @@ -1,15 +1,15 @@ from functools import partial +from pathlib import Path from random import random import torch +import torch.nn as nn import torchvision from PIL import Image from torch.utils import data from torchvision import transforms -import torch.nn as nn -from pathlib import Path -import models.image_generation.stylegan.stylegan2_lucidrains as sg2 +import dlas.models.image_generation.stylegan.stylegan2_lucidrains as sg2 def convert_transparent_to_rgb(image): @@ -29,6 +29,7 @@ def resize_to_minimum_size(min_size, image): return torchvision.transforms.functional.resize(image, min_size) return image + class RandomApply(nn.Module): def __init__(self, prob, fn, fn_else=lambda x: x): super().__init__() @@ -59,7 +60,8 @@ class expand_greyscale(object): color = tensor[:1].expand(3, -1, -1) alpha = tensor[1:] else: - raise Exception(f'image with invalid number of channels given {channels}') + raise Exception( + f'image with invalid number of channels given {channels}') if not sg2.exists(alpha) and self.transparent: alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device) @@ -73,17 +75,21 @@ class Stylegan2Dataset(data.Dataset): EXTS = ['jpg', 'jpeg', 'png', 'webp'] self.folder = opt['path'] self.image_size = opt['target_size'] - self.paths = [p for ext in EXTS for p in Path(f'{self.folder}').glob(f'**/*.{ext}')] + self.paths = [p for ext in EXTS for p in Path( + f'{self.folder}').glob(f'**/*.{ext}')] aug_prob = opt['aug_prob'] - transparent = opt['transparent'] if 'transparent' in opt.keys() else False - assert len(self.paths) > 0, f'No images were found in {self.folder} for training' + transparent = opt['transparent'] if 'transparent' in opt.keys( + ) else False + assert len( + self.paths) > 0, f'No images were found in {self.folder} for training' convert_image_fn = convert_transparent_to_rgb if not transparent else convert_rgb_to_transparent num_channels = 3 if not transparent else 4 self.transform = transforms.Compose([ transforms.Lambda(convert_image_fn), - transforms.Lambda(partial(resize_to_minimum_size, self.image_size)), + transforms.Lambda( + partial(resize_to_minimum_size, self.image_size)), transforms.Resize(self.image_size), RandomApply(aug_prob, transforms.RandomResizedCrop(self.image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)), transforms.CenterCrop(self.image_size)), diff --git a/dlas/data/images/zip_file_dataset.py b/dlas/data/images/zip_file_dataset.py index ef9ae50a..80a2c549 100644 --- a/dlas/data/images/zip_file_dataset.py +++ b/dlas/data/images/zip_file_dataset.py @@ -1,9 +1,10 @@ -import PIL.Image import zipfile + +import PIL.Image import torch import torchvision from torch.utils.data import DataLoader -from torchvision.transforms import Compose, ToTensor, Normalize, Resize +from torchvision.transforms import Compose, Normalize, Resize, ToTensor class ZipFileDataset(torch.utils.data.Dataset): @@ -14,9 +15,10 @@ class ZipFileDataset(torch.utils.data.Dataset): self.resolution = opt['resolution'] self.paired_mode = opt['paired_mode'] self.transforms = Compose([ToTensor(), - Resize(self.resolution), - Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) - ]) + Resize(self.resolution), + Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) + ]) self.zip = None def __len__(self): @@ -49,10 +51,12 @@ class ZipFileDataset(torch.utils.data.Dataset): aname = fname.replace('1.jpg', '0.jpg') out['alt_hq'] = self.load_image(aname) except: - print(f"Error loading {fname} from zipfile. Attempting to recover by loading next element.") + print( + f"Error loading {fname} from zipfile. Attempting to recover by loading next element.") return self[i+1] return out + if __name__ == '__main__': opt = { 'path': 'E:\\4k6k\\datasets\\images\\youtube-imagenet-paired\\output.zip', @@ -65,4 +69,3 @@ if __name__ == '__main__': for i, d in enumerate(loader): torchvision.utils.save_image(d['hq'], f'{i}_hq.png') torchvision.utils.save_image(d['alt_hq'], f'{i}_althq.png') - diff --git a/dlas/data/text/__init__.py b/dlas/data/text/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dlas/data/text/hf_datasets_wrapper.py b/dlas/data/text/hf_datasets_wrapper.py index dfed9d7b..64efc833 100644 --- a/dlas/data/text/hf_datasets_wrapper.py +++ b/dlas/data/text/hf_datasets_wrapper.py @@ -1,18 +1,20 @@ -from torch.utils.data import Dataset import datasets +from torch.utils.data import Dataset class HfDataset(Dataset): """ Simple wrapper for a HuggingFace dataset that can re-map keys if desired. """ + def __init__(self, corpi, cache_path=None, key_maps=None, dataset_spec_key='train'): self.hfd = [] for corpus in corpi: dataset_name, config = corpus if config == '' or config == 'None': config = None - self.hfd.append(datasets.load_dataset(dataset_name, config, cache_dir=cache_path)[dataset_spec_key]) + self.hfd.append(datasets.load_dataset( + dataset_name, config, cache_dir=cache_path)[dataset_spec_key]) self.key_maps = key_maps def __getitem__(self, item): @@ -32,5 +34,6 @@ class HfDataset(Dataset): if __name__ == '__main__': - d = HfDataset([['wikipedia', '20200501.en'], ['bookcorpus', '']], dataset_spec_key='train', cache_path='Z:\\huggingface_datasets\\cache') + d = HfDataset([['wikipedia', '20200501.en'], ['bookcorpus', '']], + dataset_spec_key='train', cache_path='Z:\\huggingface_datasets\\cache') print(d[5]) diff --git a/dlas/data/torch_dataset.py b/dlas/data/torch_dataset.py index 7ad90650..a413505e 100644 --- a/dlas/data/torch_dataset.py +++ b/dlas/data/torch_dataset.py @@ -1,10 +1,10 @@ -from torch.utils.data import Dataset import torchvision.transforms as T +from torch.utils.data import Dataset from torchvision import datasets # Wrapper for basic pytorch datasets which re-wraps them into a format usable by ExtensibleTrainer. -from data.images.cifar import CIFAR100, CIFAR10 -from utils.util import opt_get +from dlas.data.images.cifar import CIFAR10, CIFAR100 +from dlas.utils.util import opt_get class TorchDataset(Dataset): @@ -17,7 +17,8 @@ class TorchDataset(Dataset): "imagenet": datasets.ImageNet, "imagefolder": datasets.ImageFolder } - normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[ + 0.229, 0.224, 0.225]) if opt_get(opt, ['random_crop'], False): transforms = [ T.RandomResizedCrop(opt['image_size']), @@ -34,7 +35,8 @@ class TorchDataset(Dataset): normalize, ] transforms = T.Compose(transforms) - self.dataset = DATASET_MAP[opt['dataset']](transform=transforms, **opt['kwargs']) + self.dataset = DATASET_MAP[opt['dataset']]( + transform=transforms, **opt['kwargs']) self.len = opt_get(opt, ['fixed_len'], len(self.dataset)) self.offset = opt_get(opt, ['offset'], 0) @@ -53,6 +55,7 @@ class TorchDataset(Dataset): def __len__(self): return self.len-self.offset + if __name__ == '__main__': opt = { 'flip': True, diff --git a/dlas/data/util.py b/dlas/data/util.py index a73eecb6..b31177d4 100644 --- a/dlas/data/util.py +++ b/dlas/data/util.py @@ -1,21 +1,22 @@ -import os +import glob import math +import os import pickle import random +import cv2 import numpy import numpy as np -import glob import torch import torchvision -import cv2 #################### # Files & IO #################### ###################### get image path list ###################### -IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.webp', '.WEBP'] +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', + '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.webp', '.WEBP'] def torch2cv(tensor): @@ -30,7 +31,8 @@ def torch2cv(tensor): def cv2torch(cv, batchify=True): cv = cv2.cvtColor(cv, cv2.COLOR_BGR2RGB) - tens = torch.from_numpy(np.ascontiguousarray(np.transpose(cv, (2, 0, 1)))).float() + tens = torch.from_numpy(np.ascontiguousarray( + np.transpose(cv, (2, 0, 1)))).float() if batchify: tens = tens.unsqueeze(0) return tens @@ -65,7 +67,8 @@ def _get_paths_from_images(path, qualifier=is_image_file): def _get_paths_from_lmdb(dataroot): """get image path list from lmdb meta info""" - meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb')) + meta_info = pickle.load( + open(os.path.join(dataroot, 'meta_info.pkl'), 'rb')) paths = meta_info['keys'] sizes = meta_info['resolution'] if len(sizes) == 1: @@ -123,7 +126,8 @@ def read_img(env, path, size=None, rgb=False): # Indirect open then process to support unicode files. stream = open(path, "rb") bytes = bytearray(stream.read()) - img = cv2.imdecode(np.asarray(bytes, dtype=np.uint8), cv2.IMREAD_UNCHANGED) + img = cv2.imdecode(np.asarray(bytes, dtype=np.uint8), + cv2.IMREAD_UNCHANGED) elif env == 'lmdb': img = _read_img_lmdb(env, path, size) elif env == 'buffer': @@ -161,7 +165,8 @@ def read_img_seq(path): # stack to Torch tensor imgs = np.stack(img_l, axis=0) imgs = imgs[:, :, :, [2, 1, 0]] - imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float() + imgs = torch.from_numpy(np.ascontiguousarray( + np.transpose(imgs, (0, 3, 1, 2)))).float() return imgs @@ -478,9 +483,12 @@ def imresize(img, scale, antialiasing=True): kernel_width = weights_H.size(1) for i in range(out_H): idx = int(indices_H[i][0]) - out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) - out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) - out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, + :].transpose(0, 1).mv(weights_H[i]) + out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, + :].transpose(0, 1).mv(weights_H[i]) + out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, + :].transpose(0, 1).mv(weights_H[i]) # process W dimension # symmetric copying @@ -501,9 +509,12 @@ def imresize(img, scale, antialiasing=True): kernel_width = weights_W.size(1) for i in range(out_W): idx = int(indices_W[i][0]) - out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i]) - out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i]) - out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i]) + out_2[0, :, i] = out_1_aug[0, :, idx:idx + + kernel_width].mv(weights_W[i]) + out_2[1, :, i] = out_1_aug[1, :, idx:idx + + kernel_width].mv(weights_W[i]) + out_2[2, :, i] = out_1_aug[2, :, idx:idx + + kernel_width].mv(weights_W[i]) return out_2 @@ -548,9 +559,12 @@ def imresize_np(img, scale, antialiasing=True): kernel_width = weights_H.size(1) for i in range(out_H): idx = int(indices_H[i][0]) - out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i]) - out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i]) - out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, 0] = img_aug[idx:idx + kernel_width, + :, 0].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, 1] = img_aug[idx:idx + kernel_width, + :, 1].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, 2] = img_aug[idx:idx + kernel_width, + :, 2].transpose(0, 1).mv(weights_H[i]) # process W dimension # symmetric copying @@ -571,9 +585,12 @@ def imresize_np(img, scale, antialiasing=True): kernel_width = weights_W.size(1) for i in range(out_W): idx = int(indices_W[i][0]) - out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i]) - out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i]) - out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i]) + out_2[:, i, 0] = out_1_aug[:, idx:idx + + kernel_width, 0].mv(weights_W[i]) + out_2[:, i, 1] = out_1_aug[:, idx:idx + + kernel_width, 1].mv(weights_W[i]) + out_2[:, i, 2] = out_1_aug[:, idx:idx + + kernel_width, 2].mv(weights_W[i]) return out_2.numpy() @@ -587,7 +604,8 @@ def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=[], not print(f"Building cache for contents of {paths}..") output = [] for p in paths: - output.extend(find_files_of_type('img', p, qualifier=is_audio_file)[0]) + output.extend(find_files_of_type( + 'img', p, qualifier=is_audio_file)[0]) if exclusion_list is not None and len(exclusion_list) > 0: print(f"Removing exclusion lists..") before = len(output) @@ -597,6 +615,7 @@ def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=[], not print(f"Excluded {before-len(output)} files.") if endswith is not None: before = len(output) + def filter_fn(p): for e in endswith: if not p.endswith(e): @@ -606,7 +625,8 @@ def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=[], not return False return True output = list(filter(filter_fn, output)) - print(f"!!Excluded {before-len(output)} files with endswith mask. For total of {len(output)} files") + print( + f"!!Excluded {before-len(output)} files with endswith mask. For total of {len(output)} files") print("Done.") torch.save(output, cache_path) return output @@ -617,7 +637,8 @@ if __name__ == '__main__': # read images img = cv2.imread('test.png') img = img * 1.0 / 255 - img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() + img = torch.from_numpy(np.transpose( + img[:, :, [2, 1, 0]], (2, 0, 1))).float() # imresize scale = 1 / 4 import time diff --git a/dlas/data/zero_pad_dict_collate.py b/dlas/data/zero_pad_dict_collate.py index 09b3a6fa..a7c8e241 100644 --- a/dlas/data/zero_pad_dict_collate.py +++ b/dlas/data/zero_pad_dict_collate.py @@ -7,12 +7,14 @@ class ZeroPadDictCollate(): Given a list of dictionary outputs with torch.Tensors from a Dataset, iterates through each one, finds the longest tensor, and zero pads all the other tensors together. """ + def collate_tensors(self, batch, key): result = [] largest_dims = [0 for _ in range(len(batch[0][key].shape))] for elem in batch: result.append(elem[key]) - largest_dims = [max(current_largest, new_consideration) for current_largest, new_consideration in zip(largest_dims, elem[key].shape)] + largest_dims = [max(current_largest, new_consideration) + for current_largest, new_consideration in zip(largest_dims, elem[key].shape)] # Now pad each tensor by the largest dimension. for i in range(len(result)): padding_tuple = () @@ -24,7 +26,6 @@ class ZeroPadDictCollate(): return torch.stack(result, dim=0) - def collate_into_list(self, batch, key): result = [] for elem in batch: @@ -42,4 +43,4 @@ class ZeroPadDictCollate(): collated[key] = torch.stack([b[key] for b in batch]) else: collated[key] = self.collate_into_list(batch, key) - return collated \ No newline at end of file + return collated diff --git a/dlas/models/arch_util.py b/dlas/models/arch_util.py index 4c9caeb0..1528e394 100644 --- a/dlas/models/arch_util.py +++ b/dlas/models/arch_util.py @@ -3,13 +3,11 @@ from abc import abstractmethod import torch import torch.nn as nn -import torch.nn.init as init import torch.nn.functional as F -import torch.nn.utils.spectral_norm as SpectralNorm -from math import sqrt +import torch.nn.init as init -from utils.util import checkpoint -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.utils.util import checkpoint def exists(val): @@ -21,14 +19,14 @@ def default(val, d): def l2norm(t): - return F.normalize(t, p = 2, dim = -1) + return F.normalize(t, p=2, dim=-1) def ema_inplace(moving_avg, new, decay): - moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) -def laplace_smoothing(x, n_categories, eps = 1e-5): +def laplace_smoothing(x, n_categories, eps=1e-5): return (x + eps) / (x.sum() + n_categories * eps) @@ -36,9 +34,9 @@ def sample_vectors(samples, num): num_samples, device = samples.shape[0], samples.device if num_samples >= num: - indices = torch.randperm(num_samples, device = device)[:num] + indices = torch.randperm(num_samples, device=device)[:num] else: - indices = torch.randint(0, num_samples, (num,), device = device) + indices = torch.randint(0, num_samples, (num,), device=device) return samples[indices] @@ -239,7 +237,8 @@ class AttentionPool2d(nn.Module): b, c, *_spatial = x.shape x = x.reshape(b, c, -1) # NC(HW) x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) - x = x + self.positional_embedding[None, :, :x.shape[-1]].to(x.dtype) # NC(HW+1) + x = x + self.positional_embedding[None, + :, :x.shape[-1]].to(x.dtype) # NC(HW+1) x = self.qkv_proj(x) x = self.attention(x) x = self.c_proj(x) @@ -296,7 +295,8 @@ class Upsample(nn.Module): if dims == 1: ksize = 5 pad = 2 - self.conv = conv_nd(dims, self.channels, self.out_channels, ksize, padding=pad) + self.conv = conv_nd(dims, self.channels, + self.out_channels, ksize, padding=pad) def forward(self, x): assert x.shape[1] == self.channels @@ -346,6 +346,7 @@ class cGLU(nn.Module): """ Gated GELU for channel-first architectures. """ + def __init__(self, dim_in, dim_out=None): super().__init__() dim_out = dim_in if dim_out is None else dim_out @@ -395,7 +396,8 @@ class ResBlock(nn.Module): self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), - conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), + conv_nd(dims, channels, self.out_channels, + kernel_size, padding=padding), ) self.updown = up or down @@ -414,7 +416,8 @@ class ResBlock(nn.Module): nn.SiLU(), nn.Dropout(p=dropout), zero_module( - conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) + conv_nd(dims, self.out_channels, self.out_channels, + kernel_size, padding=padding) ), ) @@ -425,7 +428,8 @@ class ResBlock(nn.Module): dims, channels, self.out_channels, kernel_size, padding=padding ) else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 1) def forward(self, x): """ @@ -466,10 +470,10 @@ def build_local_attention_mask(n, l, fixed_region=0): A mask that can be applied to AttentionBlock to achieve local attention. """ assert l*2 < n, f'Local context must be less than global context. {l}, {n}' - o = torch.arange(0,n) - c = o.unsqueeze(-1).repeat(1,n) - r = o.unsqueeze(0).repeat(n,1) - localized = ((-(r-c).abs())+l).clamp(0,l-1) / (l-1) + o = torch.arange(0, n) + c = o.unsqueeze(-1).repeat(1, n) + r = o.unsqueeze(0).repeat(n, 1) + localized = ((-(r-c).abs())+l).clamp(0, l-1) / (l-1) localized[:fixed_region] = 1 localized[:, :fixed_region] = 1 mask = localized > 0 @@ -477,7 +481,7 @@ def build_local_attention_mask(n, l, fixed_region=0): def test_local_attention_mask(): - print(build_local_attention_mask(9,4,1)) + print(build_local_attention_mask(9, 4, 1)) class RelativeQKBias(nn.Module): @@ -487,17 +491,18 @@ class RelativeQKBias(nn.Module): If symmetric=False, a different bias is applied to each side of the input element, otherwise the bias is symmetric. """ + def __init__(self, l, max_positions=4000, symmetric=True): super().__init__() if symmetric: self.emb = nn.Parameter(torch.randn(l+1) * .01) - o = torch.arange(0,max_positions) - c = o.unsqueeze(-1).repeat(1,max_positions) - r = o.unsqueeze(0).repeat(max_positions,1) - M = ((-(r-c).abs())+l).clamp(0,l) + o = torch.arange(0, max_positions) + c = o.unsqueeze(-1).repeat(1, max_positions) + r = o.unsqueeze(0).repeat(max_positions, 1) + M = ((-(r-c).abs())+l).clamp(0, l) else: self.emb = nn.Parameter(torch.randn(l*2+2) * .01) - a = torch.arange(0,max_positions) + a = torch.arange(0, max_positions) c = a.unsqueeze(-1) - a m = (c >= -l).logical_and(c <= l) M = (l+c+1)*m @@ -508,7 +513,7 @@ class RelativeQKBias(nn.Module): # return self.emb[self.M[:n, :n]].view(1,n,n) # However, indexing operations like this have horrible efficiency on GPUs: https://github.com/pytorch/pytorch/issues/15245 # So, enter this horrible, equivalent mess: - return torch.gather(self.emb.unsqueeze(-1).repeat(1,n), 0, self.M[:n,:n]).view(1,n,n) + return torch.gather(self.emb.unsqueeze(-1).repeat(1, n), 0, self.M[:n, :n]).view(1, n, n) class AttentionBlock(nn.Module): @@ -550,7 +555,8 @@ class AttentionBlock(nn.Module): # split heads before split qkv self.attention = QKVAttentionLegacy(self.num_heads) - self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1) + self.x_proj = nn.Identity() if out_channels == channels else conv_nd( + 1, channels, out_channels, 1) self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1)) def forward(self, x, mask=None, qk_bias=None): @@ -572,7 +578,7 @@ class AttentionBlock(nn.Module): b, c, *spatial = x.shape if mask is not None: if len(mask.shape) == 2: - mask = mask.unsqueeze(0).repeat(x.shape[0],1,1) + mask = mask.unsqueeze(0).repeat(x.shape[0], 1, 1) if mask.shape[1] != x.shape[-1]: mask = mask[:, :x.shape[-1], :x.shape[-1]] @@ -606,7 +612,8 @@ class QKVAttentionLegacy(nn.Module): bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, + length).split(ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = torch.einsum( "bct,bcs->bts", q * scale, k * scale @@ -651,7 +658,8 @@ class QKVAttention(nn.Module): mask = mask.repeat(self.n_heads, 1, 1) weight[mask.logical_not()] = -torch.inf weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + a = torch.einsum("bts,bcs->bct", weight, + v.reshape(bs * self.n_heads, ch, length)) return a.reshape(bs, -1, length) @@ -678,7 +686,8 @@ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) - output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) + output = F.grid_sample( + x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) return output @@ -690,7 +699,8 @@ class PixelUnshuffle(nn.Module): def forward(self, x): (b, f, w, h) = x.shape x = x.contiguous().view(b, f, w // self.r, self.r, h // self.r, self.r) - x = x.permute(0, 1, 3, 5, 2, 4).contiguous().view(b, f * (self.r ** 2), w // self.r, h // self.r) + x = x.permute(0, 1, 3, 5, 2, 4).contiguous().view( + b, f * (self.r ** 2), w // self.r, h // self.r) return x @@ -704,6 +714,8 @@ def silu(input): # create a class wrapper from PyTorch nn.Module, so # the function now can be easily used in models + + class SiLU(nn.Module): ''' Applies the Sigmoid Linear Unit (SiLU) function element-wise: @@ -720,11 +732,12 @@ class SiLU(nn.Module): >>> input = torch.randn(2) >>> output = m(input) ''' + def __init__(self): ''' Init method. ''' - super().__init__() # init the base class + super().__init__() # init the base class def forward(self, input): ''' @@ -735,12 +748,15 @@ class SiLU(nn.Module): ''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard kernel sizes. ''' + + class ConvBnRelu(nn.Module): def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True): super(ConvBnRelu, self).__init__() padding_map = {1: 0, 3: 1, 5: 2, 7: 3} assert kernel_size in padding_map.keys() - self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) + self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, + stride, padding_map[kernel_size], bias=bias) if norm: self.bn = nn.BatchNorm2d(filters_out) else: @@ -753,7 +769,8 @@ class ConvBnRelu(nn.Module): # Init params. for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear') + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) @@ -770,12 +787,15 @@ class ConvBnRelu(nn.Module): ''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard kernel sizes. ''' + + class ConvBnSilu(nn.Module): def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1): super(ConvBnSilu, self).__init__() padding_map = {1: 0, 3: 1, 5: 2, 7: 3} assert kernel_size in padding_map.keys() - self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) + self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, + stride, padding_map[kernel_size], bias=bias) if norm: self.bn = nn.BatchNorm2d(filters_out) else: @@ -788,7 +808,8 @@ class ConvBnSilu(nn.Module): # Init params. for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear') + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear') m.weight.data *= weight_init_factor if m.bias is not None: m.bias.data.zero_() @@ -808,12 +829,15 @@ class ConvBnSilu(nn.Module): ''' Convenience class with Conv->BN->LeakyReLU. Includes weight initialization and auto-padding for standard kernel sizes. ''' + + class ConvBnLelu(nn.Module): def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1): super(ConvBnLelu, self).__init__() padding_map = {1: 0, 3: 1, 5: 2, 7: 3} assert kernel_size in padding_map.keys() - self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) + self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, + stride, padding_map[kernel_size], bias=bias) if norm: self.bn = nn.BatchNorm2d(filters_out) else: @@ -847,12 +871,15 @@ class ConvBnLelu(nn.Module): ''' Convenience class with Conv->GroupNorm->LeakyReLU. Includes weight initialization and auto-padding for standard kernel sizes. ''' + + class ConvGnLelu(nn.Module): def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1): super(ConvGnLelu, self).__init__() padding_map = {1: 0, 3: 1, 5: 2, 7: 3} assert kernel_size in padding_map.keys() - self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) + self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, + stride, padding_map[kernel_size], bias=bias) if norm: self.gn = nn.GroupNorm(num_groups, filters_out) else: @@ -886,12 +913,15 @@ class ConvGnLelu(nn.Module): ''' Convenience class with Conv->BN->SiLU. Includes weight initialization and auto-padding for standard kernel sizes. ''' + + class ConvGnSilu(nn.Module): def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, num_groups=8, weight_init_factor=1, convnd=nn.Conv2d): super(ConvGnSilu, self).__init__() padding_map = {1: 0, 3: 1, 5: 2, 7: 3} assert kernel_size in padding_map.keys() - self.conv = convnd(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) + self.conv = convnd(filters_in, filters_out, kernel_size, + stride, padding_map[kernel_size], bias=bias) if norm: self.gn = nn.GroupNorm(num_groups, filters_out) else: @@ -904,7 +934,8 @@ class ConvGnSilu(nn.Module): # Init params. for m in self.modules(): if isinstance(m, convnd): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear') + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu' if self.silu else 'linear') m.weight.data *= weight_init_factor if m.bias is not None: m.bias.data.zero_() @@ -924,12 +955,15 @@ class ConvGnSilu(nn.Module): ''' Convenience class with Conv->BN->ReLU. Includes weight initialization and auto-padding for standard kernel sizes. ''' + + class ConvBnRelu(nn.Module): def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, activation=True, norm=True, bias=True, weight_init_factor=1): super(ConvBnRelu, self).__init__() padding_map = {1: 0, 3: 1, 5: 2, 7: 3} assert kernel_size in padding_map.keys() - self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) + self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, + stride, padding_map[kernel_size], bias=bias) if norm: self.bn = nn.BatchNorm2d(filters_out) else: @@ -942,7 +976,8 @@ class ConvBnRelu(nn.Module): # Init params. for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear') + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu' if self.relu else 'linear') m.weight.data *= weight_init_factor if m.bias is not None: m.bias.data.zero_() @@ -969,7 +1004,8 @@ class MultiConvBlock(nn.Module): self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor)] + [ConvBnLelu(filters_mid, filters_mid, kernel_size, norm=norm, bias=False, weight_init_factor=weight_init_factor) for i in range(depth - 2)] + [ConvBnLelu(filters_mid, filters_out, kernel_size, activation=False, norm=False, bias=False, weight_init_factor=weight_init_factor)]) - self.scale = nn.Parameter(torch.full((1,), fill_value=scale_init, dtype=torch.float)) + self.scale = nn.Parameter(torch.full( + (1,), fill_value=scale_init, dtype=torch.float)) self.bias = nn.Parameter(torch.zeros(1)) def forward(self, x, noise=None): @@ -988,10 +1024,14 @@ class ExpansionBlock(nn.Module): super(ExpansionBlock, self).__init__() if filters_out is None: filters_out = filters_in // 2 - self.decimate = block(filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True) - self.process_passthrough = block(filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True) - self.conjoin = block(filters_out*2, filters_out, kernel_size=3, bias=False, activation=True, norm=False) - self.process = block(filters_out, filters_out, kernel_size=3, bias=False, activation=True, norm=True) + self.decimate = block( + filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True) + self.process_passthrough = block( + filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True) + self.conjoin = block(filters_out*2, filters_out, + kernel_size=3, bias=False, activation=True, norm=False) + self.process = block(filters_out, filters_out, + kernel_size=3, bias=False, activation=True, norm=True) # input is the feature signal with shape (b, f, w, h) # passthrough is the structure signal with shape (b, f/2, w*2, h*2) @@ -1012,10 +1052,14 @@ class ExpansionBlock2(nn.Module): super(ExpansionBlock2, self).__init__() if filters_out is None: filters_out = filters_in // 2 - self.decimate = block(filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True) - self.process_passthrough = block(filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True) - self.conjoin = block(filters_out*2, filters_out*2, kernel_size=3, bias=False, activation=True, norm=False) - self.reduce = block(filters_out*2, filters_out, kernel_size=3, bias=False, activation=True, norm=True) + self.decimate = block( + filters_in, filters_out, kernel_size=1, bias=False, activation=False, norm=True) + self.process_passthrough = block( + filters_out, filters_out, kernel_size=3, bias=True, activation=False, norm=True) + self.conjoin = block(filters_out*2, filters_out*2, + kernel_size=3, bias=False, activation=True, norm=False) + self.reduce = block(filters_out*2, filters_out, + kernel_size=3, bias=False, activation=True, norm=True) # input is the feature signal with shape (b, f, w, h) # passthrough is the structure signal with shape (b, f/2, w*2, h*2) @@ -1036,8 +1080,10 @@ class ConjoinBlock(nn.Module): filters_out = filters_in if filters_pt is None: filters_pt = filters_in - self.process = block(filters_in + filters_pt, filters_in + filters_pt, kernel_size=3, bias=False, activation=True, norm=norm) - self.decimate = block(filters_in + filters_pt, filters_out, kernel_size=1, bias=False, activation=False, norm=norm) + self.process = block(filters_in + filters_pt, filters_in + filters_pt, + kernel_size=3, bias=False, activation=True, norm=norm) + self.decimate = block(filters_in + filters_pt, filters_out, + kernel_size=1, bias=False, activation=False, norm=norm) def forward(self, input, passthrough): x = torch.cat([input, passthrough], dim=1) @@ -1053,7 +1099,8 @@ class ReferenceJoinBlock(nn.Module): scale_init=residual_weight_init_factor, norm=False, weight_init_factor=residual_weight_init_factor) if join: - self.join_conv = block(nf, nf, kernel_size=kernel_size, norm=final_norm, bias=False, activation=True) + self.join_conv = block( + nf, nf, kernel_size=kernel_size, norm=final_norm, bias=False, activation=True) else: self.join_conv = None @@ -1070,7 +1117,8 @@ class ReferenceJoinBlock(nn.Module): class UpconvBlock(nn.Module): def __init__(self, filters_in, filters_out=None, block=ConvGnSilu, norm=True, activation=True, bias=False): super(UpconvBlock, self).__init__() - self.process = block(filters_in, filters_out, kernel_size=3, bias=bias, activation=activation, norm=norm) + self.process = block(filters_in, filters_out, kernel_size=3, + bias=bias, activation=activation, norm=norm) def forward(self, x): x = F.interpolate(x, scale_factor=2, mode="nearest") @@ -1083,21 +1131,29 @@ class FinalUpsampleBlock2x(nn.Module): super(FinalUpsampleBlock2x, self).__init__() if scale == 2: self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True), - UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True), - block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True), + UpconvBlock( + nf, nf // 2, block=block, norm=False, activation=True, bias=True), + block(nf // 2, nf // 2, kernel_size=3, + norm=False, activation=False, bias=True), block(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False)) else: self.chain = nn.Sequential(block(nf, nf, kernel_size=3, norm=False, activation=True, bias=True), - UpconvBlock(nf, nf, block=block, norm=False, activation=True, bias=True), - block(nf, nf, kernel_size=3, norm=False, activation=False, bias=True), - UpconvBlock(nf, nf // 2, block=block, norm=False, activation=True, bias=True), - block(nf // 2, nf // 2, kernel_size=3, norm=False, activation=False, bias=True), + UpconvBlock( + nf, nf, block=block, norm=False, activation=True, bias=True), + block(nf, nf, kernel_size=3, norm=False, + activation=False, bias=True), + UpconvBlock( + nf, nf // 2, block=block, norm=False, activation=True, bias=True), + block(nf // 2, nf // 2, kernel_size=3, + norm=False, activation=False, bias=True), block(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False)) def forward(self, x): return self.chain(x) # torch.gather() which operates as it always fucking should have: pulling indexes from the input. + + def gather_2d(input, index): b, c, h, w = input.shape nodim = input.view(b, c, h * w) diff --git a/dlas/models/audio/asr/w2v_wrapper.py b/dlas/models/audio/asr/w2v_wrapper.py index f25c562e..398fc444 100644 --- a/dlas/models/audio/asr/w2v_wrapper.py +++ b/dlas/models/audio/asr/w2v_wrapper.py @@ -3,13 +3,14 @@ from itertools import groupby import torch import torch.nn as nn from transformers import Wav2Vec2ForCTC -from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Attention, Wav2Vec2Model +from transformers.models.wav2vec2.modeling_wav2vec2 import (Wav2Vec2Attention, + Wav2Vec2Model) -from data.audio.unsupervised_audio_dataset import load_audio -from models.audio.tts.tacotron2.text import sequence_to_text -from trainer.networks import register_model -from utils.util import opt_get -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.data.audio.unsupervised_audio_dataset import load_audio +from dlas.models.audio.tts.tacotron2.text import sequence_to_text +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get def only_letters(string): @@ -22,6 +23,7 @@ class Wav2VecFeatureExtractor(nn.Module): Basic wrapper that only does feature extraction. Useful to build out this portion of the model so it can be operated through DDP. """ + def __init__(self, basis_model='facebook/wav2vec2-large'): super().__init__() w2v = Wav2Vec2ForCTC.from_pretrained(basis_model) @@ -34,7 +36,8 @@ class Wav2VecFeatureExtractor(nn.Module): def forward(self, audio, wav_lengths): with torch.no_grad(): audio = audio[:, :, :wav_lengths.max()] - audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) + audio_norm = (audio - audio.mean()) / \ + torch.sqrt(audio.var() + 1e-7) return self.extractor(audio_norm.squeeze(1)) @@ -42,13 +45,14 @@ class Wav2VecWrapper(nn.Module): """ Basic wrapper class that makes Wav2Vec2 usable by DLAS. """ + def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True, remove_feature_extractor=False, ramp_dropout_mode=False, ramp_dropout_end=20000, ramp_dropout_min=.1, ramp_dropout_max=.5, layer_drop_pct=.1): super().__init__() self.provide_attention_mask = provide_attention_mask - + self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model) # Perform some surgery to get the model we actually want. self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled @@ -99,7 +103,8 @@ class Wav2VecWrapper(nn.Module): unaligned_tokens[b, text_lengths[b]:] = -100 model_inp = fea_extractor if self.remove_feature_extractor else audio - outputs = self.w2v(input_values=model_inp, attention_mask=attention_mask, labels=unaligned_tokens) + outputs = self.w2v( + input_values=model_inp, attention_mask=attention_mask, labels=unaligned_tokens) if self.output_wer: self.last_pred.append(torch.argmax(outputs.logits, dim=-1)) @@ -126,11 +131,15 @@ class Wav2VecWrapper(nn.Module): pred_strings = [] for last_labels, last_pred in zip(self.last_labels, self.last_pred): last_labels[last_labels == -100] = 0 - label_strings.extend([only_letters(sequence_to_text(lbl)) for lbl in last_labels]) - pred_strings.extend([only_letters(sequence_to_text(self.decode_ctc(pred))) for pred in last_pred]) - wer = wer_metric.compute(predictions=pred_strings, references=label_strings) + label_strings.extend( + [only_letters(sequence_to_text(lbl)) for lbl in last_labels]) + pred_strings.extend([only_letters(sequence_to_text( + self.decode_ctc(pred))) for pred in last_pred]) + wer = wer_metric.compute( + predictions=pred_strings, references=label_strings) res['wer'] = wer - print(f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}") + print( + f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}") if self.ramp_dropout_mode: res['dropout_rate'] = self.current_dropout_rate return res @@ -149,7 +158,8 @@ class Wav2VecWrapper(nn.Module): def update_for_step(self, step, *args): if self.ramp_dropout_mode and step % 10 == 0: dropout_gap = self.ramp_dropout_max - self.ramp_dropout_min - new_dropout_rate = self.ramp_dropout_min + dropout_gap * min(step / self.ramp_dropout_end, 1) + new_dropout_rate = self.ramp_dropout_min + \ + dropout_gap * min(step / self.ramp_dropout_end, 1) self.current_dropout_rate = new_dropout_rate for name, module in self.w2v.named_modules(): if isinstance(module, nn.Dropout): @@ -187,14 +197,18 @@ def register_wav2vec2(opt_net, opt): if __name__ == '__main__': fe = Wav2VecFeatureExtractor(basis_model='facebook/wav2vec2-large-960h') - w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True, ramp_dropout_mode=True) + w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', + freeze_transformer=True, remove_feature_extractor=True, ramp_dropout_mode=True) w2v.update_for_step(8000) - fea = fe(torch.randn(2,1,50000), torch.tensor([20000, 30000])) - loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]), fea) - w2v.get_debug_values(0,"") + fea = fe(torch.randn(2, 1, 50000), torch.tensor([20000, 30000])) + loss = w2v(torch.randn(2, 1, 50000), torch.randint(0, 40, (2, 70)), + torch.tensor([20000, 30000]), torch.tensor([35, 50]), fea) + w2v.get_debug_values(0, "") - sd = torch.load('../experiments/train_wav2vec_mass_archived_r0/models/19500_wav2vec.pth') + sd = torch.load( + '../experiments/train_wav2vec_mass_archived_r0/models/19500_wav2vec.pth') w2v.load_state_dict(sd) - pred = w2v.inference(load_audio('Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav', 16000).unsqueeze(0)) + pred = w2v.inference(load_audio( + 'Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav', 16000).unsqueeze(0)) res = sequence_to_text(pred[0]) print(res) diff --git a/dlas/models/audio/audio_resnet.py b/dlas/models/audio/audio_resnet.py index 59e8b2fa..f6591535 100644 --- a/dlas/models/audio/audio_resnet.py +++ b/dlas/models/audio/audio_resnet.py @@ -1,12 +1,12 @@ +from typing import Any, Callable, List, Optional, Type, Union + import torch -from torch import Tensor import torch.nn as nn +from torch import Tensor -from trainer.networks import register_model -from utils.util import opt_get -from typing import Type, Any, Callable, Union, List, Optional -import torch_intermediary as ml - +import dlas.torch_intermediary as ml +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', @@ -42,9 +42,11 @@ class BasicBlock(nn.Module): if norm_layer is None: norm_layer = nn.BatchNorm1d if groups != 1 or base_width != 64: - raise ValueError('BasicBlock only supports groups=1 and base_width=64') + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64') if dilation > 1: - raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + raise NotImplementedError( + "Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) @@ -177,7 +179,8 @@ class ResNet(nn.Module): for m in self.modules(): if isinstance(m, nn.Conv1d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) @@ -188,9 +191,11 @@ class ResNet(nn.Module): if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): - nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + # type: ignore[arg-type] + nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, BasicBlock): - nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + # type: ignore[arg-type] + nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1, dilate: bool = False) -> nn.Sequential: @@ -386,5 +391,5 @@ def register_audio_resnet(opt_net, opt): if __name__ == '__main__': m = resnet34() - o = m(torch.randn((1,1,48000))) - print(o.shape) \ No newline at end of file + o = m(torch.randn((1, 1, 48000))) + print(o.shape) diff --git a/dlas/models/audio/mel2vec.py b/dlas/models/audio/mel2vec.py index 25a00a81..f35d4138 100644 --- a/dlas/models/audio/mel2vec.py +++ b/dlas/models/audio/mel2vec.py @@ -9,13 +9,14 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import distributed -from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices from transformers.deepspeed import is_deepspeed_zero3_enabled +from transformers.models.wav2vec2.modeling_wav2vec2 import ( + _compute_mask_indices, _sample_negative_indices) -from models.arch_util import ResBlock -from trainer.networks import register_model -from utils.util import checkpoint -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.models.arch_util import ResBlock +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint class Mel2Vec2FeatureProjection(nn.Module): @@ -243,6 +244,8 @@ class Wav2Vec2SamePadLayer(nn.Module): from torch.nn.utils.weight_norm import WeightNorm + + def __deepcopy__(self, memo): # save and delete all weightnorm weights on self weights = {} diff --git a/dlas/models/audio/music/cheater_gen_ar.py b/dlas/models/audio/music/cheater_gen_ar.py index 00222217..61e37317 100644 --- a/dlas/models/audio/music/cheater_gen_ar.py +++ b/dlas/models/audio/music/cheater_gen_ar.py @@ -2,13 +2,13 @@ import torch import torch.nn.functional as F from torch import nn from transformers import GPT2Config, GPT2Model -import torch_intermediary as ml -from models.arch_util import AttentionBlock, ResBlock -from models.audio.tts.lucidrains_dvae import DiscreteVAE -from models.lucidrains.x_transformers import Encoder -from trainer.networks import register_model -from utils.util import opt_get, ceil_multiple, print_network +import dlas.torch_intermediary as ml +from dlas.models.arch_util import AttentionBlock, ResBlock +from dlas.models.audio.tts.lucidrains_dvae import DiscreteVAE +from dlas.models.lucidrains.x_transformers import Encoder +from dlas.trainer.networks import register_model +from dlas.utils.util import ceil_multiple, opt_get, print_network class ConditioningEncoder(nn.Module): @@ -22,23 +22,23 @@ class ConditioningEncoder(nn.Module): super().__init__() self.init = nn.Conv1d(cond_dim, embedding_dim, kernel_size=1) self.attn = Encoder( - dim=embedding_dim, - depth=attn_blocks, - heads=num_attn_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=2, - do_checkpointing=do_checkpointing - ) + dim=embedding_dim, + depth=attn_blocks, + heads=num_attn_heads, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + zero_init_branch_output=True, + ff_mult=2, + do_checkpointing=do_checkpointing + ) self.dim = embedding_dim def forward(self, x): - h = self.init(x).permute(0,2,1) - h = self.attn(h).permute(0,2,1) + h = self.init(x).permute(0, 2, 1) + h = self.attn(h).permute(0, 2, 1) return h.mean(-1) @@ -46,7 +46,7 @@ class ConditioningAR(nn.Module): def __init__(self, dim, layers, dropout=0, num_vectors=8192, cond_free_percent=.15, fp16=False): super().__init__() self.cond_encoder = ConditioningEncoder(256, dim) - self.cond_free_emb = nn.Parameter(torch.randn(1,dim)) + self.cond_free_emb = nn.Parameter(torch.randn(1, dim)) self.unconditioned_percentage = cond_free_percent self.fp16 = fp16 @@ -65,13 +65,16 @@ class ConditioningAR(nn.Module): cond = self.cond_encoder(conditioning) if self.training and self.unconditioned_percentage > 0: - unconditioned_batches = torch.rand((cond.shape[0],1), device=cond.device) < self.unconditioned_percentage - cond = torch.where(unconditioned_batches, self.cond_free_emb.repeat(cond.shape[0],1), cond) + unconditioned_batches = torch.rand( + (cond.shape[0], 1), device=cond.device) < self.unconditioned_percentage + cond = torch.where(unconditioned_batches, + self.cond_free_emb.repeat(cond.shape[0], 1), cond) unused_params.append(self.cond_free_emb) h = self.embeddings(cheater_codes) h = torch.cat([cond.unsqueeze(1), h], dim=1) - targets = cheater_codes # Since we padded above by 1, the input alignment works. + # Since we padded above by 1, the input alignment works. + targets = cheater_codes with torch.autocast(cheater_codes.device.type, enabled=self.fp16): h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state @@ -79,12 +82,13 @@ class ConditioningAR(nn.Module): if return_latent: return h.float() - logits = self.head(h[:,:-1]).permute(0,2,1) + logits = self.head(h[:, :-1]).permute(0, 2, 1) loss = F.cross_entropy(logits, targets, reduction="none") # Perform masking if code_lengths is not None: - mask = torch.arange(0, loss.shape[1], device=h.device).unsqueeze(0).repeat(loss.shape[0], 1) < code_lengths.unsqueeze(1) + mask = torch.arange(0, loss.shape[1], device=h.device).unsqueeze( + 0).repeat(loss.shape[0], 1) < code_lengths.unsqueeze(1) loss = loss * mask loss = loss.mean() @@ -114,14 +118,13 @@ def test_ar(): model = ConditioningAR(512, 8, cond_free_percent=.5) print_network(model) - codes = torch.randint(0,8192, (2,400)) - cond = torch.randn(2,256,400) - cl = torch.tensor([200,10]) - codes[1,10:] = 2 + codes = torch.randint(0, 8192, (2, 400)) + cond = torch.randn(2, 256, 400) + cl = torch.tensor([200, 10]) + codes[1, 10:] = 2 model(codes, cond, cl) pg = model.get_grad_norm_parameter_groups() - if __name__ == '__main__': test_ar() diff --git a/dlas/models/audio/music/diffwave.py b/dlas/models/audio/music/diffwave.py index 71cf765c..1fa9e8a4 100644 --- a/dlas/models/audio/music/diffwave.py +++ b/dlas/models/audio/music/diffwave.py @@ -13,17 +13,16 @@ # limitations under the License. # ============================================================================== +from math import sqrt + import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import torch_intermediary as ml - -from math import sqrt - from torch.utils.checkpoint import checkpoint -from trainer.networks import register_model +import dlas.torch_intermediary as ml +from dlas.trainer.networks import register_model Linear = ml.Linear ConvTranspose2d = nn.ConvTranspose2d @@ -43,7 +42,8 @@ def silu(x): class DiffusionEmbedding(nn.Module): def __init__(self, max_steps): super().__init__() - self.register_buffer('embedding', self._build_embedding(max_steps), persistent=False) + self.register_buffer('embedding', self._build_embedding( + max_steps), persistent=False) self.projection1 = Linear(128, 512) self.projection2 = Linear(512, 512) @@ -76,8 +76,10 @@ class DiffusionEmbedding(nn.Module): class SpectrogramUpsampler(nn.Module): def __init__(self, n_mels): super().__init__() - self.conv1 = ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8]) - self.conv2 = ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8]) + self.conv1 = ConvTranspose2d(1, 1, [3, 32], stride=[ + 1, 16], padding=[1, 8]) + self.conv2 = ConvTranspose2d(1, 1, [3, 32], stride=[ + 1, 16], padding=[1, 8]) def forward(self, x): x = torch.unsqueeze(x, 1) @@ -98,27 +100,32 @@ class ResidualBlock(nn.Module): :param uncond: disable spectrogram conditional ''' super().__init__() - self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation) + self.dilated_conv = Conv1d( + residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation) self.diffusion_projection = Linear(512, residual_channels) if not uncond: # conditional model - self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1) + self.conditioner_projection = Conv1d( + n_mels, 2 * residual_channels, 1) else: # unconditional model self.conditioner_projection = None - self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) + self.output_projection = Conv1d( + residual_channels, 2 * residual_channels, 1) def forward(self, x, diffusion_step, conditioner=None): assert (conditioner is None and self.conditioner_projection is None) or \ (conditioner is not None and self.conditioner_projection is not None) - diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) + diffusion_step = self.diffusion_projection( + diffusion_step).unsqueeze(-1) y = x + diffusion_step if self.conditioner_projection is None: # using a unconditional model y = self.dilated_conv(y) else: y = self.dilated_conv(y) conditioner = self.conditioner_projection(conditioner) - conditioner = F.interpolate(conditioner, size=y.shape[-1], mode='nearest') + conditioner = F.interpolate( + conditioner, size=y.shape[-1], mode='nearest') y = y + conditioner gate, filter = torch.chunk(y, 2, dim=1) @@ -141,7 +148,8 @@ class DiffWave(nn.Module): self.spectrogram_upsampler = SpectrogramUpsampler(n_mels) self.residual_layers = nn.ModuleList([ - ResidualBlock(n_mels, residual_channels, 2 ** (i % dilation_cycle_length), uncond=unconditional) + ResidualBlock(n_mels, residual_channels, 2 ** (i % + dilation_cycle_length), uncond=unconditional) for i in range(residual_layers) ]) self.skip_projection = Conv1d(residual_channels, residual_channels, 1) @@ -177,4 +185,5 @@ def register_diffwave(opt_net, opt): if __name__ == '__main__': model = DiffWave() - model(torch.randn(2,1,65536), torch.tensor([500,3999]), torch.randn(2,128,256)) \ No newline at end of file + model(torch.randn(2, 1, 65536), torch.tensor( + [500, 3999]), torch.randn(2, 128, 256)) diff --git a/dlas/models/audio/music/encoders.py b/dlas/models/audio/music/encoders.py index c2fd3a66..3654e909 100644 --- a/dlas/models/audio/music/encoders.py +++ b/dlas/models/audio/music/encoders.py @@ -3,10 +3,10 @@ import torch.nn.functional as F from torch import nn from transformers import GPT2Config, GPT2Model -from models.arch_util import AttentionBlock, ResBlock -from models.audio.tts.lucidrains_dvae import DiscreteVAE -from trainer.networks import register_model -from utils.util import opt_get, ceil_multiple, print_network +from dlas.models.arch_util import AttentionBlock, ResBlock +from dlas.models.audio.tts.lucidrains_dvae import DiscreteVAE +from dlas.trainer.networks import register_model +from dlas.utils.util import ceil_multiple, opt_get, print_network class ResEncoder16x(nn.Module): @@ -18,20 +18,29 @@ class ResEncoder16x(nn.Module): ): super().__init__() attn = [] + def edim(m): dd = min(spec_dim + m * 128, hidden_dim) return ceil_multiple(dd, 8) self.downsampler = nn.Sequential( - ResBlock(spec_dim, out_channels=edim(2), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled), - ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled), - ResBlock(edim(3), out_channels=edim(3), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled), - ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled), - ResBlock(edim(4), out_channels=edim(4), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled), + ResBlock(spec_dim, out_channels=edim(2), use_conv=True, dims=1, + down=True, checkpointing_enabled=checkpointing_enabled), + ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, + down=True, checkpointing_enabled=checkpointing_enabled), + ResBlock(edim(3), out_channels=edim(3), use_conv=True, + dims=1, checkpointing_enabled=checkpointing_enabled), + ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1, + down=True, checkpointing_enabled=checkpointing_enabled), + ResBlock(edim(4), out_channels=edim(4), use_conv=True, + dims=1, checkpointing_enabled=checkpointing_enabled), ResBlock(edim(4), out_channels=hidden_dim, use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled)) self.encoder = nn.Sequential( - ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled), - ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled), - ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled), + ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, + dims=1, checkpointing_enabled=checkpointing_enabled), + ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, + dims=1, checkpointing_enabled=checkpointing_enabled), + ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, + dims=1, checkpointing_enabled=checkpointing_enabled), nn.GroupNorm(8, hidden_dim), nn.SiLU(), nn.Conv1d(hidden_dim, embedding_dim, 1), diff --git a/dlas/models/audio/music/flat_diffusion.py b/dlas/models/audio/music/flat_diffusion.py index 9b8d897b..5f91d386 100644 --- a/dlas/models/audio/music/flat_diffusion.py +++ b/dlas/models/audio/music/flat_diffusion.py @@ -4,18 +4,22 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import autocast -import torch_intermediary as ml -from models.arch_util import ResBlock -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, TimestepBlock -from trainer.networks import register_model -from utils.util import checkpoint +import dlas.torch_intermediary as ml +from dlas.models.arch_util import ResBlock +from dlas.models.diffusion.nn import (conv_nd, linear, normalization, + timestep_embedding, zero_module) +from dlas.models.diffusion.unet_diffusion import (AttentionBlock, + TimestepBlock, + TimestepEmbedSequential) +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint def is_latent(t): return t.dtype == torch.float + def is_sequence(t): return t.dtype == torch.long @@ -24,7 +28,8 @@ class MultiGroupEmbedding(nn.Module): def __init__(self, tokens, groups, dim): super().__init__() # nn.Embedding - self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)]) + self.m = nn.ModuleList( + [ml.Embedding(tokens, dim // groups) for _ in range(groups)]) def forward(self, x): h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] @@ -56,7 +61,8 @@ class TimestepResBlock(TimestepBlock): self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), - conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding), + conv_nd(dims, channels, self.out_channels, + eff_kernel, padding=eff_padding), ) self.emb_layers = nn.Sequential( @@ -71,14 +77,16 @@ class TimestepResBlock(TimestepBlock): nn.SiLU(), nn.Dropout(p=dropout), zero_module( - conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) + conv_nd(dims, self.out_channels, self.out_channels, + kernel_size, padding=padding) ), ) if self.out_channels == channels: self.skip_connection = nn.Identity() else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding) + self.skip_connection = conv_nd( + dims, channels, self.out_channels, eff_kernel, padding=eff_padding) def forward(self, x, emb): """ @@ -111,8 +119,10 @@ class TimestepResBlock(TimestepBlock): class DiffusionLayer(TimestepBlock): def __init__(self, model_channels, dropout, num_heads): super().__init__() - self.resblk = TimestepResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True) - self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True) + self.resblk = TimestepResBlock( + model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True) + self.attn = AttentionBlock( + model_channels, num_heads, relative_pos_embeddings=True) def forward(self, x, time_emb): y = self.resblk(x, time_emb) @@ -134,7 +144,8 @@ class FlatDiffusion(nn.Module): num_heads=8, # Parameters for regularization. layer_drop=.1, - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. + # This implements a mechanism similar to what is used in classifier-free training. + unconditioned_percentage=.1, train_mel_head=False, ): super().__init__() @@ -163,37 +174,52 @@ class FlatDiffusion(nn.Module): # nn.Embedding self.embeddings = ml.Embedding(token_count, model_channels) else: - self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels) + self.embeddings = MultiGroupEmbedding( + token_count, in_groups, model_channels) self.latent_conditioner = nn.Sequential( nn.Conv1d(in_latent_channels, model_channels, 3, padding=1), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, + relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, + relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, + relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, + relative_pos_embeddings=True), ) self.code_converter = nn.Sequential( ResBlock(dims=1, channels=model_channels, dropout=dropout), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, + relative_pos_embeddings=True), ResBlock(dims=1, channels=model_channels, dropout=dropout), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, + relative_pos_embeddings=True), ResBlock(dims=1, channels=model_channels, dropout=dropout), ) self.code_norm = normalization(model_channels) - self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2), - nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2), - AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), - AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), - AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), - AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2), + nn.Conv1d( + model_channels, model_channels*2, 3, padding=1, stride=2), + AttentionBlock( + model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + AttentionBlock( + model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + AttentionBlock( + model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + AttentionBlock( + model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False)) - self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1)) + self.unconditioned_embedding = nn.Parameter( + torch.randn(1, model_channels, 1)) self.conditioning_timestep_integrator = TimestepEmbedSequential( DiffusionLayer(model_channels, dropout, num_heads), DiffusionLayer(model_channels, dropout, num_heads), DiffusionLayer(model_channels, dropout, num_heads), ) - self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1) - self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) + self.integrating_conv = nn.Conv1d( + model_channels*2, model_channels, kernel_size=1) + self.mel_head = nn.Conv1d( + model_channels, in_channels, kernel_size=3, padding=1) self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] + [TimestepResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)]) @@ -201,7 +227,8 @@ class FlatDiffusion(nn.Module): self.out = nn.Sequential( normalization(model_channels), nn.SiLU(), - zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), + zero_module(conv_nd(1, model_channels, + out_channels, 3, padding=1)), ) if train_mel_head: @@ -231,7 +258,8 @@ class FlatDiffusion(nn.Module): conditioning_input.shape) == 3 else conditioning_input conds = [] for j in range(speech_conditioning_input.shape[1]): - conds.append(self.contextual_embedder(speech_conditioning_input[:, j])) + conds.append(self.contextual_embedder( + speech_conditioning_input[:, j])) conds = torch.cat(conds, dim=-1) cond_emb = conds.mean(dim=-1) cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1) @@ -239,18 +267,21 @@ class FlatDiffusion(nn.Module): code_emb = self.latent_conditioner(aligned_conditioning) else: code_emb = self.embeddings(aligned_conditioning) - code_emb = code_emb.permute(0,2,1) + code_emb = code_emb.permute(0, 2, 1) - unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device) + unconditioned_batches = torch.zeros( + (code_emb.shape[0], 1, 1), device=code_emb.device) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. if self.training and self.unconditioned_percentage > 0: unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device) < self.unconditioned_percentage code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1), code_emb) - expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest') + expanded_code_emb = F.interpolate( + code_emb, size=expected_seq_len, mode='nearest') expanded_code_emb = self.code_converter(expanded_code_emb) - expanded_code_emb = self.code_norm(expanded_code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1) + expanded_code_emb = self.code_norm( + expanded_code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1) if not return_code_pred: return expanded_code_emb @@ -260,7 +291,6 @@ class FlatDiffusion(nn.Module): mel_pred = mel_pred * unconditioned_batches.logical_not() return expanded_code_emb, mel_pred - def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False): """ Apply the model to an input batch. @@ -273,27 +303,36 @@ class FlatDiffusion(nn.Module): :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered. :return: an [N x C x ...] Tensor of outputs. """ - assert precomputed_aligned_embeddings is not None or (codes is not None and conditioning_input is not None) - assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive. + assert precomputed_aligned_embeddings is not None or ( + codes is not None and conditioning_input is not None) + # These two are mutually exclusive. + assert not ( + return_code_pred and precomputed_aligned_embeddings is not None) unused_params = [] if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) - unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) + code_emb = self.unconditioned_embedding.repeat( + x.shape[0], 1, x.shape[-1]) + unused_params.extend( + list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) unused_params.extend(list(self.latent_conditioner.parameters())) else: if precomputed_aligned_embeddings is not None: code_emb = precomputed_aligned_embeddings else: - code_emb, mel_pred = self.timestep_independent(codes, conditioning_input, x.shape[-1], True) + code_emb, mel_pred = self.timestep_independent( + codes, conditioning_input, x.shape[-1], True) if is_latent(codes): - unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) + unused_params.extend( + list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) else: - unused_params.extend(list(self.latent_conditioner.parameters())) + unused_params.extend( + list(self.latent_conditioner.parameters())) unused_params.append(self.unconditioned_embedding) - time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + time_emb = self.time_embed( + timestep_embedding(timesteps, self.model_channels)) code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) x = self.inp_block(x) x = torch.cat([x, code_emb], dim=1) @@ -325,10 +364,12 @@ class FlatDiffusion(nn.Module): conditioning_input.shape) == 3 else conditioning_input conds = [] for j in range(speech_conditioning_input.shape[1]): - conds.append(self.contextual_embedder(speech_conditioning_input[:, j])) + conds.append(self.contextual_embedder( + speech_conditioning_input[:, j])) conds = torch.cat(conds, dim=-1) return conds.mean(dim=-1) + @register_model def register_flat_diffusion(opt_net, opt): return FlatDiffusion(**opt_net['kwargs']) @@ -336,13 +377,13 @@ def register_flat_diffusion(opt_net, opt): if __name__ == '__main__': clip = torch.randn(2, 256, 400) - aligned_latent = torch.randn(2,388,512) - aligned_sequence = torch.randint(0,8,(2,100,8)) + aligned_latent = torch.randn(2, 388, 512) + aligned_sequence = torch.randint(0, 8, (2, 100, 8)) cond = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) - model = FlatDiffusion(512, layer_drop=.3, unconditioned_percentage=.5, in_groups=8, train_mel_head=True) + model = FlatDiffusion( + 512, layer_drop=.3, unconditioned_percentage=.5, in_groups=8, train_mel_head=True) # Test with latent aligned conditioning - #o = model(clip, ts, aligned_latent, cond) + # o = model(clip, ts, aligned_latent, cond) # Test with sequence aligned conditioning o = model(clip, ts, aligned_sequence, cond, return_code_pred=True) - diff --git a/dlas/models/audio/music/gpt_music.py b/dlas/models/audio/music/gpt_music.py index 03322df9..1603f48a 100644 --- a/dlas/models/audio/music/gpt_music.py +++ b/dlas/models/audio/music/gpt_music.py @@ -1,17 +1,17 @@ import torch -from torch import nn import torch.nn.functional as F +from torch import nn from transformers import GPT2Config, GPT2Model -import torch_intermediary as ml -from models.arch_util import AttentionBlock, ResBlock -from models.audio.music.music_quantizer import MusicQuantizer -from models.audio.music.music_quantizer2 import MusicQuantizer2 -from models.audio.tts.lucidrains_dvae import DiscreteVAE -from models.lucidrains.x_transformers import Encoder -from models.vqvae.vqvae import Quantize -from trainer.networks import register_model -from utils.util import opt_get, checkpoint, ceil_multiple, print_network +import dlas.torch_intermediary as ml +from dlas.models.arch_util import AttentionBlock, ResBlock +from dlas.models.audio.music.music_quantizer import MusicQuantizer +from dlas.models.audio.music.music_quantizer2 import MusicQuantizer2 +from dlas.models.audio.tts.lucidrains_dvae import DiscreteVAE +from dlas.models.lucidrains.x_transformers import Encoder +from dlas.models.vqvae.vqvae import Quantize +from dlas.trainer.networks import register_model +from dlas.utils.util import ceil_multiple, checkpoint, opt_get, print_network class ConditioningEncoder(nn.Module): @@ -22,9 +22,11 @@ class ConditioningEncoder(nn.Module): num_attn_heads=4): super().__init__() attn = [] - self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=3, stride=2, padding=1) + self.init = nn.Conv1d(spec_dim, embedding_dim, + kernel_size=3, stride=2, padding=1) for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_activation=True)) + attn.append(AttentionBlock(embedding_dim, + num_attn_heads, do_activation=True)) self.attn = nn.Sequential(*attn) self.dim = embedding_dim @@ -43,14 +45,20 @@ class UpperConditioningEncoder(nn.Module): super().__init__() attn = [] self.init = nn.Sequential(nn.Conv1d(spec_dim, min(spec_dim+128, embedding_dim), kernel_size=3, stride=2, padding=1), - nn.Conv1d(min(spec_dim+128, embedding_dim), min(spec_dim+256, embedding_dim), kernel_size=3, stride=2, padding=1), - nn.Conv1d(min(spec_dim+256, embedding_dim), min(spec_dim+384, embedding_dim), kernel_size=3, stride=2, padding=1), - nn.Conv1d(min(spec_dim+384, embedding_dim), min(spec_dim+512, embedding_dim), kernel_size=3, stride=2, padding=1), - ResBlock(min(spec_dim+512, embedding_dim), dims=1), - nn.Conv1d(min(spec_dim+512, embedding_dim), min(spec_dim+512, embedding_dim), kernel_size=3, stride=2, padding=1), + nn.Conv1d(min(spec_dim+128, embedding_dim), min( + spec_dim+256, embedding_dim), kernel_size=3, stride=2, padding=1), + nn.Conv1d(min(spec_dim+256, embedding_dim), min( + spec_dim+384, embedding_dim), kernel_size=3, stride=2, padding=1), + nn.Conv1d(min(spec_dim+384, embedding_dim), min( + spec_dim+512, embedding_dim), kernel_size=3, stride=2, padding=1), + ResBlock( + min(spec_dim+512, embedding_dim), dims=1), + nn.Conv1d(min(spec_dim+512, embedding_dim), min( + spec_dim+512, embedding_dim), kernel_size=3, stride=2, padding=1), ResBlock(min(spec_dim+512, embedding_dim), dims=1)) for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_activation=True)) + attn.append(AttentionBlock(embedding_dim, + num_attn_heads, do_activation=True)) self.attn = nn.Sequential(*attn) self.dim = embedding_dim @@ -67,21 +75,31 @@ class UpperQuantizer(nn.Module): num_tokens): super().__init__() attn = [] + def edim(m): dd = max(embedding_dim//m, 128, spec_dim) return ceil_multiple(dd, 8) self.encoder = nn.Sequential( - ResBlock(spec_dim, out_channels=edim(6), use_conv=True, dims=1, down=True), - ResBlock(edim(6), out_channels=edim(5), use_conv=True, dims=1, down=True), - ResBlock(edim(5), out_channels=edim(4), use_conv=True, dims=1, down=True), - ResBlock(edim(4), out_channels=edim(3), use_conv=True, dims=1, down=True), + ResBlock(spec_dim, out_channels=edim(6), + use_conv=True, dims=1, down=True), + ResBlock(edim(6), out_channels=edim(5), + use_conv=True, dims=1, down=True), + ResBlock(edim(5), out_channels=edim(4), + use_conv=True, dims=1, down=True), + ResBlock(edim(4), out_channels=edim(3), + use_conv=True, dims=1, down=True), ResBlock(edim(3), out_channels=edim(3), use_conv=True, dims=1), - ResBlock(edim(3), out_channels=edim(2), use_conv=True, dims=1, down=True), + ResBlock(edim(3), out_channels=edim(2), + use_conv=True, dims=1, down=True), ResBlock(edim(2), out_channels=edim(2), use_conv=True, dims=1), - ResBlock(edim(2), out_channels=embedding_dim, use_conv=True, dims=1, down=True), - ResBlock(embedding_dim, out_channels=embedding_dim, use_conv=True, dims=1), - ResBlock(embedding_dim, out_channels=embedding_dim, use_conv=True, dims=1), - ResBlock(embedding_dim, out_channels=embedding_dim, use_conv=True, dims=1), + ResBlock(edim(2), out_channels=embedding_dim, + use_conv=True, dims=1, down=True), + ResBlock(embedding_dim, out_channels=embedding_dim, + use_conv=True, dims=1), + ResBlock(embedding_dim, out_channels=embedding_dim, + use_conv=True, dims=1), + ResBlock(embedding_dim, out_channels=embedding_dim, + use_conv=True, dims=1), nn.GroupNorm(8, embedding_dim) ) self.quantizer = Quantize(embedding_dim, num_tokens) @@ -95,7 +113,7 @@ class UpperQuantizer(nn.Module): h = x for lyr in self.encoder: h = lyr(h) - h = h.permute(0,2,1) + h = h.permute(0, 2, 1) h_quant, commitment_loss, codes = self.quantizer(h) self.log_codes(codes) return h_quant, commitment_loss @@ -105,7 +123,8 @@ class UpperQuantizer(nn.Module): if self.internal_step % 10 == 0: codes = codes.flatten() l = codes.shape[0] - i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l + i = self.code_ind if ( + self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l self.codes[i:i+l] = codes.cpu() self.code_ind = self.code_ind + l if self.code_ind >= self.codes.shape[0]: @@ -122,7 +141,8 @@ class GptMusicLower(nn.Module): self.freeze_upper_until = freeze_upper_until self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64, n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, use_cache=False) - self.target_quantizers = nn.ModuleList([DiscreteVAE(**vqargs).eval() for _ in range(num_vaes)]) + self.target_quantizers = nn.ModuleList( + [DiscreteVAE(**vqargs).eval() for _ in range(num_vaes)]) self.upper_quantizer = UpperQuantizer(256, dim, num_upper_vectors) self.fp16 = fp16 self.internal_step = 0 @@ -132,14 +152,17 @@ class GptMusicLower(nn.Module): p.DO_NOT_TRAIN = True p.requires_grad = False - self.conditioning_encoder = ConditioningEncoder(256, dim, attn_blocks=4, num_attn_heads=dim//64) + self.conditioning_encoder = ConditioningEncoder( + 256, dim, attn_blocks=4, num_attn_heads=dim//64) self.gpt = GPT2Model(self.config) del self.gpt.wte # Unused, we'll do our own embeddings. # nn.Embedding - self.embeddings = nn.ModuleList([ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)]) - self.heads = nn.ModuleList([ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)]) + self.embeddings = nn.ModuleList( + [ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)]) + self.heads = nn.ModuleList( + [ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)]) def forward(self, mel, conditioning, return_latent=False): unused_params = [] @@ -159,14 +182,17 @@ class GptMusicLower(nn.Module): unused_params.extend(list(self.upper_quantizer.parameters())) else: self.upper_quantizer = self.upper_quantizer.train() - upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True) - upper_vector = F.interpolate(upper_vector.permute(0,2,1), size=codes.shape[1], mode='linear') - upper_vector = upper_vector.permute(0,2,1) + upper_vector, upper_diversity = self.upper_quantizer( + mel, return_decoder_latent=True) + upper_vector = F.interpolate(upper_vector.permute( + 0, 2, 1), size=codes.shape[1], mode='linear') + upper_vector = upper_vector.permute(0, 2, 1) inputs = codes[:, :-1] targets = codes upper_vector = upper_vector[:, :-1] - h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)] + h = [embedding(inputs[:, :, i]) + for i, embedding in enumerate(self.embeddings)] h = torch.cat(h, dim=-1) + upper_vector with torch.autocast(mel.device.type, enabled=self.fp16): @@ -183,8 +209,8 @@ class GptMusicLower(nn.Module): losses = 0 for i, head in enumerate(self.heads): - logits = head(h).permute(0,2,1) - loss = F.cross_entropy(logits, targets[:,:,i]) + logits = head(h).permute(0, 2, 1) + loss = F.cross_entropy(logits, targets[:, :, i]) losses = losses + loss unused_adder = 0 @@ -221,11 +247,15 @@ class GptMusicUpper(nn.Module): n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, use_cache=False) self.upper_quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[dim, - max(512,dim-128), - max(512,dim-256), - max(512,dim-384), - max(512,dim-512), - max(512,dim-512)], codevector_dim=dim, + max(512, + dim-128), + max(512, + dim-256), + max(512, + dim-384), + max(512, + dim-512), + max(512, dim-512)], codevector_dim=dim, codebook_size=num_upper_vectors, codebook_groups=num_upper_groups, expressive_downsamples=True) # Following are unused quantizer constructs we delete to avoid DDP errors (and to be efficient.. of course..) @@ -235,15 +265,17 @@ class GptMusicUpper(nn.Module): p.DO_NOT_TRAIN = True p.requires_grad = False - self.conditioning_encoder = UpperConditioningEncoder(256, dim, attn_blocks=4, num_attn_heads=dim//64) + self.conditioning_encoder = UpperConditioningEncoder( + 256, dim, attn_blocks=4, num_attn_heads=dim//64) self.gpt = GPT2Model(self.config) del self.gpt.wte # Unused, we'll do our own embeddings. # nn.Embedding - self.embeddings = nn.ModuleList([ml.Embedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)]) - self.heads = nn.ModuleList([ml.Linear(dim, num_upper_vectors) for _ in range(num_upper_groups)]) - + self.embeddings = nn.ModuleList([ml.Embedding( + num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)]) + self.heads = nn.ModuleList( + [ml.Linear(dim, num_upper_vectors) for _ in range(num_upper_groups)]) def forward(self, mel, conditioning, return_latent=False): with torch.no_grad(): @@ -252,7 +284,8 @@ class GptMusicUpper(nn.Module): inputs = codes[:, :-1] targets = codes - h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)] + h = [embedding(inputs[:, :, i]) + for i, embedding in enumerate(self.embeddings)] h = torch.cat(h, dim=-1) with torch.autocast(mel.device.type, enabled=self.fp16): @@ -269,8 +302,8 @@ class GptMusicUpper(nn.Module): losses = 0 for i, head in enumerate(self.heads): - logits = head(h).permute(0,2,1) - loss = F.cross_entropy(logits, targets[:,:,i]) + logits = head(h).permute(0, 2, 1) + loss = F.cross_entropy(logits, targets[:, :, i]) losses = losses + loss return losses / self.num_groups @@ -293,6 +326,7 @@ class GptMusicUpper(nn.Module): def register_music_gpt_lower(opt_net, opt): return GptMusicLower(**opt_get(opt_net, ['kwargs'], {})) + @register_model def register_music_gpt_upper(opt_net, opt): return GptMusicUpper(**opt_get(opt_net, ['kwargs'], {})) @@ -301,11 +335,11 @@ def register_music_gpt_upper(opt_net, opt): def test_lower(): model = GptMusicLower(dim=512, layers=12, fp16=False, freeze_upper_until=1000, num_target_vectors=8192, num_upper_vectors=8192, num_vaes=4, - vqargs= { - 'positional_dims': 1, 'channels': 64, - 'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192, - 'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False, - }) + vqargs={ + 'positional_dims': 1, 'channels': 64, + 'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192, + 'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False, + }) quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\7500_generator.pth', 'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\11000_generator.pth', 'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_high\\models\\11500_generator.pth', @@ -316,7 +350,7 @@ def test_lower(): torch.save(model.state_dict(), 'sample.pth') print_network(model) - mel = torch.randn(2,256,400) + mel = torch.randn(2, 256, 400) model(mel, mel) pg = model.get_grad_norm_parameter_groups() @@ -335,14 +369,15 @@ def test_lower(): def test_upper(): lower = GptMusicLower(512, 12) - lower.load_state_dict(torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth')) + lower.load_state_dict(torch.load( + 'D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth')) model = GptMusicUpper(512, 12) model.upper_quantizer.load_state_dict(lower.upper_quantizer.state_dict()) torch.save(model.state_dict(), 'sample.pth') - mel = torch.randn(2,256,2500) + mel = torch.randn(2, 256, 2500) model(mel, mel) model.get_grad_norm_parameter_groups() if __name__ == '__main__': - test_lower() \ No newline at end of file + test_lower() diff --git a/dlas/models/audio/music/gpt_music2.py b/dlas/models/audio/music/gpt_music2.py index d508eb70..eaa4bf79 100644 --- a/dlas/models/audio/music/gpt_music2.py +++ b/dlas/models/audio/music/gpt_music2.py @@ -2,12 +2,12 @@ import torch import torch.nn.functional as F from torch import nn from transformers import GPT2Config, GPT2Model -import torch_intermediary as ml -from models.arch_util import AttentionBlock, ResBlock -from models.audio.tts.lucidrains_dvae import DiscreteVAE -from trainer.networks import register_model -from utils.util import opt_get, ceil_multiple, print_network +import dlas.torch_intermediary as ml +from dlas.models.arch_util import AttentionBlock, ResBlock +from dlas.models.audio.tts.lucidrains_dvae import DiscreteVAE +from dlas.trainer.networks import register_model +from dlas.utils.util import ceil_multiple, opt_get, print_network class UpperEncoder(nn.Module): @@ -19,22 +19,30 @@ class UpperEncoder(nn.Module): ): super().__init__() attn = [] + def edim(m): dd = min(spec_dim + m * 128, hidden_dim) return ceil_multiple(dd, 8) self.downsampler = nn.Sequential( - ResBlock(spec_dim, out_channels=edim(1), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled), - ResBlock(edim(1), out_channels=edim(2), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled), - ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled), - ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled), + ResBlock(spec_dim, out_channels=edim(1), use_conv=True, dims=1, + down=True, checkpointing_enabled=checkpointing_enabled), + ResBlock(edim(1), out_channels=edim(2), use_conv=True, dims=1, + down=True, checkpointing_enabled=checkpointing_enabled), + ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, + down=True, checkpointing_enabled=checkpointing_enabled), + ResBlock(edim(3), out_channels=edim(4), use_conv=True, + dims=1, checkpointing_enabled=checkpointing_enabled), ResBlock(edim(4), out_channels=hidden_dim, use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled)) self.encoder = nn.Sequential( AttentionBlock(hidden_dim, 4, do_activation=True), - ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled), + ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, + dims=1, checkpointing_enabled=checkpointing_enabled), AttentionBlock(hidden_dim, 4, do_activation=True), - ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled), + ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, + dims=1, checkpointing_enabled=checkpointing_enabled), AttentionBlock(hidden_dim, 4, do_activation=True), - ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled), + ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, + dims=1, checkpointing_enabled=checkpointing_enabled), nn.GroupNorm(8, hidden_dim), nn.SiLU(), nn.Conv1d(hidden_dim, embedding_dim, 1), @@ -47,8 +55,6 @@ class UpperEncoder(nn.Module): return h - - class GptMusicLower(nn.Module): def __init__(self, dim, layers, encoder_out_dim, dropout=0, num_target_vectors=8192, fp16=True, num_vaes=4, vqargs={}): super().__init__() @@ -58,7 +64,8 @@ class GptMusicLower(nn.Module): n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, use_cache=False) - self.target_quantizers = nn.ModuleList([DiscreteVAE(**vqargs).eval() for _ in range(num_vaes)]) + self.target_quantizers = nn.ModuleList( + [DiscreteVAE(**vqargs).eval() for _ in range(num_vaes)]) self.upper_encoder = UpperEncoder(256, dim, encoder_out_dim) self.encoder_projector = nn.Conv1d(encoder_out_dim, dim, 1) self.fp16 = fp16 @@ -75,8 +82,10 @@ class GptMusicLower(nn.Module): del self.gpt.wte # Unused, we'll do our own embeddings. # nn.Embedding - self.embeddings = nn.ModuleList([ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)]) - self.heads = nn.ModuleList([ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)]) + self.embeddings = nn.ModuleList( + [ml.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)]) + self.heads = nn.ModuleList( + [ml.Linear(dim, num_target_vectors) for _ in range(num_vaes)]) def forward(self, mel, return_latent=False): unused_params = [] @@ -92,20 +101,23 @@ class GptMusicLower(nn.Module): upper_vector = self.upper_encoder(mel) upper_vector = self.encoder_projector(upper_vector) # WTB slerp - upper_vector = F.interpolate(upper_vector, size=codes.shape[1], mode='linear') - upper_vector = upper_vector.permute(0,2,1) + upper_vector = F.interpolate( + upper_vector, size=codes.shape[1], mode='linear') + upper_vector = upper_vector.permute(0, 2, 1) inputs = codes[:, :-1] targets = codes upper_vector = upper_vector[:, :-1] - h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)] + h = [embedding(inputs[:, :, i]) + for i, embedding in enumerate(self.embeddings)] h = torch.cat(h, dim=-1) + upper_vector with torch.autocast(mel.device.type, enabled=self.fp16): # Stick the conditioning embedding on the front of the input sequence. # The transformer will learn how to integrate it. # This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token. - h = torch.cat([self.start_token.repeat(h.shape[0], 1, 1), h], dim=1) + h = torch.cat( + [self.start_token.repeat(h.shape[0], 1, 1), h], dim=1) h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state @@ -114,8 +126,8 @@ class GptMusicLower(nn.Module): losses = 0 for i, head in enumerate(self.heads): - logits = head(h).permute(0,2,1) - loss = F.cross_entropy(logits, targets[:,:,i]) + logits = head(h).permute(0, 2, 1) + loss = F.cross_entropy(logits, targets[:, :, i]) losses = losses + loss unused_adder = 0 @@ -143,10 +155,10 @@ def register_music_gpt_lower2(opt_net, opt): def test_lower(): model = GptMusicLower(dim=1024, encoder_out_dim=256, layers=16, fp16=False, num_target_vectors=8192, num_vaes=4, - vqargs= {'positional_dims': 1, 'channels': 64, - 'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192, - 'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False, - }) + vqargs={'positional_dims': 1, 'channels': 64, + 'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192, + 'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False, + }) quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\7500_generator.pth', 'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\11000_generator.pth', 'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_high\\models\\11500_generator.pth', @@ -157,7 +169,7 @@ def test_lower(): torch.save(model.state_dict(), 'sample.pth') print_network(model) - mel = torch.randn(2,256,400) + mel = torch.randn(2, 256, 400) model(mel) pg = model.get_grad_norm_parameter_groups() diff --git a/dlas/models/audio/music/instrument_quantizer.py b/dlas/models/audio/music/instrument_quantizer.py index 42b5706a..ffbd917d 100644 --- a/dlas/models/audio/music/instrument_quantizer.py +++ b/dlas/models/audio/music/instrument_quantizer.py @@ -3,13 +3,15 @@ import functools import torch import torch.nn as nn import torch.nn.functional as F -import torch_intermediary as ml -from models.diffusion.nn import timestep_embedding -from models.lucidrains.vq import VectorQuantize -from models.lucidrains.x_transformers import FeedForward, Attention, Decoder, RMSScaleShiftNorm -from trainer.networks import register_model -from utils.util import checkpoint +import dlas.torch_intermediary as ml +from dlas.models.diffusion.nn import timestep_embedding +from dlas.models.lucidrains.vq import VectorQuantize +from dlas.models.lucidrains.x_transformers import (Attention, Decoder, + FeedForward, + RMSScaleShiftNorm) +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint class SelfClassifyingHead(nn.Module): @@ -19,7 +21,7 @@ class SelfClassifyingHead(nn.Module): self.num_classes = classes self.temperature = init_temperature self.dec = Decoder(dim=dim, depth=head_depth, heads=4, ff_dropout=dropout, ff_mult=2, attn_dropout=dropout, - use_rmsnorm=True, ff_glu=True, do_checkpointing=False) + use_rmsnorm=True, ff_glu=True, do_checkpointing=False) self.quantizer = VectorQuantize(out_dim, classes, use_cosine_sim=False, threshold_ema_dead_code=2, sample_codebook_temp=init_temperature) self.to_output = ml.Linear(dim, out_dim) @@ -39,7 +41,8 @@ class SelfClassifyingHead(nn.Module): codes = [] q_reg = 0 for i in range(self.seq_len): - q, c = checkpoint(functools.partial(self.do_ar_step, used_codes=codes), torch.stack(stack, dim=1)) + q, c = checkpoint(functools.partial( + self.do_ar_step, used_codes=codes), torch.stack(stack, dim=1)) q_reg = q_reg + (q ** 2).mean() s = torch.sigmoid(q) @@ -48,10 +51,13 @@ class SelfClassifyingHead(nn.Module): # If the addition would strictly make the result worse, set it to 0. Sometimes. if len(results) > 0: - worsen = (F.mse_loss(outputs[-1], target, reduction='none').sum(-1) < F.mse_loss(output, target, reduction='none').sum(-1)).float() + worsen = (F.mse_loss(outputs[-1], target, reduction='none').sum(-1) < F.mse_loss( + output, target, reduction='none').sum(-1)).float() probabilistic_worsen = torch.rand_like(worsen) * worsen > .5 - output = output * probabilistic_worsen.unsqueeze(-1) # This is non-differentiable, but still deterministic. - c[probabilistic_worsen] = -1 # Code of -1 means the code was unused. + # This is non-differentiable, but still deterministic. + output = output * probabilistic_worsen.unsqueeze(-1) + # Code of -1 means the code was unused. + c[probabilistic_worsen] = -1 s = s * probabilistic_worsen.unsqueeze(-1) outputs[-1] = s @@ -65,7 +71,8 @@ class VectorResBlock(nn.Module): def __init__(self, dim, dropout): super().__init__() self.norm = nn.BatchNorm1d(dim) - self.ff = FeedForward(dim, mult=2, glu=True, dropout=dropout, zero_init_output=True) + self.ff = FeedForward(dim, mult=2, glu=True, + dropout=dropout, zero_init_output=True) def forward(self, x): h = self.norm(x.unsqueeze(-1)).squeeze(-1) @@ -92,8 +99,10 @@ class InstrumentQuantizer(nn.Module): super().__init__() self.op_dim = op_dim self.proj = ml.Linear(op_dim, dim) - self.encoder = nn.ModuleList([VectorResBlock(dim, dropout) for _ in range(enc_depth)]) - self.heads = SelfClassifyingHead(dim, num_classes, op_dim, head_depth, class_seq_len, dropout, max_temp) + self.encoder = nn.ModuleList( + [VectorResBlock(dim, dropout) for _ in range(enc_depth)]) + self.heads = SelfClassifyingHead( + dim, num_classes, op_dim, head_depth, class_seq_len, dropout, max_temp) self.min_gumbel_temperature = min_temp self.max_gumbel_temperature = max_temp self.gumbel_temperature_decay = temp_decay @@ -109,16 +118,19 @@ class InstrumentQuantizer(nn.Module): x = (x + 1) / 2 b, c, s = x.shape - px = x.permute(0,2,1) # B,S,C shape + px = x.permute(0, 2, 1) # B,S,C shape f = px.reshape(-1, self.op_dim) h = self.proj(f) for lyr in self.encoder: h = lyr(h) reconstructions, codes, q_reg = self.heads(h, f) - reconstruction_losses = torch.stack([F.mse_loss(r.reshape(b, s, c), px) for r in reconstructions]) - r_follow = torch.arange(1, reconstruction_losses.shape[0]+1, device=x.device) - reconstruction_losses = (reconstruction_losses * r_follow / r_follow.shape[0]) + reconstruction_losses = torch.stack( + [F.mse_loss(r.reshape(b, s, c), px) for r in reconstructions]) + r_follow = torch.arange( + 1, reconstruction_losses.shape[0]+1, device=x.device) + reconstruction_losses = ( + reconstruction_losses * r_follow / r_follow.shape[0]) self.log_codes(codes) return reconstruction_losses, q_reg @@ -126,7 +138,8 @@ class InstrumentQuantizer(nn.Module): def log_codes(self, codes): if self.internal_step % 5 == 0: l = codes.shape[0] - i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l + i = self.code_ind if ( + self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l self.codes[i:i+l] = codes.cpu() self.code_ind = self.code_ind + l if self.code_ind >= self.codes.shape[0]: @@ -143,9 +156,9 @@ class InstrumentQuantizer(nn.Module): def update_for_step(self, step, *args): self.internal_step = step self.heads.quantizer._codebook.sample_codebook_temp = max( - self.max_gumbel_temperature * self.gumbel_temperature_decay**step, - self.min_gumbel_temperature, - ) + self.max_gumbel_temperature * self.gumbel_temperature_decay**step, + self.min_gumbel_temperature, + ) def get_grad_norm_parameter_groups(self): groups = { @@ -162,6 +175,6 @@ def register_instrument_quantizer(opt_net, opt): if __name__ == '__main__': - inp = torch.randn((4,256,200)).clamp(-1,1) + inp = torch.randn((4, 256, 200)).clamp(-1, 1) model = InstrumentQuantizer(256, 512, 4096, 8, 3) model(inp) diff --git a/dlas/models/audio/music/m2v_code_to_mel.py b/dlas/models/audio/music/m2v_code_to_mel.py index 91e25363..08153cea 100644 --- a/dlas/models/audio/music/m2v_code_to_mel.py +++ b/dlas/models/audio/music/m2v_code_to_mel.py @@ -2,10 +2,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models.arch_util import ResBlock, AttentionBlock -from models.audio.music.flat_diffusion import MultiGroupEmbedding -from trainer.networks import register_model -from utils.util import checkpoint +from dlas.models.arch_util import AttentionBlock, ResBlock +from dlas.models.audio.music.flat_diffusion import MultiGroupEmbedding +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint class Code2Mel(nn.Module): @@ -13,27 +13,33 @@ class Code2Mel(nn.Module): super().__init__() self.emb = MultiGroupEmbedding(num_tokens, num_groups, base_dim) self.base_blocks = nn.Sequential(ResBlock(base_dim, dropout, dims=1), - AttentionBlock(base_dim, num_heads=base_dim//64), + AttentionBlock( + base_dim, num_heads=base_dim//64), ResBlock(base_dim, dropout, dims=1)) l2dim = base_dim-256 self.l2_up_block = nn.Conv1d(base_dim, l2dim, kernel_size=5, padding=2) self.l2_blocks = nn.Sequential(ResBlock(l2dim, dropout, kernel_size=5, dims=1), - AttentionBlock(l2dim, num_heads=base_dim//64), - ResBlock(l2dim, dropout, kernel_size=5, dims=1), - AttentionBlock(l2dim, num_heads=base_dim//64), - ResBlock(l2dim, dropout, dims=1), - ResBlock(l2dim, dropout, dims=1)) + AttentionBlock( + l2dim, num_heads=base_dim//64), + ResBlock(l2dim, dropout, + kernel_size=5, dims=1), + AttentionBlock( + l2dim, num_heads=base_dim//64), + ResBlock(l2dim, dropout, dims=1), + ResBlock(l2dim, dropout, dims=1)) l3dim = l2dim-256 self.l3_up_block = nn.Conv1d(l2dim, l3dim, kernel_size=5, padding=2) self.l3_blocks = nn.Sequential(ResBlock(l3dim, dropout, kernel_size=5, dims=1), - AttentionBlock(l3dim, num_heads=base_dim//64), - ResBlock(l3dim, dropout, kernel_size=5, dims=1), + AttentionBlock( + l3dim, num_heads=base_dim//64), + ResBlock(l3dim, dropout, + kernel_size=5, dims=1), ResBlock(l3dim, dropout, dims=1)) self.final_block = nn.Conv1d(l3dim, out_dim, kernel_size=3, padding=1) def forward(self, codes, target): with torch.autocast(codes.device.type): - h = self.emb(codes).permute(0,2,1) + h = self.emb(codes).permute(0, 2, 1) h = checkpoint(self.base_blocks, h) h = F.interpolate(h, scale_factor=2, mode='linear') h = self.l2_up_block(h) @@ -52,6 +58,6 @@ def register_code2mel(opt_net, opt): if __name__ == '__main__': model = Code2Mel() - codes = torch.randint(0,16, (2,200,4)) - target = torch.randn(2,256,804) - model(codes, target) \ No newline at end of file + codes = torch.randint(0, 16, (2, 200, 4)) + target = torch.randn(2, 256, 804) + model(codes, target) diff --git a/dlas/models/audio/music/mel2vec_codes_gpt.py b/dlas/models/audio/music/mel2vec_codes_gpt.py index 62c60768..820beafa 100644 --- a/dlas/models/audio/music/mel2vec_codes_gpt.py +++ b/dlas/models/audio/music/mel2vec_codes_gpt.py @@ -1,11 +1,11 @@ import torch -from torch import nn import torch.nn.functional as F -import torch_intermediary as ml +from torch import nn from transformers import GPT2Config, GPT2Model -from trainer.networks import register_model -from utils.util import opt_get +import dlas.torch_intermediary as ml +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get class Mel2VecCodesGpt(nn.Module): @@ -19,8 +19,10 @@ class Mel2VecCodesGpt(nn.Module): self.gpt = GPT2Model(self.config) del self.gpt.wte # Unused, we'll do our own embeddings. # nn.Embedding - self.embeddings = nn.ModuleList([ml.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)]) - self.heads = nn.ModuleList([ml.Linear(dim, num_vectors) for _ in range(num_groups)]) + self.embeddings = nn.ModuleList( + [ml.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)]) + self.heads = nn.ModuleList( + [ml.Linear(dim, num_vectors) for _ in range(num_groups)]) def forward(self, codes): assert codes.shape[-1] == self.num_groups @@ -28,14 +30,15 @@ class Mel2VecCodesGpt(nn.Module): inputs = codes[:, :-1] targets = codes[:, 1:] - h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)] + h = [embedding(inputs[:, :, i]) + for i, embedding in enumerate(self.embeddings)] h = torch.cat(h, dim=-1) h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state losses = 0 for i, head in enumerate(self.heads): - logits = head(h).permute(0,2,1) - loss = F.cross_entropy(logits, targets[:,:,i]) + logits = head(h).permute(0, 2, 1) + loss = F.cross_entropy(logits, targets[:, :, i]) losses = losses + loss return losses / self.num_groups @@ -48,5 +51,5 @@ def register_music_gpt(opt_net, opt): if __name__ == '__main__': model = Mel2VecCodesGpt(512, 8) - codes = torch.randint(0,8, (2,300,8)) - model(codes) \ No newline at end of file + codes = torch.randint(0, 8, (2, 300, 8)) + model(codes) diff --git a/dlas/models/audio/music/music_quantizer.py b/dlas/models/audio/music/music_quantizer.py index 3d3b461b..0128071a 100644 --- a/dlas/models/audio/music/music_quantizer.py +++ b/dlas/models/audio/music/music_quantizer.py @@ -1,14 +1,14 @@ import functools import torch -from torch import nn import torch.nn.functional as F -import torch_intermediary as ml +from torch import nn -from models.arch_util import zero_module -from models.vqvae.vqvae import Quantize -from trainer.networks import register_model -from utils.util import checkpoint, ceil_multiple, print_network +import dlas.torch_intermediary as ml +from dlas.models.arch_util import zero_module +from dlas.models.vqvae.vqvae import Quantize +from dlas.trainer.networks import register_model +from dlas.utils.util import ceil_multiple, checkpoint, print_network class Downsample(nn.Module): @@ -37,13 +37,13 @@ class ResBlock(nn.Module): def __init__(self, chan): super().__init__() self.net = nn.Sequential( - nn.Conv1d(chan, chan, 3, padding = 1), + nn.Conv1d(chan, chan, 3, padding=1), nn.GroupNorm(8, chan), nn.SiLU(), - nn.Conv1d(chan, chan, 3, padding = 1), + nn.Conv1d(chan, chan, 3, padding=1), nn.GroupNorm(8, chan), nn.SiLU(), - zero_module(nn.Conv1d(chan, chan, 3, padding = 1)), + zero_module(nn.Conv1d(chan, chan, 3, padding=1)), ) def forward(self, x): @@ -74,7 +74,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): # storage for codebook variables (codewords) self.codevectors = nn.Parameter( - torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups) + torch.FloatTensor(1, self.num_groups * self.num_vars, + codevector_dim // self.num_groups) ) self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars) @@ -95,7 +96,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): else: marginal_probs = probs.mean(dim=0) - perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() + perplexity = torch.exp(-torch.sum(marginal_probs * + torch.log(marginal_probs + 1e-7), dim=-1)).sum() return perplexity def get_codes(self, hidden_states): @@ -103,9 +105,11 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): # project to codevector dim hidden_states = self.weight_proj(hidden_states) - hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + hidden_states = hidden_states.view( + batch_size * sequence_length * self.num_groups, -1) codevector_idx = hidden_states.argmax(dim=-1) - idxs = codevector_idx.view(batch_size, sequence_length, self.num_groups) + idxs = codevector_idx.view( + batch_size, sequence_length, self.num_groups) return idxs def forward(self, hidden_states, mask_time_indices=None, return_probs=False): @@ -113,7 +117,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): # project to codevector dim hidden_states = self.weight_proj(hidden_states) - hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + hidden_states = hidden_states.view( + batch_size * sequence_length * self.num_groups, -1) if self.training: # sample code vector probs via gumbel in differentiable way @@ -125,7 +130,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): codevector_soft_dist = torch.softmax( hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1 ) - perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices) + perplexity = self._compute_perplexity( + codevector_soft_dist, mask_time_indices) else: # take argmax in non-differentiable way # compute hard codevector distribution (one hot) @@ -133,15 +139,20 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( -1, codevector_idx.view(-1, 1), 1.0 ) - codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) + codevector_probs = codevector_probs.view( + batch_size * sequence_length, self.num_groups, -1) - perplexity = self._compute_perplexity(codevector_probs, mask_time_indices) + perplexity = self._compute_perplexity( + codevector_probs, mask_time_indices) - codevector_probs = codevector_probs.view(batch_size * sequence_length, -1) + codevector_probs = codevector_probs.view( + batch_size * sequence_length, -1) # use probs to retrieve codevectors - codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors + codevectors_per_group = codevector_probs.unsqueeze( + -1) * self.codevectors codevectors = ( - codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1) + codevectors_per_group.view( + batch_size * sequence_length, self.num_groups, self.num_vars, -1) .sum(-2) .view(batch_size, sequence_length, -1) ) @@ -164,7 +175,8 @@ class MusicQuantizer(nn.Module): self.use_vqvae_quantizer = use_vqvae_quantizer if use_vqvae_quantizer: self.quantizer = Quantize(inner_dim[0], codebook_size) - assert codevector_dim == inner_dim[0] # Because this quantizer doesn't support different sizes. + # Because this quantizer doesn't support different sizes. + assert codevector_dim == inner_dim[0] else: self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim[0], codevector_dim=codevector_dim, num_codevector_groups=codebook_groups, @@ -174,8 +186,10 @@ class MusicQuantizer(nn.Module): self.num_losses_record = [] if down_steps == 0: - self.down = nn.Conv1d(inp_channels, inner_dim[0], kernel_size=3, padding=1) - self.up = nn.Conv1d(inner_dim[0], inp_channels, kernel_size=3, padding=1) + self.down = nn.Conv1d( + inp_channels, inner_dim[0], kernel_size=3, padding=1) + self.up = nn.Conv1d( + inner_dim[0], inp_channels, kernel_size=3, padding=1) elif down_steps == 2: self.down = nn.Sequential(nn.Conv1d(inp_channels, inner_dim[-1], kernel_size=3, padding=1), Downsample(inner_dim[-1], inner_dim[-2]), @@ -201,25 +215,27 @@ class MusicQuantizer(nn.Module): def get_codes(self, mel): h = self.down(mel) h = self.encoder(h) - h = self.enc_norm(h.permute(0,2,1)) + h = self.enc_norm(h.permute(0, 2, 1)) return self.quantizer.get_codes(h) def forward(self, mel, return_decoder_latent=False): orig_mel = mel cm = ceil_multiple(mel.shape[-1], 4) if cm != 0: - mel = F.pad(mel, (0,cm-mel.shape[-1])) + mel = F.pad(mel, (0, cm-mel.shape[-1])) h = self.down(mel) h = self.encoder(h) - h = self.enc_norm(h.permute(0,2,1)) + h = self.enc_norm(h.permute(0, 2, 1)) if self.use_vqvae_quantizer: codevectors, diversity, codes = self.quantizer(h) else: - codevectors, perplexity, codes = self.quantizer(h, return_probs=True) - diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors + codevectors, perplexity, codes = self.quantizer( + h, return_probs=True) + diversity = (self.quantizer.num_codevectors - + perplexity) / self.quantizer.num_codevectors self.log_codes(codes) - h = self.decoder(codevectors.permute(0,2,1)) + h = self.decoder(codevectors.permute(0, 2, 1)) if return_decoder_latent: return h, diversity @@ -233,13 +249,14 @@ class MusicQuantizer(nn.Module): if self.internal_step % 5 == 0: if not self.use_vqvae_quantizer: codes = torch.argmax(codes, dim=-1) - ccodes = codes[:,:,0] - for j in range(1,codes.shape[-1]): - ccodes += codes[:,:,j] * self.codebook_size ** j + ccodes = codes[:, :, 0] + for j in range(1, codes.shape[-1]): + ccodes += codes[:, :, j] * self.codebook_size ** j codes = ccodes codes = codes.flatten() l = codes.shape[0] - i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l + i = self.code_ind if ( + self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l self.codes[i:i+l] = codes.cpu() self.code_ind = self.code_ind + l if self.code_ind >= self.codes.shape[0]: @@ -259,7 +276,8 @@ def register_music_quantizer(opt_net, opt): if __name__ == '__main__': - model = MusicQuantizer(inner_dim=[1024,1024,512], codevector_dim=1024, codebook_size=8192, codebook_groups=0, use_vqvae_quantizer=True) + model = MusicQuantizer(inner_dim=[1024, 1024, 512], codevector_dim=1024, + codebook_size=8192, codebook_groups=0, use_vqvae_quantizer=True) print_network(model) - mel = torch.randn((2,256,782)) - model(mel) \ No newline at end of file + mel = torch.randn((2, 256, 782)) + model(mel) diff --git a/dlas/models/audio/music/music_quantizer2.py b/dlas/models/audio/music/music_quantizer2.py index d7df3658..2ad1c02a 100644 --- a/dlas/models/audio/music/music_quantizer2.py +++ b/dlas/models/audio/music/music_quantizer2.py @@ -1,14 +1,14 @@ import functools import torch -from torch import nn import torch.nn.functional as F -import torch_intermediary as ml +from torch import nn -from models.arch_util import zero_module -from models.vqvae.vqvae import Quantize -from trainer.networks import register_model -from utils.util import checkpoint, ceil_multiple, print_network +import dlas.torch_intermediary as ml +from dlas.models.arch_util import zero_module +from dlas.models.vqvae.vqvae import Quantize +from dlas.trainer.networks import register_model +from dlas.utils.util import ceil_multiple, checkpoint, print_network class Downsample(nn.Module): @@ -16,7 +16,8 @@ class Downsample(nn.Module): super().__init__() self.interpolate = not stride_down if stride_down: - self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=3, padding=1, stride=2) + self.conv = nn.Conv1d( + chan_in, chan_out, kernel_size=3, padding=1, stride=2) else: self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=3, padding=1) if norm: @@ -49,13 +50,13 @@ class ResBlock(nn.Module): def __init__(self, chan): super().__init__() self.net = nn.Sequential( - nn.Conv1d(chan, chan, 3, padding = 1), + nn.Conv1d(chan, chan, 3, padding=1), nn.GroupNorm(8, chan), nn.SiLU(), - nn.Conv1d(chan, chan, 3, padding = 1), + nn.Conv1d(chan, chan, 3, padding=1), nn.GroupNorm(8, chan), nn.SiLU(), - zero_module(nn.Conv1d(chan, chan, 3, padding = 1)), + zero_module(nn.Conv1d(chan, chan, 3, padding=1)), ) def forward(self, x): @@ -86,7 +87,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): # storage for codebook variables (codewords) self.codevectors = nn.Parameter( - torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups) + torch.FloatTensor(1, self.num_groups * self.num_vars, + codevector_dim // self.num_groups) ) self.weight_proj = ml.Linear(proj_dim, self.num_groups * self.num_vars) @@ -107,7 +109,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): else: marginal_probs = probs.mean(dim=0) - perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() + perplexity = torch.exp(-torch.sum(marginal_probs * + torch.log(marginal_probs + 1e-7), dim=-1)).sum() return perplexity def get_codes(self, hidden_states): @@ -115,9 +118,11 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): # project to codevector dim hidden_states = self.weight_proj(hidden_states) - hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + hidden_states = hidden_states.view( + batch_size * sequence_length * self.num_groups, -1) codevector_idx = hidden_states.argmax(dim=-1) - idxs = codevector_idx.view(batch_size, sequence_length, self.num_groups) + idxs = codevector_idx.view( + batch_size, sequence_length, self.num_groups) return idxs def forward(self, hidden_states, mask_time_indices=None, return_probs=False): @@ -125,7 +130,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): # project to codevector dim hidden_states = self.weight_proj(hidden_states) - hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + hidden_states = hidden_states.view( + batch_size * sequence_length * self.num_groups, -1) if self.training: # sample code vector probs via gumbel in differentiable way @@ -137,7 +143,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): codevector_soft_dist = torch.softmax( hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1 ) - perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices) + perplexity = self._compute_perplexity( + codevector_soft_dist, mask_time_indices) else: # take argmax in non-differentiable way # compute hard codevector distribution (one hot) @@ -145,15 +152,20 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( -1, codevector_idx.view(-1, 1), 1.0 ) - codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) + codevector_probs = codevector_probs.view( + batch_size * sequence_length, self.num_groups, -1) - perplexity = self._compute_perplexity(codevector_probs, mask_time_indices) + perplexity = self._compute_perplexity( + codevector_probs, mask_time_indices) - codevector_probs = codevector_probs.view(batch_size * sequence_length, -1) + codevector_probs = codevector_probs.view( + batch_size * sequence_length, -1) # use probs to retrieve codevectors - codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors + codevectors_per_group = codevector_probs.unsqueeze( + -1) * self.codevectors codevectors = ( - codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1) + codevectors_per_group.view( + batch_size * sequence_length, self.num_groups, self.num_vars, -1) .sum(-2) .view(batch_size, sequence_length, -1) ) @@ -183,12 +195,14 @@ class MusicQuantizer2(nn.Module): self.num_losses_record = [] if down_steps == 0: - self.down = nn.Conv1d(inp_channels, inner_dim[0], kernel_size=3, padding=1) - self.up = nn.Conv1d(inner_dim[0], inp_channels, kernel_size=3, padding=1) + self.down = nn.Conv1d( + inp_channels, inner_dim[0], kernel_size=3, padding=1) + self.up = nn.Conv1d( + inner_dim[0], inp_channels, kernel_size=3, padding=1) elif down_steps == 2: self.down = nn.Sequential(nn.Conv1d(inp_channels, inner_dim[-1], kernel_size=3, padding=1), *[Downsample(inner_dim[-i], inner_dim[-i-1], norm=expressive_downsamples, act=expressive_downsamples, - stride_down=expressive_downsamples) for i in range(1,len(inner_dim))]) + stride_down=expressive_downsamples) for i in range(1, len(inner_dim))]) self.up = nn.Sequential(*[Upsample(inner_dim[i], inner_dim[i+1]) for i in range(len(inner_dim)-1)] + [nn.Conv1d(inner_dim[-1], inp_channels, kernel_size=3, padding=1)]) @@ -209,22 +223,23 @@ class MusicQuantizer2(nn.Module): def get_codes(self, mel): h = self.down(mel) h = self.encoder(h) - h = self.enc_norm(h.permute(0,2,1)) + h = self.enc_norm(h.permute(0, 2, 1)) return self.quantizer.get_codes(h) def forward(self, mel, return_decoder_latent=False): orig_mel = mel cm = ceil_multiple(mel.shape[-1], 2 ** (len(self.down)-1)) if cm != 0: - mel = F.pad(mel, (0,cm-mel.shape[-1])) + mel = F.pad(mel, (0, cm-mel.shape[-1])) h = self.down(mel) h = self.encoder(h) - h = self.enc_norm(h.permute(0,2,1)) + h = self.enc_norm(h.permute(0, 2, 1)) codevectors, perplexity, codes = self.quantizer(h, return_probs=True) - diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors + diversity = (self.quantizer.num_codevectors - + perplexity) / self.quantizer.num_codevectors self.log_codes(codes) - h = self.decoder(codevectors.permute(0,2,1)) + h = self.decoder(codevectors.permute(0, 2, 1)) if return_decoder_latent: return h, diversity @@ -237,13 +252,14 @@ class MusicQuantizer2(nn.Module): def log_codes(self, codes): if self.internal_step % 5 == 0: codes = torch.argmax(codes, dim=-1) - ccodes = codes[:,:,0] - for j in range(1,codes.shape[-1]): - ccodes += codes[:,:,j] * self.codebook_size ** j + ccodes = codes[:, :, 0] + for j in range(1, codes.shape[-1]): + ccodes += codes[:, :, j] * self.codebook_size ** j codes = ccodes codes = codes.flatten() l = codes.shape[0] - i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l + i = self.code_ind if ( + self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l self.codes[i:i+l] = codes.cpu() self.code_ind = self.code_ind + l if self.code_ind >= self.codes.shape[0]: @@ -258,9 +274,9 @@ class MusicQuantizer2(nn.Module): def update_for_step(self, step, *args): self.quantizer.temperature = max( - self.max_gumbel_temperature * self.gumbel_temperature_decay**step, - self.min_gumbel_temperature, - ) + self.max_gumbel_temperature * self.gumbel_temperature_decay**step, + self.min_gumbel_temperature, + ) @register_model @@ -269,7 +285,8 @@ def register_music_quantizer2(opt_net, opt): if __name__ == '__main__': - model = MusicQuantizer2(inner_dim=[1024], codevector_dim=1024, codebook_size=256, codebook_groups=2) + model = MusicQuantizer2( + inner_dim=[1024], codevector_dim=1024, codebook_size=256, codebook_groups=2) print_network(model) - mel = torch.randn((2,256,782)) - model(mel) \ No newline at end of file + mel = torch.randn((2, 256, 782)) + model(mel) diff --git a/dlas/models/audio/music/tfdpc_v5.py b/dlas/models/audio/music/tfdpc_v5.py index 998906df..08e88b02 100644 --- a/dlas/models/audio/music/tfdpc_v5.py +++ b/dlas/models/audio/music/tfdpc_v5.py @@ -8,14 +8,17 @@ import torch.nn as nn import torch.nn.functional as F import torchaudio import torchvision -import torch_intermediary as ml -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepBlock -from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, RotaryEmbedding, \ - FeedForward -from trainer.networks import register_model -from utils.util import checkpoint, print_network, load_audio +import dlas.torch_intermediary as ml +from dlas.models.diffusion.nn import (conv_nd, linear, normalization, + timestep_embedding, zero_module) +from dlas.models.diffusion.unet_diffusion import TimestepBlock +from dlas.models.lucidrains.x_transformers import (Attention, Encoder, + FeedForward, + RMSScaleShiftNorm, + RotaryEmbedding) +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint, load_audio, print_network class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock): @@ -389,17 +392,20 @@ def inference_tfdpc5_with_cheater(): use_fp16=False, unconditioned_percentage=0).eval().cuda() model.load_state_dict(torch.load('x:/dlas/experiments/train_music_cheater_gen_v3/models/59000_generator_ema.pth')) - from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector + from trainer.injectors.audio_injectors import \ + TorchMelSpectrogramInjector spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000, 'true_normalization': True, 'normalize': True, 'in': 'in', 'out': 'out'}, {}).cuda() ref_mel = spec_fn({'in': sample.unsqueeze(0)})['out'] - from trainer.injectors.audio_injectors import MusicCheaterLatentInjector + from trainer.injectors.audio_injectors import \ + MusicCheaterLatentInjector cheater_encoder = MusicCheaterLatentInjector({'in': 'in', 'out': 'out'}, {}).cuda() ref_cheater = cheater_encoder({'in': ref_mel})['out'] - from models.diffusion.respace import SpacedDiffusion - from models.diffusion.respace import space_timesteps - from models.diffusion.gaussian_diffusion import get_named_beta_schedule + from models.diffusion.gaussian_diffusion import \ + get_named_beta_schedule + from models.diffusion.respace import (SpacedDiffusion, + space_timesteps) diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [128]), model_mean_type='epsilon', model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000), conditioning_free=True, conditioning_free_k=1) @@ -415,7 +421,8 @@ def inference_tfdpc5_with_cheater(): # Just decode the ref. #gen_cheater = ref_cheater - from models.audio.music.transformer_diffusion12 import TransformerDiffusionWithCheaterLatent + from models.audio.music.transformer_diffusion12 import \ + TransformerDiffusionWithCheaterLatent diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [32]), model_mean_type='epsilon', model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000), conditioning_free=True, conditioning_free_k=1) diff --git a/dlas/models/audio/music/transformer_diffusion12.py b/dlas/models/audio/music/transformer_diffusion12.py index ff3ccc90..445305e6 100644 --- a/dlas/models/audio/music/transformer_diffusion12.py +++ b/dlas/models/audio/music/transformer_diffusion12.py @@ -4,18 +4,21 @@ from time import time import torch import torch.nn as nn import torch.nn.functional as F -import torch_intermediary as ml -from models.arch_util import ResBlock -from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower -from models.audio.music.music_quantizer2 import MusicQuantizer2 -from models.audio.tts.lucidrains_dvae import DiscreteVAE -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepBlock -from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, RotaryEmbedding, \ - FeedForward -from trainer.networks import register_model -from utils.util import checkpoint, print_network +import dlas.torch_intermediary as ml +from dlas.models.arch_util import ResBlock +from dlas.models.audio.music.gpt_music2 import GptMusicLower, UpperEncoder +from dlas.models.audio.music.music_quantizer2 import MusicQuantizer2 +from dlas.models.audio.tts.lucidrains_dvae import DiscreteVAE +from dlas.models.diffusion.nn import (conv_nd, linear, normalization, + timestep_embedding, zero_module) +from dlas.models.diffusion.unet_diffusion import TimestepBlock +from dlas.models.lucidrains.x_transformers import (Attention, Encoder, + FeedForward, + RMSScaleShiftNorm, + RotaryEmbedding) +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint, print_network def is_latent(t): diff --git a/dlas/models/audio/music/transformer_diffusion13.py b/dlas/models/audio/music/transformer_diffusion13.py index c946af4e..3e7adc57 100644 --- a/dlas/models/audio/music/transformer_diffusion13.py +++ b/dlas/models/audio/music/transformer_diffusion13.py @@ -1,26 +1,29 @@ import itertools -import random from random import randrange import torch import torch.nn as nn import torch.nn.functional as F -import torch_intermediary as ml -from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask, cGLU, \ - RelativeQKBias -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepBlock -from trainer.networks import register_model -from utils.util import checkpoint +import dlas.torch_intermediary as ml +from dlas.models.arch_util import (AttentionBlock, RelativeQKBias, ResBlock, + TimestepEmbedSequential, + build_local_attention_mask, cGLU) +from dlas.models.diffusion.nn import (conv_nd, linear, normalization, + timestep_embedding, zero_module) +from dlas.models.diffusion.unet_diffusion import TimestepBlock +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint class SubBlock(nn.Module): def __init__(self, inp_dim, contraction_dim, heads, dropout): super().__init__() self.dropout = nn.Dropout(p=dropout) - self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads) - self.register_buffer('mask', build_local_attention_mask(n=6000, l=64), persistent=False) + self.attn = AttentionBlock( + inp_dim, out_channels=contraction_dim, num_heads=heads) + self.register_buffer('mask', build_local_attention_mask( + n=6000, l=64), persistent=False) self.pos_bias = RelativeQKBias(l=64, max_positions=6000) ff_contract = contraction_dim//2 self.ff1 = nn.Sequential(nn.Conv1d(inp_dim+contraction_dim, ff_contract, kernel_size=1), @@ -31,7 +34,8 @@ class SubBlock(nn.Module): cGLU(ff_contract)) def forward(self, x): - ah = self.dropout(self.attn(x, mask=self.mask, qk_bias=self.pos_bias(x.shape[-1]))) + ah = self.dropout(self.attn(x, mask=self.mask, + qk_bias=self.pos_bias(x.shape[-1]))) h = torch.cat([ah, x], dim=1) hf = self.dropout(checkpoint(self.ff1, h)) h = torch.cat([h, hf], dim=1) @@ -44,17 +48,21 @@ class ConcatAttentionBlock(TimestepBlock): super().__init__() self.contraction_dim = contraction_dim self.prenorm = nn.GroupNorm(8, trunk_dim) - self.block1 = SubBlock(trunk_dim+blk_dim, contraction_dim, heads, dropout) - self.block2 = SubBlock(trunk_dim+blk_dim+contraction_dim*2, contraction_dim, heads, dropout) - self.out = nn.Conv1d(contraction_dim*4, trunk_dim, kernel_size=1, bias=False) + self.block1 = SubBlock( + trunk_dim+blk_dim, contraction_dim, heads, dropout) + self.block2 = SubBlock( + trunk_dim+blk_dim+contraction_dim*2, contraction_dim, heads, dropout) + self.out = nn.Conv1d(contraction_dim*4, trunk_dim, + kernel_size=1, bias=False) self.out.weight.data.zero_() def forward(self, x, blk_emb): h = self.prenorm(x) - h = torch.cat([h, blk_emb.unsqueeze(-1).repeat(1,1,x.shape[-1])], dim=1) + h = torch.cat( + [h, blk_emb.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1) h = self.block1(h) h = self.block2(h) - h = self.out(h[:,-self.contraction_dim*4:]) + h = self.out(h[:, -self.contraction_dim*4:]) return h + x @@ -72,10 +80,13 @@ class ConditioningEncoder(nn.Module): self.init = nn.Conv1d(spec_dim, hidden_dim, kernel_size=5, stride=2) # nn.Embedding self.resolution_embedding = ml.Embedding(num_resolutions, hidden_dim) - self.resolution_embedding.weight.data.mul(.1) # Reduces the relative influence of this embedding from the start. + # Reduces the relative influence of this embedding from the start. + self.resolution_embedding.weight.data.mul(.1) for a in range(attn_blocks): - attn.append(AttentionBlock(hidden_dim, num_attn_heads, do_checkpoint=do_checkpointing)) - attn.append(ResBlock(hidden_dim, dims=1, checkpointing_enabled=do_checkpointing)) + attn.append(AttentionBlock(hidden_dim, num_attn_heads, + do_checkpoint=do_checkpointing)) + attn.append(ResBlock(hidden_dim, dims=1, + checkpointing_enabled=do_checkpointing)) self.attn = nn.Sequential(*attn) self.out = ml.Linear(hidden_dim, out_dim, bias=False) self.dim = hidden_dim @@ -91,6 +102,7 @@ class TransformerDiffusion(nn.Module): """ A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way? """ + def __init__( self, resolution_steps=8, @@ -108,7 +120,8 @@ class TransformerDiffusion(nn.Module): dropout=0, use_fp16=False, # Parameters for regularization. - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. + # This implements a mechanism similar to what is used in classifier-free training. + unconditioned_percentage=.1, ): super().__init__() @@ -135,17 +148,21 @@ class TransformerDiffusion(nn.Module): ) # nn.Embedding self.resolution_embed = ml.Embedding(resolution_steps, time_proj_dim) - self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64) - self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim)) + self.conditioning_encoder = ConditioningEncoder( + in_channels, model_channels, cond_proj_dim, resolution_steps, num_attn_heads=model_channels//64) + self.unconditioned_embedding = nn.Parameter( + torch.randn(1, cond_proj_dim)) - self.inp_block = conv_nd(1, in_channels+input_vec_dim, model_channels, 3, 1, 1) + self.inp_block = conv_nd( + 1, in_channels+input_vec_dim, model_channels, 3, 1, 1) self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_proj_dim*3 + cond_proj_dim, num_heads, dropout) for _ in range(num_layers)]) self.out = nn.Sequential( normalization(model_channels), nn.SiLU(), - zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), + zero_module(conv_nd(1, model_channels, + out_channels, 3, padding=1)), ) self.debug_codes = {} @@ -163,17 +180,20 @@ class TransformerDiffusion(nn.Module): s_diff = s.shape[-1] - self.max_window if s_diff > 1: start = randrange(0, s_diff) - s = s[:,:,start:start+self.max_window] + s = s[:, :, start:start+self.max_window] s_prior = F.interpolate(s, scale_factor=.25, mode='nearest') - s_prior = F.interpolate(s_prior, size=(s.shape[-1],), mode='linear', align_corners=True) + s_prior = F.interpolate(s_prior, size=( + s.shape[-1],), mode='linear', align_corners=True) # Now diffuse the prior randomly between the x timestep and 0. adv = torch.rand_like(ts.float()) t_prior = (adv * ts).long() - 1 # The t_prior-1 below is an important detail: it forces s_prior to be unmodified for ts=0. It also means that t_prior is not on the same timescale as ts (instead it is shifted by 1). - s_prior_diffused = diffuser.q_sample(s_prior, t_prior-1, torch.randn_like(s_prior), allow_negatives=True) + s_prior_diffused = diffuser.q_sample( + s_prior, t_prior-1, torch.randn_like(s_prior), allow_negatives=True) - self.preprocessed = (s_prior_diffused, t_prior, torch.tensor([resolution] * x.shape[0], dtype=torch.long, device=x.device)) + self.preprocessed = (s_prior_diffused, t_prior, torch.tensor( + [resolution] * x.shape[0], dtype=torch.long, device=x.device)) return s def forward(self, x, timesteps, prior_timesteps=None, x_prior=None, resolution=None, conditioning_input=None, conditioning_free=False): @@ -202,12 +222,16 @@ class TransformerDiffusion(nn.Module): x_prior, prior_timesteps, resolution = self.preprocessed self.preprocessed = None else: - assert x.shape[-1] > x_prior.shape[-1] * 3.9, f'{x.shape} {x_prior.shape}' + assert x.shape[-1] > x_prior.shape[-1] * \ + 3.9, f'{x.shape} {x_prior.shape}' if prior_timesteps is None: # This is taken to mean a fully diffused prior was given. - prior_timesteps = torch.tensor([0], device=x.device) # Assuming batch_size=1 for inference. - x_prior = F.interpolate(x_prior, size=(x.shape[-1],), mode='linear', align_corners=True) - assert torch.all(timesteps - prior_timesteps >= 0), f'Prior timesteps should always be lower (more resolved) than input timesteps. {timesteps}, {prior_timesteps}' + # Assuming batch_size=1 for inference. + prior_timesteps = torch.tensor([0], device=x.device) + x_prior = F.interpolate(x_prior, size=( + x.shape[-1],), mode='linear', align_corners=True) + assert torch.all(timesteps - prior_timesteps >= + 0), f'Prior timesteps should always be lower (more resolved) than input timesteps. {timesteps}, {prior_timesteps}' if conditioning_free: code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1) @@ -218,19 +242,26 @@ class TransformerDiffusion(nn.Module): clen = randrange(MIN_COND_LEN, MAX_COND_LEN) gap = conditioning_input.shape[-1] - clen cstart = randrange(0, gap) - conditioning_input = conditioning_input[:,:,cstart:cstart+clen] - code_emb = self.conditioning_encoder(conditioning_input, resolution) + conditioning_input = conditioning_input[:, + :, cstart:cstart+clen] + code_emb = self.conditioning_encoder( + conditioning_input, resolution) # Mask out the conditioning input and x_prior inputs for whole batch elements, implementing something similar to classifier-free guidance. if self.training and self.unconditioned_percentage > 0: - unconditioned_batches = torch.rand((x.shape[0], 1), device=x.device) < self.unconditioned_percentage - code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1), code_emb) + unconditioned_batches = torch.rand( + (x.shape[0], 1), device=x.device) < self.unconditioned_percentage + code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat( + code_emb.shape[0], 1), code_emb) with torch.autocast(x.device.type, enabled=self.enable_fp16): - time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim)) - prior_time_emb = self.prior_time_embed(timestep_embedding(prior_timesteps, self.time_embed_dim)) + time_emb = self.time_embed( + timestep_embedding(timesteps, self.time_embed_dim)) + prior_time_emb = self.prior_time_embed( + timestep_embedding(prior_timesteps, self.time_embed_dim)) res_emb = self.resolution_embed(resolution) - blk_emb = torch.cat([time_emb, prior_time_emb, res_emb, code_emb], dim=1) + blk_emb = torch.cat( + [time_emb, prior_time_emb, res_emb, code_emb], dim=1) h = torch.cat([x, x_prior], dim=1) h = self.inp_block(h) @@ -250,13 +281,16 @@ class TransformerDiffusion(nn.Module): return out def get_grad_norm_parameter_groups(self): - attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers])) - attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers])) + attn1 = list(itertools.chain.from_iterable( + [lyr.block1.attn.parameters() for lyr in self.layers])) + attn2 = list(itertools.chain.from_iterable( + [lyr.block2.attn.parameters() for lyr in self.layers])) ff1 = list(itertools.chain.from_iterable([lyr.block1.ff1.parameters() for lyr in self.layers] + [lyr.block1.ff2.parameters() for lyr in self.layers])) ff2 = list(itertools.chain.from_iterable([lyr.block2.ff1.parameters() for lyr in self.layers] + [lyr.block2.ff2.parameters() for lyr in self.layers])) - blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers])) + blkout_layers = list(itertools.chain.from_iterable( + [lyr.out.parameters() for lyr in self.layers])) groups = { 'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])), 'blk1_attention_layers': attn1, @@ -276,7 +310,8 @@ class TransformerDiffusion(nn.Module): return groups def before_step(self, step): - scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers])) + scaled_grad_parameters = list(itertools.chain.from_iterable( + [lyr.out.parameters() for lyr in self.layers])) # Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes # higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than # directly fiddling with the gradients. @@ -291,14 +326,13 @@ def register_transformer_diffusion13(opt_net, opt): def test_tfd(): - from models.diffusion.respace import SpacedDiffusion - from models.diffusion.respace import space_timesteps from models.diffusion.gaussian_diffusion import get_named_beta_schedule + from models.diffusion.respace import SpacedDiffusion, space_timesteps diffuser = SpacedDiffusion(use_timesteps=space_timesteps(4000, [4000]), model_mean_type='epsilon', - model_var_type='learned_range', loss_type='mse', - betas=get_named_beta_schedule('linear', 4000)) - clip = torch.randn(2,256,10336) - cond = torch.randn(2,256,10336) + model_var_type='learned_range', loss_type='mse', + betas=get_named_beta_schedule('linear', 4000)) + clip = torch.randn(2, 256, 10336) + cond = torch.randn(2, 256, 10336) ts = torch.LongTensor([0, 0]) model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512, num_heads=512//64, input_vec_dim=256, num_layers=12, dropout=.1, @@ -316,5 +350,5 @@ def remove_conditioning(sd_path): if __name__ == '__main__': - #remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_pre\\models\\12500_generator.pth') + # remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_pre\\models\\12500_generator.pth') test_tfd() diff --git a/dlas/models/audio/music/transformer_diffusion14.py b/dlas/models/audio/music/transformer_diffusion14.py index c318a800..23e89a43 100644 --- a/dlas/models/audio/music/transformer_diffusion14.py +++ b/dlas/models/audio/music/transformer_diffusion14.py @@ -4,18 +4,21 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models.arch_util import TimestepEmbedSequential -from models.audio.music.encoders import ResEncoder16x -from models.audio.music.transformer_diffusion13 import ConcatAttentionBlock -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from trainer.networks import register_model -from utils.util import checkpoint, print_network +from dlas.models.arch_util import TimestepEmbedSequential +from dlas.models.audio.music.encoders import ResEncoder16x +from dlas.models.audio.music.transformer_diffusion13 import \ + ConcatAttentionBlock +from dlas.models.diffusion.nn import (conv_nd, linear, normalization, + timestep_embedding, zero_module) +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint, print_network class TransformerDiffusion(nn.Module): """ A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way? """ + def __init__( self, time_embed_dim=256, @@ -27,11 +30,13 @@ class TransformerDiffusion(nn.Module): out_channels=512, # mean and variance num_heads=4, dropout=0, - use_corner_alignment=False, # This is an interpolation parameter only provided for backwards compatibility. ALL NEW TRAINS SHOULD SET THIS TO TRUE. + # This is an interpolation parameter only provided for backwards compatibility. ALL NEW TRAINS SHOULD SET THIS TO TRUE. + use_corner_alignment=False, use_fp16=False, new_code_expansion=False, # Parameters for regularization. - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. + # This implements a mechanism similar to what is used in classifier-free training. + unconditioned_percentage=.1, # Parameters for re-training head freeze_except_code_converters=False, ): @@ -55,7 +60,8 @@ class TransformerDiffusion(nn.Module): ) self.input_converter = nn.Conv1d(input_vec_dim, model_channels, 1) - self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1)) + self.unconditioned_embedding = nn.Parameter( + torch.randn(1, model_channels, 1)) self.intg = nn.Conv1d(model_channels*2, model_channels, 1) self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_embed_dim//4, num_heads, dropout) for _ in range(num_layers)]) @@ -63,7 +69,8 @@ class TransformerDiffusion(nn.Module): self.out = nn.Sequential( normalization(model_channels), nn.SiLU(), - zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), + zero_module(conv_nd(1, model_channels, + out_channels, 3, padding=1)), ) if freeze_except_code_converters: @@ -76,13 +83,16 @@ class TransformerDiffusion(nn.Module): p.requires_grad = True def get_grad_norm_parameter_groups(self): - attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers])) - attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers])) + attn1 = list(itertools.chain.from_iterable( + [lyr.block1.attn.parameters() for lyr in self.layers])) + attn2 = list(itertools.chain.from_iterable( + [lyr.block2.attn.parameters() for lyr in self.layers])) ff1 = list(itertools.chain.from_iterable([lyr.block1.ff1.parameters() for lyr in self.layers] + [lyr.block1.ff2.parameters() for lyr in self.layers])) ff2 = list(itertools.chain.from_iterable([lyr.block2.ff1.parameters() for lyr in self.layers] + [lyr.block2.ff2.parameters() for lyr in self.layers])) - blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers])) + blkout_layers = list(itertools.chain.from_iterable( + [lyr.out.parameters() for lyr in self.layers])) groups = { 'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])), 'blk1_attention_layers': attn1, @@ -101,7 +111,8 @@ class TransformerDiffusion(nn.Module): def forward(self, x, timesteps, prior=None, conditioning_free=False): if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) + code_emb = self.unconditioned_embedding.repeat( + x.shape[0], 1, x.shape[-1]) else: code_emb = self.input_converter(prior) @@ -112,10 +123,12 @@ class TransformerDiffusion(nn.Module): code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1), code_emb) - code_emb = F.interpolate(code_emb, size=x.shape[-1], mode='nearest') + code_emb = F.interpolate( + code_emb, size=x.shape[-1], mode='nearest') with torch.autocast(x.device.type, enabled=self.enable_fp16): - blk_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim)) + blk_emb = self.time_embed( + timestep_embedding(timesteps, self.time_embed_dim)) x = self.inp_block(x) x = self.intg(torch.cat([x, code_emb], dim=1)) @@ -141,7 +154,8 @@ class TransformerDiffusionWithCheaterLatent(nn.Module): self.internal_step = 0 self.freeze_encoder_until = freeze_encoder_until self.diff = TransformerDiffusion(**kwargs) - self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder) + self.encoder = ResEncoder16x( + 256, 1024, 256, checkpointing_enabled=checkpoint_encoder) def forward(self, x, timesteps, truth_mel, conditioning_free=False, cheater=None): unused_parameters = [] @@ -158,7 +172,8 @@ class TransformerDiffusionWithCheaterLatent(nn.Module): for p in unused_parameters: proj = proj + p.mean() * 0 - diff = self.diff(x, timesteps, prior=proj, conditioning_free=conditioning_free) + diff = self.diff(x, timesteps, prior=proj, + conditioning_free=conditioning_free) return diff def get_debug_values(self, step, __): @@ -172,7 +187,8 @@ class TransformerDiffusionWithCheaterLatent(nn.Module): def before_step(self, step): scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \ - list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers])) + list(itertools.chain.from_iterable( + [lyr.prenorm.parameters() for lyr in self.diff.layers])) # Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes # higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than # directly fiddling with the gradients. @@ -196,10 +212,10 @@ def register_transformer_diffusion_14_with_cheater_latent(opt_net, opt): def test_tfd(): - clip = torch.randn(2,256,400) + clip = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512, - num_heads=3, input_vec_dim=256, num_layers=12, dropout=.1) + num_heads=3, input_vec_dim=256, num_layers=12, dropout=.1) model(clip, ts, clip) @@ -209,14 +225,14 @@ def test_cheater_model(): # For music: model = TransformerDiffusionWithCheaterLatent(in_channels=256, out_channels=512, - model_channels=1024, contraction_dim=512, num_heads=8, - input_vec_dim=256, num_layers=16, - dropout=.1, new_code_expansion=True, - ) - #diff_weights = torch.load('extracted_diff.pth') - #model.diff.load_state_dict(diff_weights, strict=False) - #model.encoder.load_state_dict(torch.load('../experiments/music_cheater_encoder_256.pth', map_location=torch.device('cpu')), strict=True) - #torch.save(model.state_dict(), 'sample.pth') + model_channels=1024, contraction_dim=512, num_heads=8, + input_vec_dim=256, num_layers=16, + dropout=.1, new_code_expansion=True, + ) + # diff_weights = torch.load('extracted_diff.pth') + # model.diff.load_state_dict(diff_weights, strict=False) + # model.encoder.load_state_dict(torch.load('../experiments/music_cheater_encoder_256.pth', map_location=torch.device('cpu')), strict=True) + # torch.save(model.state_dict(), 'sample.pth') print_network(model) o = model(clip, ts, clip) @@ -234,7 +250,8 @@ def extract_cheater_encoder(in_f, out_f): if __name__ == '__main__': - #test_local_attention_mask() - extract_cheater_encoder('X:\\dlas\\experiments\\tfd14_and_cheater.pth', 'X:\\dlas\\experiments\\tfd14_cheater_encoder.pth') - #test_cheater_model() - #extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True) + # test_local_attention_mask() + extract_cheater_encoder('X:\\dlas\\experiments\\tfd14_and_cheater.pth', + 'X:\\dlas\\experiments\\tfd14_cheater_encoder.pth') + # test_cheater_model() + # extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True) diff --git a/dlas/models/audio/music/unet_diffusion_music_codes.py b/dlas/models/audio/music/unet_diffusion_music_codes.py index 5a135502..eee0954c 100644 --- a/dlas/models/audio/music/unet_diffusion_music_codes.py +++ b/dlas/models/audio/music/unet_diffusion_music_codes.py @@ -1,29 +1,23 @@ -from abc import abstractmethod - import math +from abc import abstractmethod import numpy as np import torch import torch as th import torch.nn as nn import torch.nn.functional as F -import torchvision # For debugging, not actually used. -import torch_intermediary as ml -from models.audio.music.gpt_music import GptMusicLower -from models.audio.music.music_quantizer import MusicQuantizer -from models.diffusion.fp16_util import convert_module_to_f16, convert_module_to_f32 -from models.diffusion.nn import ( - conv_nd, - linear, - avg_pool_nd, - zero_module, - normalization, - timestep_embedding, -) -from models.lucidrains.x_transformers import Encoder -from trainer.networks import register_model -from utils.util import checkpoint, print_network, ceil_multiple +import dlas.torch_intermediary as ml +from dlas.models.audio.music.gpt_music import GptMusicLower +from dlas.models.audio.music.music_quantizer import MusicQuantizer +from dlas.models.diffusion.fp16_util import (convert_module_to_f16, + convert_module_to_f32) +from dlas.models.diffusion.nn import (avg_pool_nd, conv_nd, linear, + normalization, timestep_embedding, + zero_module) +from dlas.models.lucidrains.x_transformers import Encoder +from dlas.trainer.networks import register_model +from dlas.utils.util import ceil_multiple, checkpoint, print_network class TimestepBlock(nn.Module): @@ -80,7 +74,8 @@ class Upsample(nn.Module): if dims == 1: ksize = 5 pad = 2 - self.conv = conv_nd(dims, self.channels, self.out_channels, ksize, padding=pad) + self.conv = conv_nd(dims, self.channels, + self.out_channels, ksize, padding=pad) def forward(self, x): assert x.shape[1] == self.channels @@ -123,7 +118,7 @@ class Downsample(nn.Module): elif dims == 2: stride = 2 else: - stride = (1,2,2) + stride = (1, 2, 2) if factor is not None: stride = factor if use_conv: @@ -180,7 +175,8 @@ class ResBlock(TimestepBlock): self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), - conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), + conv_nd(dims, channels, self.out_channels, + kernel_size, padding=padding), ) self.updown = up or down @@ -206,7 +202,8 @@ class ResBlock(TimestepBlock): nn.SiLU(), nn.Dropout(p=dropout), zero_module( - conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) + conv_nd(dims, self.out_channels, self.out_channels, + kernel_size, padding=padding) ), ) @@ -217,7 +214,8 @@ class ResBlock(TimestepBlock): dims, channels, self.out_channels, kernel_size, padding=padding ) else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 1) def forward(self, x, emb): """ @@ -346,13 +344,15 @@ class QKVAttentionLegacy(nn.Module): bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, + length).split(ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = th.einsum( "bct,bcs->bts", q * scale, k * scale ) # More stable with f16 than dividing afterwards if rel_pos is not None: - weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1]) + weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape( + bs * self.n_heads, weight.shape[-2], weight.shape[-1]) weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) if mask is not None: # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. @@ -400,7 +400,8 @@ class QKVAttention(nn.Module): mask = mask.repeat(self.n_heads, 1).unsqueeze(1) weight = weight * mask weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + a = th.einsum("bts,bcs->bct", weight, + v.reshape(bs * self.n_heads, ch, length)) return a.reshape(bs, -1, length) @staticmethod @@ -460,7 +461,8 @@ class UNetMusicModel(nn.Module): resblock_updown=False, use_new_attention_order=False, use_raw_y_as_embedding=False, - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. + # This implements a mechanism similar to what is used in classifier-free training. + unconditioned_percentage=.1, ): super().__init__() @@ -493,44 +495,48 @@ class UNetMusicModel(nn.Module): if self.ar_prior: self.ar_input = ml.Linear(input_vec_dim, model_channels) self.ar_prior_intg = Encoder( - dim=model_channels, - depth=4, - heads=num_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) + dim=model_channels, + depth=4, + heads=num_heads, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + zero_init_branch_output=True, + ff_mult=1, + ) else: self.input_converter = ml.Linear(input_vec_dim, model_channels) self.code_converter = Encoder( - dim=model_channels, - depth=4, - heads=num_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ff_mult=1, - ) - self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels)) - self.x_processor = conv_nd(dims, in_channels, model_channels, 3, padding=1) + dim=model_channels, + depth=4, + heads=num_heads, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + zero_init_branch_output=True, + ff_mult=1, + ) + self.unconditioned_embedding = nn.Parameter( + torch.randn(1, 1, model_channels)) + self.x_processor = conv_nd( + dims, in_channels, model_channels, 3, padding=1) if self.num_classes is not None: # nn.Embedding self.label_emb = ml.Embedding(num_classes, time_embed_dim) self.use_raw_y_as_embedding = use_raw_y_as_embedding - assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive. + # These are mutually-exclusive. + assert not ((self.num_classes is not None) and use_raw_y_as_embedding) self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( - conv_nd(dims, model_channels*2, model_channels, 1, padding=0) + conv_nd(dims, model_channels*2, + model_channels, 1, padding=0) ) ] ) @@ -657,7 +663,8 @@ class UNetMusicModel(nn.Module): self.out = nn.Sequential( normalization(ch), nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + zero_module(conv_nd(dims, model_channels, + out_channels, 3, padding=1)), ) def forward(self, x, timesteps, y, conditioning_free=False): @@ -666,27 +673,35 @@ class UNetMusicModel(nn.Module): if cm != 0: pc = (cm - x.shape[-1]) / x.shape[-1] x = F.pad(x, (0, cm - x.shape[-1])) - y = F.pad(y.permute(0,2,1), (0, int(pc * y.shape[-1]))).permute(0,2,1) + y = F.pad(y.permute(0, 2, 1), (0, int( + pc * y.shape[-1]))).permute(0, 2, 1) unused_params = [] hs = [] - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = self.time_embed(timestep_embedding( + timesteps, self.model_channels)) if conditioning_free: - expanded_code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1).permute(0,2,1) + expanded_code_emb = self.unconditioned_embedding.repeat( + x.shape[0], x.shape[-1], 1).permute(0, 2, 1) if self.ar_prior: - unused_params.extend(list(self.ar_input.parameters()) + list(self.ar_prior_intg.parameters())) + unused_params.extend( + list(self.ar_input.parameters()) + list(self.ar_prior_intg.parameters())) else: - unused_params.extend(list(self.input_converter.parameters()) + list(self.code_converter.parameters())) + unused_params.extend( + list(self.input_converter.parameters()) + list(self.code_converter.parameters())) else: - code_emb = self.ar_input(y) if self.ar_prior else self.input_converter(y) + code_emb = self.ar_input( + y) if self.ar_prior else self.input_converter(y) if self.training and self.unconditioned_percentage > 0: unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device) < self.unconditioned_percentage code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(y.shape[0], 1, 1), code_emb) - code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb) - expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=x.shape[-1], mode='nearest') + code_emb = self.ar_prior_intg( + code_emb) if self.ar_prior else self.code_converter(code_emb) + expanded_code_emb = F.interpolate(code_emb.permute( + 0, 2, 1), size=x.shape[-1], mode='nearest') h = x.type(self.dtype) expanded_code_emb = expanded_code_emb.type(self.dtype) @@ -720,7 +735,8 @@ class UNetMusicModelWithQuantizer(nn.Module): self.internal_step = 0 self.freeze_quantizer_until = freeze_quantizer_until self.diff = UNetMusicModel(**kwargs) - self.m2v = MusicQuantizer(inp_channels=256, inner_dim=[1024,1024,512], codevector_dim=1024, codebook_size=512, codebook_groups=2) + self.m2v = MusicQuantizer(inp_channels=256, inner_dim=[ + 1024, 1024, 512], codevector_dim=1024, codebook_size=512, codebook_groups=2) self.m2v.quantizer.temperature = self.m2v.min_gumbel_temperature del self.m2v.up @@ -728,15 +744,16 @@ class UNetMusicModelWithQuantizer(nn.Module): self.internal_step = step qstep = max(0, self.internal_step - self.freeze_quantizer_until) self.m2v.quantizer.temperature = max( - self.m2v.max_gumbel_temperature * self.m2v.gumbel_temperature_decay**qstep, - self.m2v.min_gumbel_temperature, - ) + self.m2v.max_gumbel_temperature * self.m2v.gumbel_temperature_decay**qstep, + self.m2v.min_gumbel_temperature, + ) def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False): quant_grad_enabled = self.internal_step > self.freeze_quantizer_until with torch.set_grad_enabled(quant_grad_enabled): - proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True) - proj = proj.permute(0,2,1) + proj, diversity_loss = self.m2v( + truth_mel, return_decoder_latent=True) + proj = proj.permute(0, 2, 1) # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. if not quant_grad_enabled: @@ -746,7 +763,8 @@ class UNetMusicModelWithQuantizer(nn.Module): proj = proj + unused diversity_loss = diversity_loss * 0 - diff = self.diff(x, timesteps, proj, conditioning_free=conditioning_free) + diff = self.diff(x, timesteps, proj, + conditioning_free=conditioning_free) if disable_diversity: return diff return diff, diversity_loss @@ -790,7 +808,8 @@ class UNetMusicModelARPrior(nn.Module): with torch.no_grad(): prior = self.ar(truth_mel, conditioning_input, return_latent=True) - diff = self.diff(x, timesteps, prior, conditioning_free=conditioning_free) + diff = self.diff(x, timesteps, prior, + conditioning_free=conditioning_free) return diff @@ -798,6 +817,7 @@ class UNetMusicModelARPrior(nn.Module): def register_unet_diffusion_music_codes(opt_net, opt): return UNetMusicModelWithQuantizer(**opt_net['args']) + @register_model def register_unet_diffusion_music_ar_prior(opt_net, opt): return UNetMusicModelARPrior(**opt_net['args']) @@ -808,20 +828,21 @@ if __name__ == '__main__': cond = torch.randn(2, 256, 300) ts = torch.LongTensor([600, 600]) model = UNetMusicModelARPrior(in_channels=256, out_channels=512, model_channels=640, num_res_blocks=3, input_vec_dim=512, - attention_resolutions=(2,4), channel_mult=(1,2,3), dims=1, + attention_resolutions=(2, 4), channel_mult=(1, 2, 3), dims=1, use_scale_shift_norm=True, dropout=.1, num_heads=8, unconditioned_percentage=.4, freeze_unet=True) print_network(model) model.get_grad_norm_parameter_groups() - ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth') + ar_weights = torch.load( + 'D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth') model.ar.load_state_dict(ar_weights, strict=True) - diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_unet_music\\models\\55500_generator_ema.pth') + diff_weights = torch.load( + 'X:\\dlas\\experiments\\train_music_diffusion_unet_music\\models\\55500_generator_ema.pth') pruned_diff_weights = {} - for k,v in diff_weights.items(): + for k, v in diff_weights.items(): if k.startswith('diff.'): pruned_diff_weights[k.replace('diff.', '')] = v model.diff.load_state_dict(pruned_diff_weights, strict=False) torch.save(model.state_dict(), 'sample.pth') model(clip, ts, cond, cond) - diff --git a/dlas/models/audio/music/unet_diffusion_waveform_gen.py b/dlas/models/audio/music/unet_diffusion_waveform_gen.py index f0ef775f..c0c2cbc1 100644 --- a/dlas/models/audio/music/unet_diffusion_waveform_gen.py +++ b/dlas/models/audio/music/unet_diffusion_waveform_gen.py @@ -6,14 +6,19 @@ import torch.nn.functional as F from torch import autocast from x_transformers import Encoder -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \ - Downsample, Upsample, TimestepBlock -from models.audio.tts.mini_encoder import AudioMiniEncoder -from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder -from scripts.audio.gen.use_diffuse_tts import ceil_multiple -from trainer.networks import register_model -from utils.util import checkpoint +from dlas.models.audio.tts.mini_encoder import AudioMiniEncoder +from dlas.models.audio.tts.unet_diffusion_tts7 import \ + CheckpointedXTransformerEncoder +from dlas.models.diffusion.nn import (conv_nd, linear, normalization, + timestep_embedding, zero_module) +from dlas.models.diffusion.unet_diffusion import (AttentionBlock, Downsample, + TimestepBlock, + TimestepEmbedSequential, + Upsample) +from dlas.scripts.audio.gen.use_diffuse_tts import ceil_multiple +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint + def is_sequence(t): return t.dtype == torch.long @@ -44,7 +49,8 @@ class ResBlock(TimestepBlock): self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), - conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding), + conv_nd(dims, channels, self.out_channels, + eff_kernel, padding=eff_padding), ) self.emb_layers = nn.Sequential( @@ -59,14 +65,16 @@ class ResBlock(TimestepBlock): nn.SiLU(), nn.Dropout(p=dropout), zero_module( - conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) + conv_nd(dims, self.out_channels, self.out_channels, + kernel_size, padding=padding) ), ) if self.out_channels == channels: self.skip_connection = nn.Identity() else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding) + self.skip_connection = conv_nd( + dims, channels, self.out_channels, eff_kernel, padding=eff_padding) def forward(self, x, emb): """ @@ -95,6 +103,7 @@ class ResBlock(TimestepBlock): h = self.out_layers(h) return self.skip_connection(x) + h + class DiffusionWaveformGen(nn.Module): """ The full UNet model with attention and timestep embedding. @@ -138,12 +147,12 @@ class DiffusionWaveformGen(nn.Module): out_channels=2, # mean and variance dropout=0, # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K - channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48), + channel_mult=(1, 1.5, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48), num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2), # spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0) # attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1 - token_conditioning_resolutions=(1,16,), - attention_resolutions=(512,1024,2048), + token_conditioning_resolutions=(1, 16,), + attention_resolutions=(512, 1024, 2048), conv_resample=True, dims=1, use_fp16=False, @@ -154,10 +163,12 @@ class DiffusionWaveformGen(nn.Module): scale_factor=2, time_embed_dim_multiplier=4, freeze_main_net=False, - efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3. + # Uses kernels with width of 1 in several places rather than 3. + efficient_convs=True, use_scale_shift_norm=True, # Parameters for regularization. - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. + # This implements a mechanism similar to what is used in classifier-free training. + unconditioned_percentage=.1, # Parameters for super-sampling. super_sampling=False, super_sampling_max_noising_factor=.1, @@ -168,7 +179,8 @@ class DiffusionWaveformGen(nn.Module): num_heads_upsample = num_heads if super_sampling: - in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input. + # In super-sampling mode, the LR input is concatenated directly onto the input. + in_channels *= 2 self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels @@ -218,22 +230,31 @@ class DiffusionWaveformGen(nn.Module): rotary_pos_emb=True, ) )) - self.latent_converter = nn.Conv1d(in_latent_channels, conditioning_dim, 1) - self.aligned_latent_padding_embedding = nn.Parameter(torch.randn(1,in_latent_channels,1)) - self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1)) + self.latent_converter = nn.Conv1d( + in_latent_channels, conditioning_dim, 1) + self.aligned_latent_padding_embedding = nn.Parameter( + torch.randn(1, in_latent_channels, 1)) + self.unconditioned_embedding = nn.Parameter( + torch.randn(1, conditioning_dim, 1)) self.conditioning_timestep_integrator = TimestepEmbedSequential( - ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), - AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels), - ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), - AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels), - ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), + ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, + dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), + AttentionBlock(conditioning_dim, num_heads=num_heads, + num_head_channels=num_head_channels), + ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, + dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), + AttentionBlock(conditioning_dim, num_heads=num_heads, + num_head_channels=num_head_channels), + ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, + dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), ) self.conditioning_expansion = conditioning_expansion self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding) + conv_nd(dims, in_channels, model_channels, + kernel_size, padding=padding) ) ] ) @@ -344,7 +365,8 @@ class DiffusionWaveformGen(nn.Module): if level and i == num_blocks: out_ch = ch layers.append( - Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor) + Upsample(ch, conv_resample, dims=dims, + out_channels=out_ch, factor=scale_factor) ) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) @@ -353,7 +375,8 @@ class DiffusionWaveformGen(nn.Module): self.out = nn.Sequential( normalization(ch), nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), + zero_module(conv_nd(dims, model_channels, out_channels, + kernel_size, padding=padding)), ) if self.freeze_main_net: @@ -385,13 +408,14 @@ class DiffusionWaveformGen(nn.Module): cm = ceil_multiple(x.shape[-1], self.alignment_size) if cm != 0: pc = (cm-x.shape[-1])/x.shape[-1] - x = F.pad(x, (0,cm-x.shape[-1])) + x = F.pad(x, (0, cm-x.shape[-1])) # Also fix aligned_latent, which is aligned to x. if self.is_latent(aligned_conditioning): aligned_conditioning = torch.cat([aligned_conditioning, self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1) else: - aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1]))) + aligned_conditioning = F.pad( + aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1]))) return x, aligned_conditioning def forward(self, x, timesteps, aligned_conditioning, conditioning_free=False): @@ -416,11 +440,13 @@ class DiffusionWaveformGen(nn.Module): with autocast(x.device.type, enabled=self.enable_fp16): hs = [] - time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + time_emb = self.time_embed( + timestep_embedding(timesteps, self.model_channels)) # Note: this block does not need to repeated on inference, since it is not timestep-dependent. if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1) + code_emb = self.unconditioned_embedding.repeat( + x.shape[0], 1, 1) else: if self.is_latent(aligned_conditioning): code_emb = self.latent_converter(aligned_conditioning) @@ -428,15 +454,18 @@ class DiffusionWaveformGen(nn.Module): code_emb = self.mel_converter(aligned_conditioning) # Everything after this comment is timestep dependent. - code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1) - code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) + code_emb = torch.repeat_interleave( + code_emb, self.conditioning_expansion, dim=-1) + code_emb = self.conditioning_timestep_integrator( + code_emb, time_emb) first = True time_emb = time_emb.float() h = x for k, module in enumerate(self.input_blocks): if isinstance(module, nn.Conv1d): - h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest') + h_tok = F.interpolate(module(code_emb), size=( + h.shape[-1]), mode='nearest') h = h + h_tok else: with autocast(x.device.type, enabled=self.enable_fp16 and not first): @@ -455,7 +484,8 @@ class DiffusionWaveformGen(nn.Module): # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. extraneous_addition = 0 - params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters()) + params = [self.aligned_latent_padding_embedding, + self.unconditioned_embedding] + list(self.latent_converter.parameters()) for p in params: extraneous_addition = extraneous_addition + p.mean() out = out + extraneous_addition * 0 @@ -470,13 +500,13 @@ def register_unet_diffusion_waveform_gen(opt_net, opt): if __name__ == '__main__': clip = torch.randn(2, 1, 32868) - aligned_latent = torch.randn(2,388,1024) - aligned_sequence = torch.randn(2,120,220) + aligned_latent = torch.randn(2, 388, 1024) + aligned_sequence = torch.randn(2, 120, 220) ts = torch.LongTensor([600, 600]) model = DiffusionWaveformGen(128, - channel_mult=[1,1.5,2, 3, 4, 6, 8], + channel_mult=[1, 1.5, 2, 3, 4, 6, 8], num_res_blocks=[2, 2, 2, 2, 2, 2, 1], - token_conditioning_resolutions=[1,4,16,64], + token_conditioning_resolutions=[1, 4, 16, 64], attention_resolutions=[], num_heads=8, kernel_size=3, @@ -488,4 +518,3 @@ if __name__ == '__main__': o = model(clip, ts, aligned_latent) # Test with sequence aligned conditioning o = model(clip, ts, aligned_sequence) - diff --git a/dlas/models/audio/music/unet_diffusion_waveform_gen3.py b/dlas/models/audio/music/unet_diffusion_waveform_gen3.py index e84ab75a..4230ef21 100644 --- a/dlas/models/audio/music/unet_diffusion_waveform_gen3.py +++ b/dlas/models/audio/music/unet_diffusion_waveform_gen3.py @@ -4,12 +4,14 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepEmbedSequential, \ - Downsample, Upsample, TimestepBlock -from scripts.audio.gen.use_diffuse_tts import ceil_multiple -from trainer.networks import register_model -from utils.util import checkpoint, print_network +from dlas.models.diffusion.nn import (conv_nd, linear, normalization, + timestep_embedding, zero_module) +from dlas.models.diffusion.unet_diffusion import (Downsample, TimestepBlock, + TimestepEmbedSequential, + Upsample) +from dlas.scripts.audio.gen.use_diffuse_tts import ceil_multiple +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint, print_network def is_sequence(t): @@ -368,4 +370,3 @@ if __name__ == '__main__': # Test with sequence aligned conditioning o = model(clip, ts, aligned_sequence) print_network(model) - diff --git a/dlas/models/audio/music/unet_diffusion_waveform_gen_simple.py b/dlas/models/audio/music/unet_diffusion_waveform_gen_simple.py index d1e0574c..94c07695 100644 --- a/dlas/models/audio/music/unet_diffusion_waveform_gen_simple.py +++ b/dlas/models/audio/music/unet_diffusion_waveform_gen_simple.py @@ -3,12 +3,14 @@ import torch.nn as nn import torch.nn.functional as F from torch import autocast -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepEmbedSequential, \ - Downsample, Upsample, TimestepBlock -from scripts.audio.gen.use_diffuse_tts import ceil_multiple -from trainer.networks import register_model -from utils.util import checkpoint +from dlas.models.diffusion.nn import (conv_nd, linear, normalization, + timestep_embedding, zero_module) +from dlas.models.diffusion.unet_diffusion import (Downsample, TimestepBlock, + TimestepEmbedSequential, + Upsample) +from dlas.scripts.audio.gen.use_diffuse_tts import ceil_multiple +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint def is_sequence(t): @@ -40,7 +42,8 @@ class ResBlock(TimestepBlock): self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), - conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding), + conv_nd(dims, channels, self.out_channels, + eff_kernel, padding=eff_padding), ) self.emb_layers = nn.Sequential( @@ -55,14 +58,16 @@ class ResBlock(TimestepBlock): nn.SiLU(), nn.Dropout(p=dropout), zero_module( - conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) + conv_nd(dims, self.out_channels, self.out_channels, + kernel_size, padding=padding) ), ) if self.out_channels == channels: self.skip_connection = nn.Identity() else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding) + self.skip_connection = conv_nd( + dims, channels, self.out_channels, eff_kernel, padding=eff_padding) def forward(self, x, emb): """ @@ -91,6 +96,7 @@ class ResBlock(TimestepBlock): h = self.out_layers(h) return self.skip_connection(x) + h + class DiffusionWaveformGen(nn.Module): """ The full UNet model with residual blocks and timestep embedding. @@ -122,11 +128,11 @@ class DiffusionWaveformGen(nn.Module): out_channels=2, # mean and variance dropout=0, # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K - channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48), + channel_mult=(1, 1.5, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48), num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2), # spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0) # attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1 - token_conditioning_resolutions=(1,16,), + token_conditioning_resolutions=(1, 16,), conv_resample=True, dims=1, use_fp16=False, @@ -134,10 +140,12 @@ class DiffusionWaveformGen(nn.Module): scale_factor=2, time_embed_dim_multiplier=4, freeze_main_net=False, - efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3. + # Uses kernels with width of 1 in several places rather than 3. + efficient_convs=True, use_scale_shift_norm=True, # Parameters for regularization. - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. + # This implements a mechanism similar to what is used in classifier-free training. + unconditioned_percentage=.1, # Parameters for super-sampling. super_sampling=False, super_sampling_max_noising_factor=.1, @@ -145,7 +153,8 @@ class DiffusionWaveformGen(nn.Module): super().__init__() if super_sampling: - in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input. + # In super-sampling mode, the LR input is concatenated directly onto the input. + in_channels *= 2 self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels @@ -175,19 +184,25 @@ class DiffusionWaveformGen(nn.Module): # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive # transformer network. - self.mel_converter = nn.Conv1d(in_mel_channels, conditioning_dim, 3, padding=1) - self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1)) + self.mel_converter = nn.Conv1d( + in_mel_channels, conditioning_dim, 3, padding=1) + self.unconditioned_embedding = nn.Parameter( + torch.randn(1, conditioning_dim, 1)) self.conditioning_timestep_integrator = TimestepEmbedSequential( - ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), - ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), - ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), + ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, + dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), + ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, + dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), + ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, + dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), ) self.conditioning_expansion = conditioning_expansion self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding) + conv_nd(dims, in_channels, model_channels, + kernel_size, padding=padding) ) ] ) @@ -268,7 +283,8 @@ class DiffusionWaveformGen(nn.Module): if level and i == num_blocks: out_ch = ch layers.append( - Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor) + Upsample(ch, conv_resample, dims=dims, + out_channels=out_ch, factor=scale_factor) ) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) @@ -277,7 +293,8 @@ class DiffusionWaveformGen(nn.Module): self.out = nn.Sequential( normalization(ch), nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), + zero_module(conv_nd(dims, model_channels, out_channels, + kernel_size, padding=padding)), ) if self.freeze_main_net: @@ -306,8 +323,9 @@ class DiffusionWaveformGen(nn.Module): cm = ceil_multiple(x.shape[-1], self.alignment_size) if cm != 0: pc = (cm-x.shape[-1])/x.shape[-1] - x = F.pad(x, (0,cm-x.shape[-1])) - aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1]))) + x = F.pad(x, (0, cm-x.shape[-1])) + aligned_conditioning = F.pad( + aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1]))) return x, aligned_conditioning def forward(self, x, timesteps, aligned_conditioning, conditioning_free=False): @@ -327,24 +345,29 @@ class DiffusionWaveformGen(nn.Module): with autocast(x.device.type, enabled=self.enable_fp16): hs = [] - time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + time_emb = self.time_embed( + timestep_embedding(timesteps, self.model_channels)) # Note: this block does not need to repeated on inference, since it is not timestep-dependent. if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1) + code_emb = self.unconditioned_embedding.repeat( + x.shape[0], 1, 1) else: code_emb = self.mel_converter(aligned_conditioning) # Everything after this comment is timestep dependent. - code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1) - code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) + code_emb = torch.repeat_interleave( + code_emb, self.conditioning_expansion, dim=-1) + code_emb = self.conditioning_timestep_integrator( + code_emb, time_emb) first = True time_emb = time_emb.float() h = x for k, module in enumerate(self.input_blocks): if isinstance(module, nn.Conv1d): - h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest') + h_tok = F.interpolate(module(code_emb), size=( + h.shape[-1]), mode='nearest') h = h + h_tok else: with autocast(x.device.type, enabled=self.enable_fp16 and not first): @@ -378,12 +401,12 @@ def register_unet_diffusion_waveform_gen2(opt_net, opt): if __name__ == '__main__': clip = torch.randn(2, 1, 32868) - aligned_sequence = torch.randn(2,120,220) + aligned_sequence = torch.randn(2, 120, 220) ts = torch.LongTensor([600, 600]) model = DiffusionWaveformGen(128, - channel_mult=[1,1.5,2, 3, 4, 6, 8], + channel_mult=[1, 1.5, 2, 3, 4, 6, 8], num_res_blocks=[2, 2, 2, 2, 2, 2, 1], - token_conditioning_resolutions=[1,4,16,64], + token_conditioning_resolutions=[1, 4, 16, 64], kernel_size=3, scale_factor=2, time_embed_dim_multiplier=4, @@ -391,4 +414,3 @@ if __name__ == '__main__': efficient_convs=False) # Test with sequence aligned conditioning o = model(clip, ts, aligned_sequence) - diff --git a/dlas/models/audio/tts/autoregressive_codegen.py b/dlas/models/audio/tts/autoregressive_codegen.py index ce1113b7..6a9e55e7 100644 --- a/dlas/models/audio/tts/autoregressive_codegen.py +++ b/dlas/models/audio/tts/autoregressive_codegen.py @@ -1,12 +1,13 @@ import torch import torch.nn as nn import torch.nn.functional as F -from transformers import GPT2PreTrainedModel, GPT2Config +from transformers import GPT2Config, GPT2PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions -from models.arch_util import AttentionBlock -from models.lucidrains.x_transformers import TransformerWrapper, Decoder, Encoder -from trainer.networks import register_model +from dlas.models.arch_util import AttentionBlock +from dlas.models.lucidrains.x_transformers import (Decoder, Encoder, + TransformerWrapper) +from dlas.trainer.networks import register_model class InferenceModel(GPT2PreTrainedModel): @@ -14,6 +15,7 @@ class InferenceModel(GPT2PreTrainedModel): Implementation of GPT2PreTrainedModel from transformers, which allows us to use their generation library with this transformer. """ + def __init__(self, model): super().__init__(GPT2Config()) self.transformer = model @@ -83,7 +85,8 @@ class InferenceModel(GPT2PreTrainedModel): ): assert self.context is not None assert inputs_embeds is None # Not supported by this inference model. - assert labels is None # Training not supported by this inference model. + # Training not supported by this inference model. + assert labels is None return_dict = return_dict if return_dict is not None else self.config.use_return_dict out = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True, past_key_values=past_key_values, @@ -115,7 +118,8 @@ class InferenceModel(GPT2PreTrainedModel): called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. """ return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past) for layer_past in past ) @@ -124,6 +128,7 @@ class ResBlock(nn.Module): """ Basic residual convolutional block that uses GroupNorm. """ + def __init__(self, chan): super().__init__() self.net = nn.Sequential( @@ -148,11 +153,13 @@ class ConditioningEncoder(nn.Module): super().__init__() attn = [] self.init = nn.Sequential(nn.Conv1d(spec_dim, embedding_dim//4, kernel_size=5, padding=2), - nn.Conv1d(embedding_dim//4, embedding_dim//2, kernel_size=3, padding=1, stride=2), + nn.Conv1d(embedding_dim//4, embedding_dim // + 2, kernel_size=3, padding=1, stride=2), ResBlock(embedding_dim//2), nn.Conv1d(embedding_dim//2, embedding_dim, kernel_size=3, padding=1, stride=2)) for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing)) + attn.append(AttentionBlock(embedding_dim, + num_attn_heads, do_checkpoint=do_checkpointing)) self.attn = nn.Sequential(*attn) self.dim = embedding_dim @@ -165,50 +172,53 @@ class ConditioningEncoder(nn.Module): class AutoregressiveCodegen(nn.Module): def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1): super().__init__() - assert depth >= 8 # This is the minimum bound to support the context interleaving that happens later. + # This is the minimum bound to support the context interleaving that happens later. + assert depth >= 8 - self.START_TOKEN=8192 - self.STOP_TOKEN=8193 + self.START_TOKEN = 8192 + self.STOP_TOKEN = 8193 self.START_TEXT_TOKEN = 255 self.STOP_TEXT_TOKEN = 0 self.max_text_token_id = num_text_tokens self.max_mel_token_id = num_mel_tokens - self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False) + self.mel_embedding = ConditioningEncoder( + 80, model_dim, do_checkpointing=False) self.encoder = TransformerWrapper( - num_tokens=num_text_tokens, - use_pos_emb=False, - max_seq_len=-1, - attn_layers = Encoder( - depth=depth, - heads=model_dim//64, - dim=model_dim, - attn_dropout=dropout, - ff_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - ff_mult=1, - rotary_pos_emb=True, - attn_rel_pos_bias=True, - )) - self.encoder.norm = nn.Identity() # This layer and the next are unused. + num_tokens=num_text_tokens, + use_pos_emb=False, + max_seq_len=-1, + attn_layers=Encoder( + depth=depth, + heads=model_dim//64, + dim=model_dim, + attn_dropout=dropout, + ff_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + ff_mult=1, + rotary_pos_emb=True, + attn_rel_pos_bias=True, + )) + # This layer and the next are unused. + self.encoder.norm = nn.Identity() self.encoder.to_logits = nn.Identity() self.decoder = TransformerWrapper( - num_tokens=num_mel_tokens, - use_pos_emb=False, - max_seq_len=-1, - attn_layers=Decoder( - depth=depth, - heads=model_dim//64, - dim=model_dim, - attn_dropout=dropout, - ff_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - ff_mult=1, - rotary_pos_emb=True, - cross_attend=True, - attn_rel_pos_bias=True, - )) + num_tokens=num_mel_tokens, + use_pos_emb=False, + max_seq_len=-1, + attn_layers=Decoder( + depth=depth, + heads=model_dim//64, + dim=model_dim, + attn_dropout=dropout, + ff_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + ff_mult=1, + rotary_pos_emb=True, + cross_attend=True, + attn_rel_pos_bias=True, + )) def get_grad_norm_parameter_groups(self): return { @@ -218,8 +228,10 @@ class AutoregressiveCodegen(nn.Module): } def forward(self, text_codes, conditioning_signal, mel_codes, wav_lengths, return_loss=True): - assert text_codes.max() < self.max_text_token_id and text_codes.min() >= 0, f'Invalid text code encountered: {text_codes.max()}, {text_codes.min()}' - assert mel_codes.max() < self.max_mel_token_id and mel_codes.min() >= 0, f'Invalid mel code encountered: {mel_codes.max()}, {mel_codes.min()}' + assert text_codes.max() < self.max_text_token_id and text_codes.min( + ) >= 0, f'Invalid text code encountered: {text_codes.max()}, {text_codes.min()}' + assert mel_codes.max() < self.max_mel_token_id and mel_codes.min( + ) >= 0, f'Invalid mel code encountered: {mel_codes.max()}, {mel_codes.min()}' # Format mel_codes with a stop token on the end. mel_lengths = wav_lengths // 1024 + 1 @@ -235,8 +247,8 @@ class AutoregressiveCodegen(nn.Module): cond_embs.append(self.mel_embedding(conditioning_signal[:, i])) cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True) # Since all positional embeddings are relative, it is (probably) important to "fix" the text with some permanent embeddings. - text_codes = F.pad(text_codes, (1,0), value=self.START_TEXT_TOKEN) - text_codes = F.pad(text_codes, (0,1), value=self.STOP_TEXT_TOKEN) + text_codes = F.pad(text_codes, (1, 0), value=self.START_TEXT_TOKEN) + text_codes = F.pad(text_codes, (0, 1), value=self.STOP_TEXT_TOKEN) _, enc_text = self.encoder(text_codes, return_hiddens=True) # Interleave cond_emb into the first few contexts. full_context = enc_text @@ -245,11 +257,11 @@ class AutoregressiveCodegen(nn.Module): full_context[6] = cond_emb # Execute the decoder - dec_inputs = F.pad(mel_codes, (1,0), value=self.START_TOKEN)[:, :-1] + dec_inputs = F.pad(mel_codes, (1, 0), value=self.START_TOKEN)[:, :-1] dec = self.decoder(dec_inputs, full_context=full_context) if not return_loss: return dec - loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes) + loss_mel = F.cross_entropy(dec.permute(0, 2, 1), mel_codes) return loss_mel def generate(self, conditioning_signal, text_codes, max_tokens=256, **hf_generate_kwargs): @@ -261,8 +273,8 @@ class AutoregressiveCodegen(nn.Module): for i in range(conditioning_signal.shape[1]): cond_embs.append(self.mel_embedding(conditioning_signal[:, i])) cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True) - text_codes = F.pad(text_codes, (1,0), value=self.START_TEXT_TOKEN) - text_codes = F.pad(text_codes, (0,1), value=self.STOP_TEXT_TOKEN) + text_codes = F.pad(text_codes, (1, 0), value=self.START_TEXT_TOKEN) + text_codes = F.pad(text_codes, (0, 1), value=self.STOP_TEXT_TOKEN) _, enc_text = self.encoder(text_codes, return_hiddens=True) # Interleave cond_emb into the first few contexts. full_context = enc_text @@ -273,7 +285,7 @@ class AutoregressiveCodegen(nn.Module): gen = inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN, max_length=max_tokens, output_attentions=False, return_dict_in_generate=True, use_cache=True, - **hf_generate_kwargs) + **hf_generate_kwargs) return gen.sequences @@ -285,8 +297,8 @@ def register_autoregressive_codegen(opt_net, opt): if __name__ == '__main__': codegen = AutoregressiveCodegen(256, 10) torch.save(codegen.state_dict(), 'sample.pth') - #codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200))) - codegen(torch.randint(0,256, (2,200)), - torch.randn(2,80,120), - torch.randint(0,8192, (2,350)), - torch.tensor([192,350])) + # codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200))) + codegen(torch.randint(0, 256, (2, 200)), + torch.randn(2, 80, 120), + torch.randint(0, 8192, (2, 350)), + torch.tensor([192, 350])) diff --git a/dlas/models/audio/tts/autoregressive_codegen2.py b/dlas/models/audio/tts/autoregressive_codegen2.py index 2d0f8c11..dbc8cf56 100644 --- a/dlas/models/audio/tts/autoregressive_codegen2.py +++ b/dlas/models/audio/tts/autoregressive_codegen2.py @@ -1,12 +1,13 @@ import torch import torch.nn as nn import torch.nn.functional as F -from transformers import GPT2PreTrainedModel, GPT2Config +from transformers import GPT2Config, GPT2PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions -from models.arch_util import AttentionBlock -from models.lucidrains.x_transformers import TransformerWrapper, Encoder, Decoder -from trainer.networks import register_model +from dlas.models.arch_util import AttentionBlock +from dlas.models.lucidrains.x_transformers import (Decoder, Encoder, + TransformerWrapper) +from dlas.trainer.networks import register_model class InferenceModel(GPT2PreTrainedModel): @@ -14,6 +15,7 @@ class InferenceModel(GPT2PreTrainedModel): Implementation of GPT2PreTrainedModel from transformers, which allows us to use their generation library with this transformer. """ + def __init__(self, model): super().__init__(GPT2Config()) self.transformer = model @@ -83,10 +85,12 @@ class InferenceModel(GPT2PreTrainedModel): ): assert self.context is not None assert inputs_embeds is None # Not supported by this inference model. - assert labels is None # Training not supported by this inference model. + # Training not supported by this inference model. + assert labels is None return_dict = return_dict if return_dict is not None else self.config.use_return_dict - hidden_states = self.transformer.decoder(input_ids, context=self.context, return_embeddings=True) + hidden_states = self.transformer.decoder( + input_ids, context=self.context, return_embeddings=True) logits = self.transformer.decoder.to_logits(hidden_states) if not return_dict: @@ -109,7 +113,8 @@ class InferenceModel(GPT2PreTrainedModel): called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. """ return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past) for layer_past in past ) @@ -118,6 +123,7 @@ class ResBlock(nn.Module): """ Basic residual convolutional block that uses GroupNorm. """ + def __init__(self, chan): super().__init__() self.net = nn.Sequential( @@ -142,11 +148,13 @@ class ConditioningEncoder(nn.Module): super().__init__() attn = [] self.init = nn.Sequential(nn.Conv1d(spec_dim, embedding_dim//4, kernel_size=5, padding=2), - nn.Conv1d(embedding_dim//4, embedding_dim//2, kernel_size=3, padding=1, stride=2), + nn.Conv1d(embedding_dim//4, embedding_dim // + 2, kernel_size=3, padding=1, stride=2), ResBlock(embedding_dim//2), nn.Conv1d(embedding_dim//2, embedding_dim, kernel_size=3, padding=1, stride=2)) for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing)) + attn.append(AttentionBlock(embedding_dim, + num_attn_heads, do_checkpoint=do_checkpointing)) self.attn = nn.Sequential(*attn) self.dim = embedding_dim @@ -160,45 +168,46 @@ class AutoregressiveCodegen(nn.Module): def __init__(self, model_dim, encoder_depth, decoder_depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1, ff_mult=1): super().__init__() - self.START_TOKEN=8192 - self.STOP_TOKEN=8193 + self.START_TOKEN = 8192 + self.STOP_TOKEN = 8193 self.max_text_token_id = num_text_tokens self.max_mel_token_id = num_mel_tokens - self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False) + self.mel_embedding = ConditioningEncoder( + 80, model_dim, do_checkpointing=False) self.encoder = TransformerWrapper( - num_tokens=num_text_tokens, - use_pos_emb=False, - max_seq_len=-1, - attn_layers = Encoder( - depth=encoder_depth, - heads=model_dim//64, - dim=model_dim, - attn_dropout=dropout, - ff_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - ff_mult=ff_mult, - rotary_pos_emb=True, - attn_rel_pos_bias=True, - )) + num_tokens=num_text_tokens, + use_pos_emb=False, + max_seq_len=-1, + attn_layers=Encoder( + depth=encoder_depth, + heads=model_dim//64, + dim=model_dim, + attn_dropout=dropout, + ff_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + ff_mult=ff_mult, + rotary_pos_emb=True, + attn_rel_pos_bias=True, + )) self.encoder.to_logits = nn.Identity() # This is unused. self.decoder = TransformerWrapper( - num_tokens=num_mel_tokens, - use_pos_emb=False, - max_seq_len=-1, - attn_layers=Decoder( - depth=decoder_depth, - heads=model_dim//64, - dim=model_dim, - attn_dropout=dropout, - ff_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - ff_mult=ff_mult, - rotary_pos_emb=True, - cross_attend=True, - attn_rel_pos_bias=True, - )) + num_tokens=num_mel_tokens, + use_pos_emb=False, + max_seq_len=-1, + attn_layers=Decoder( + depth=decoder_depth, + heads=model_dim//64, + dim=model_dim, + attn_dropout=dropout, + ff_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + ff_mult=ff_mult, + rotary_pos_emb=True, + cross_attend=True, + attn_rel_pos_bias=True, + )) def get_grad_norm_parameter_groups(self): return { @@ -208,8 +217,10 @@ class AutoregressiveCodegen(nn.Module): } def forward(self, text_codes, conditioning_signal, mel_codes, wav_lengths, return_loss=True): - assert text_codes.max() < self.max_text_token_id and text_codes.min() >= 0, f'Invalid text code encountered: {text_codes.max()}, {text_codes.min()}' - assert mel_codes.max() < self.max_mel_token_id and mel_codes.min() >= 0, f'Invalid mel code encountered: {mel_codes.max()}, {mel_codes.min()}' + assert text_codes.max() < self.max_text_token_id and text_codes.min( + ) >= 0, f'Invalid text code encountered: {text_codes.max()}, {text_codes.min()}' + assert mel_codes.max() < self.max_mel_token_id and mel_codes.min( + ) >= 0, f'Invalid mel code encountered: {mel_codes.max()}, {mel_codes.min()}' # Format mel_codes with a stop token on the end. mel_lengths = wav_lengths // 1024 + 1 @@ -228,11 +239,11 @@ class AutoregressiveCodegen(nn.Module): context = torch.cat([cond_emb, enc_text], dim=1) # Execute the decoder - dec_inputs = F.pad(mel_codes, (1,0), value=self.START_TOKEN)[:, :-1] + dec_inputs = F.pad(mel_codes, (1, 0), value=self.START_TOKEN)[:, :-1] dec = self.decoder(dec_inputs, context=context) if not return_loss: return dec - loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes) + loss_mel = F.cross_entropy(dec.permute(0, 2, 1), mel_codes) return loss_mel def generate(self, conditioning_signal, text_codes, max_tokens=1024, **hf_generate_kwargs): @@ -263,8 +274,9 @@ def register_autoregressive_codegen2(opt_net, opt): if __name__ == '__main__': codegen = AutoregressiveCodegen(512, 20) torch.save(codegen.state_dict(), 'sample.pth') - codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200))) - codegen(torch.randint(0,256, (2,200)), - torch.randn(2,80,120), - torch.randint(0,8192, (2,350)), - torch.tensor([192,350])) \ No newline at end of file + codegen.generate(torch.randn((1, 80, 120)), + torch.randint(0, 256, (1, 200))) + codegen(torch.randint(0, 256, (2, 200)), + torch.randn(2, 80, 120), + torch.randint(0, 8192, (2, 350)), + torch.tensor([192, 350])) diff --git a/dlas/models/audio/tts/ctc_code_generator.py b/dlas/models/audio/tts/ctc_code_generator.py index 801c3fa2..55631670 100644 --- a/dlas/models/audio/tts/ctc_code_generator.py +++ b/dlas/models/audio/tts/ctc_code_generator.py @@ -3,12 +3,12 @@ from random import random import torch import torch.nn as nn import torch.nn.functional as F -import torch_intermediary as ml -from models.audio.tts.unet_diffusion_tts7 import CheckpointedLayer -from models.lucidrains.x_transformers import Encoder -from trainer.networks import register_model -from utils.util import opt_get +import dlas.torch_intermediary as ml +from dlas.models.audio.tts.unet_diffusion_tts7 import CheckpointedLayer +from dlas.models.lucidrains.x_transformers import Encoder +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get class CheckpointedXTransformerEncoder(nn.Module): @@ -16,6 +16,7 @@ class CheckpointedXTransformerEncoder(nn.Module): Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid to channels-last that XTransformer expects. """ + def __init__(self, **xtransformer_kwargs): super().__init__() self.transformer = XTransformer(**xtransformer_kwargs) @@ -23,7 +24,8 @@ class CheckpointedXTransformerEncoder(nn.Module): for xform in [self.transformer.encoder, self.transformer.decoder.net]: for i in range(len(xform.attn_layers.layers)): n, b, r = xform.attn_layers.layers[i] - xform.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r]) + xform.attn_layers.layers[i] = nn.ModuleList( + [n, CheckpointedLayer(b), r]) def forward(self, *args, **kwargs): return self.transformer(*args, **kwargs) @@ -45,20 +47,21 @@ class CtcCodeGenerator(nn.Module): self.recursive_embedding = ml.Embedding(pred_codes, model_dim) self.mask_embedding = nn.Parameter(torch.randn(model_dim)) self.encoder = Encoder( - dim=model_dim, - depth=layers, - heads=model_dim//64, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - ) + dim=model_dim, + depth=layers, + heads=model_dim//64, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + ) self.pred_head = ml.Linear(model_dim, pred_codes) self.confidence_head = ml.Linear(model_dim, 1) def inference(self, codes, pads, repeats): - position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device)) + position_h = self.position_embedding( + torch.arange(0, codes.shape[-1], device=codes.device)) codes_h = self.codes_embedding(codes) labels = pads + repeats * self.max_pad @@ -83,13 +86,15 @@ class CtcCodeGenerator(nn.Module): print(f"Got unexpectedly long pads. Max: {pads.max()}, {pads}") pads = torch.clip(pads, 0, self.max_pad) if repeats.max() > self.max_repeat: - print(f"Got unexpectedly long repeats. Max: {repeats.max()}, {repeats}") + print( + f"Got unexpectedly long repeats. Max: {repeats.max()}, {repeats}") repeats = torch.clip(repeats, 0, self.max_repeat) assert codes.max() < self.ctc_codes, codes.max() labels = pads + repeats * self.max_pad - position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device)) + position_h = self.position_embedding( + torch.arange(0, codes.shape[-1], device=codes.device)) codes_h = self.codes_embedding(codes) recursive_h = self.recursive_embedding(labels) @@ -101,12 +106,14 @@ class CtcCodeGenerator(nn.Module): h = self.encoder(position_h + codes_h + recursive_h) pred_logits = self.pred_head(h) - loss = F.cross_entropy(pred_logits.permute(0,2,1), labels, reduce=False) + loss = F.cross_entropy(pred_logits.permute( + 0, 2, 1), labels, reduce=False) confidences = self.confidence_head(h).squeeze(-1) confidences = F.softmax(confidences * mask, dim=-1) confidence_loss = loss * confidences - loss = loss / loss.shape[-1] # This balances the confidence_loss and loss. + # This balances the confidence_loss and loss. + loss = loss / loss.shape[-1] return loss.mean(), confidence_loss.mean() @@ -118,8 +125,8 @@ def register_ctc_code_generator(opt_net, opt): if __name__ == '__main__': model = CtcCodeGenerator() - inps = torch.randint(0,36, (4, 300)) - pads = torch.randint(0,100, (4,300)) - repeats = torch.randint(0,20, (4,300)) + inps = torch.randint(0, 36, (4, 300)) + pads = torch.randint(0, 100, (4, 300)) + repeats = torch.randint(0, 20, (4, 300)) loss = model(inps, pads, repeats) - print(loss.shape) \ No newline at end of file + print(loss.shape) diff --git a/dlas/models/audio/tts/diffusion_encoder.py b/dlas/models/audio/tts/diffusion_encoder.py index 15c3a2f2..8ece1816 100644 --- a/dlas/models/audio/tts/diffusion_encoder.py +++ b/dlas/models/audio/tts/diffusion_encoder.py @@ -1,16 +1,22 @@ -import functools import math import random from functools import partial import torch import torch.nn as nn -import torch_intermediary as ml +from x_transformers.x_transformers import (DEFAULT_DIM_HEAD, + AlibiPositionalBias, Attention, + AttentionLayers, FeedForward, + FixedPositionalEmbedding, GRUGating, + LayerIntermediates, + LearnedAlibiPositionalBias, + RelativePositionBias, Residual, + RMSNorm, RotaryEmbedding, Scale, + ScaleNorm, ShiftTokens, cast_tuple, + default, equals, exists, + groupby_prefix_and_trim, not_equals) -from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \ - DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, \ - exists, Attention, FeedForward, Scale, ShiftTokens, GRUGating, Residual, cast_tuple, equals, LayerIntermediates, \ - AttentionLayers, not_equals +import dlas.torch_intermediary as ml class TimeIntegrationBlock(nn.Module): @@ -36,40 +42,41 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers): """ Modification of x-transformers.AttentionLayers that performs timestep embeddings and layerdrop. """ + def __init__( self, dim, timestep_dim, depth, - heads = 8, - causal = False, - cross_attend = False, - only_cross = False, - use_scalenorm = False, - use_rmsnorm = False, - use_rezero = False, - alibi_pos_bias = False, - alibi_num_heads = None, - alibi_learned = False, - rel_pos_bias = False, - rel_pos_num_buckets = 32, - rel_pos_max_distance = 128, - position_infused_attn = False, - rotary_pos_emb = False, - rotary_emb_dim = None, - custom_layers = None, - sandwich_coef = None, - par_ratio = None, - residual_attn = False, - cross_residual_attn = False, - macaron = False, - gate_residual = False, - scale_residual = False, - shift_tokens = 0, - use_qk_norm_attn = False, - qk_norm_attn_seq_len = None, - zero_init_branch_output = False, - layerdrop_percent = .1, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + alibi_pos_bias=False, + alibi_num_heads=None, + alibi_learned=False, + rel_pos_bias=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + rotary_pos_emb=False, + rotary_emb_dim=None, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + gate_residual=False, + scale_residual=False, + shift_tokens=0, + use_qk_norm_attn=False, + qk_norm_attn_seq_len=None, + zero_init_branch_output=False, + layerdrop_percent=.1, **kwargs ): super().__init__(dim, depth) @@ -84,21 +91,26 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers): self.layerdrop_percent = layerdrop_percent self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb - self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.pia_pos_emb = FixedPositionalEmbedding( + dim) if position_infused_attn else None rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32) - self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None + self.rotary_pos_emb = RotaryEmbedding( + rotary_emb_dim) if rotary_pos_emb else None - assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both' + assert not ( + alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both' assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' if rel_pos_bias: - self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance) + self.rel_pos = RelativePositionBias( + scale=dim_head ** 0.5, causal=causal, heads=heads, num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance) elif alibi_pos_bias: alibi_num_heads = default(alibi_num_heads, heads) assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads' alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias - self.rel_pos = alibi_pos_klass(heads = alibi_num_heads, bidirectional = not causal) + self.rel_pos = alibi_pos_klass( + heads=alibi_num_heads, bidirectional=not causal) else: self.rel_pos = None @@ -125,8 +137,10 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers): # qk normalization if use_qk_norm_attn: - attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None - attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value} + attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** + 2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None + attn_kwargs = {**attn_kwargs, 'qk_norm': True, + 'scale_init_value': attn_scale_init_value} # zero init if zero_init_branch_output: @@ -140,16 +154,20 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers): par_depth = depth * len(default_block) assert 1 < par_ratio <= par_depth, 'par ratio out of range' default_block = tuple(filter(not_equals('f'), default_block)) - par_attn = par_depth // par_ratio - depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_attn = par_depth // par_ratio + # 2 / 3 attention layer cutoff suggested by PAR paper + depth_cut = par_depth * 2 // 3 par_width = (depth_cut + depth_cut // par_attn) // par_attn - assert len(default_block) <= par_width, 'default block is too large for par_ratio' - par_block = default_block + ('f',) * (par_width - len(default_block)) + assert len( + default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + \ + ('f',) * (par_width - len(default_block)) par_head = par_block * par_attn layer_types = par_head + ('f',) * (par_depth - len(par_head)) elif exists(sandwich_coef): assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' - layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + layer_types = ('a',) * sandwich_coef + default_block * \ + (depth - sandwich_coef) + ('f',) * sandwich_coef else: layer_types = default_block * depth @@ -163,9 +181,10 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers): # iterate and construct layers for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)): if layer_type == 'a': - layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs) + layer = Attention(dim, heads=heads, + causal=causal, **attn_kwargs) elif layer_type == 'c': - layer = Attention(dim, heads = heads, **attn_kwargs) + layer = Attention(dim, heads=heads, **attn_kwargs) elif layer_type == 'f': layer = FeedForward(dim, **ff_kwargs) layer = layer if not macaron else Scale(0.5, layer) @@ -175,19 +194,22 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers): if layer_shift_tokens > 0: shift_range_upper = layer_shift_tokens + 1 shift_range_lower = -layer_shift_tokens if not causal else 0 - layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer) + layer = ShiftTokens( + range(shift_range_lower, shift_range_upper), layer) if exists(branch_fn): layer = branch_fn(layer) residual_fn = GRUGating if gate_residual else Residual - residual = residual_fn(dim, scale_residual = scale_residual) + residual = residual_fn(dim, scale_residual=scale_residual) layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c') - pre_branch_norm = TimeIntegrationBlock(timestep_dim, dim, norm_fn()) + pre_branch_norm = TimeIntegrationBlock( + timestep_dim, dim, norm_fn()) post_branch_norm = norm_fn() if layer_uses_qk_norm else None - post_main_norm = None # Always do prenorm for timestep integration. + # Always do prenorm for timestep integration. + post_main_norm = None norms = nn.ModuleList([ pre_branch_norm, @@ -204,15 +226,16 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers): def forward( self, x, - time_emb = None, - context = None, - mask = None, - context_mask = None, - attn_mask = None, - mems = None, - return_hiddens = False + time_emb=None, + context=None, + mask=None, + context_mask=None, + attn_mask=None, + mems=None, + return_hiddens=False ): - assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True' + assert not (self.cross_attend ^ exists( + context)), 'context must be passed in if cross_attend is set to True' assert time_emb is not None, 'must specify a timestep embedding.' hiddens = [] @@ -224,8 +247,10 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers): rotary_pos_emb = None if exists(self.rotary_pos_emb): - max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems))) - rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) + max_rotary_emb_length = max( + list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems))) + rotary_pos_emb = self.rotary_pos_emb( + max_rotary_emb_length, x.device) unused_params = [] to_drop = 0 @@ -253,9 +278,11 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers): x = pre_branch_norm(x, time_emb) if layer_type == 'a': - out, inter = block(x, mask = mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem) + out, inter = block(x, mask=mask, attn_mask=attn_mask, sinusoidal_emb=self.pia_pos_emb, + rel_pos=self.rel_pos, rotary_pos_emb=rotary_pos_emb, prev_attn=prev_attn, mem=layer_mem) elif layer_type == 'c': - out, inter = block(x, context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn) + out, inter = block( + x, context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) elif layer_type == 'f': out = block(x) @@ -283,10 +310,10 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers): if return_hiddens: intermediates = LayerIntermediates( - hiddens = hiddens, - attn_intermediates = intermediates + hiddens=hiddens, + attn_intermediates=intermediates ) return x, intermediates - return x \ No newline at end of file + return x diff --git a/dlas/models/audio/tts/lucidrains_dvae.py b/dlas/models/audio/tts/lucidrains_dvae.py index 6fc23480..224ab6bb 100644 --- a/dlas/models/audio/tts/lucidrains_dvae.py +++ b/dlas/models/audio/tts/lucidrains_dvae.py @@ -1,17 +1,15 @@ import functools -import math from math import sqrt import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from torch import einsum from vector_quantize_pytorch import VectorQuantize -from models.vqvae.vqvae import Quantize -from trainer.networks import register_model -from utils.util import opt_get +from dlas.models.vqvae.vqvae import Quantize +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get def default(val, d): @@ -32,9 +30,9 @@ class ResBlock(nn.Module): def __init__(self, chan, conv, activation): super().__init__() self.net = nn.Sequential( - conv(chan, chan, 3, padding = 1), + conv(chan, chan, 3, padding=1), activation(), - conv(chan, chan, 3, padding = 1), + conv(chan, chan, 3, padding=1), activation(), conv(chan, chan, 1) ) @@ -52,7 +50,8 @@ class UpsampledConv(nn.Module): self.conv = conv(*args, **kwargs) def forward(self, x): - up = nn.functional.interpolate(x, scale_factor=self.stride, mode='nearest') + up = nn.functional.interpolate( + x, scale_factor=self.stride, mode='nearest') return self.conv(up) @@ -60,23 +59,23 @@ class DiscreteVAE(nn.Module): def __init__( self, positional_dims=2, - num_tokens = 512, - codebook_dim = 512, - num_layers = 3, - num_resnet_blocks = 0, - hidden_dim = 64, - channels = 3, - stride = 2, - kernel_size = 4, - use_transposed_convs = True, - encoder_norm = False, - activation = 'relu', - smooth_l1_loss = False, - straight_through = False, - normalization = None, # ((0.5,) * 3, (0.5,) * 3), - record_codes = False, - use_lr_quantizer = False, - lr_quantizer_args = {}, + num_tokens=512, + codebook_dim=512, + num_layers=3, + num_resnet_blocks=0, + hidden_dim=64, + channels=3, + stride=2, + kernel_size=4, + use_transposed_convs=True, + encoder_norm=False, + activation='relu', + smooth_l1_loss=False, + straight_through=False, + normalization=None, # ((0.5,) * 3, (0.5,) * 3), + record_codes=False, + use_lr_quantizer=False, + lr_quantizer_args={}, ): super().__init__() has_resblocks = num_resnet_blocks > 0 @@ -86,7 +85,8 @@ class DiscreteVAE(nn.Module): self.straight_through = straight_through self.positional_dims = positional_dims - assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now. + # This VAE only supports 1d and 2d inputs for now. + assert positional_dims > 0 and positional_dims < 3 if positional_dims == 2: conv = nn.Conv2d conv_transpose = nn.ConvTranspose2d @@ -103,7 +103,6 @@ class DiscreteVAE(nn.Module): else: assert NotImplementedError() - enc_layers = [] dec_layers = [] @@ -116,18 +115,22 @@ class DiscreteVAE(nn.Module): dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0] dec_chans = [dec_init_chan, *dec_chans] - enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans)) + enc_chans_io, dec_chans_io = map(lambda t: list( + zip(t[:-1], t[1:])), (enc_chans, dec_chans)) pad = (kernel_size - 1) // 2 for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): - enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), act())) + enc_layers.append(nn.Sequential( + conv(enc_in, enc_out, kernel_size, stride=stride, padding=pad), act())) if encoder_norm: enc_layers.append(nn.GroupNorm(8, enc_out)) - dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), act())) + dec_layers.append(nn.Sequential(conv_transpose( + dec_in, dec_out, kernel_size, stride=stride, padding=pad), act())) dec_out_chans = dec_chans[-1] innermost_dim = dec_chans[0] else: - enc_layers.append(nn.Sequential(conv(channels, hidden_dim, 1), act())) + enc_layers.append(nn.Sequential( + conv(channels, hidden_dim, 1), act())) dec_out_chans = hidden_dim innermost_dim = hidden_dim @@ -138,7 +141,6 @@ class DiscreteVAE(nn.Module): if num_resnet_blocks > 0: dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1)) - enc_layers.append(conv(innermost_dim, codebook_dim, 1)) dec_layers.append(conv(dec_out_chans, channels, 1)) @@ -148,9 +150,11 @@ class DiscreteVAE(nn.Module): self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss if use_lr_quantizer: - self.codebook = VectorQuantize(dim=codebook_dim, codebook_size=num_tokens, **lr_quantizer_args) + self.codebook = VectorQuantize( + dim=codebook_dim, codebook_size=num_tokens, **lr_quantizer_args) else: - self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True) + self.codebook = Quantize( + codebook_dim, num_tokens, new_return_order=True) # take care of normalization within class self.normalization = normalization @@ -165,7 +169,8 @@ class DiscreteVAE(nn.Module): if not self.normalization is not None: return images - means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization) + means, stds = map(lambda t: torch.as_tensor( + t).to(images), self.normalization) arrange = 'c -> () c () ()' if self.positional_dims == 2 else 'c -> () c ()' means, stds = map(lambda t: rearrange(t, arrange), (means, stds)) images = images.clone() @@ -183,7 +188,8 @@ class DiscreteVAE(nn.Module): @eval_decorator def get_codebook_indices(self, images): img = self.norm(images) - logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) + logits = self.encoder(img).permute( + (0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1)) sampled, codes, _ = self.codebook(logits) self.log_codes(codes) return codes @@ -214,7 +220,8 @@ class DiscreteVAE(nn.Module): def infer(self, img): img = self.norm(img) - logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) + logits = self.encoder(img).permute( + (0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1)) sampled, codes, commitment_loss = self.codebook(logits) return self.decode(codes) @@ -226,9 +233,11 @@ class DiscreteVAE(nn.Module): img ): img = self.norm(img) - logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) + logits = self.encoder(img).permute( + (0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1)) sampled, codes, commitment_loss = self.codebook(logits) - sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1)) + sampled = sampled.permute( + (0, 3, 1, 2) if len(img.shape) == 4 else (0, 2, 1)) if self.training: out = sampled @@ -249,7 +258,8 @@ class DiscreteVAE(nn.Module): if self.record_codes and self.internal_step % 10 == 0: codes = codes.flatten() l = codes.shape[0] - i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l + i = self.code_ind if ( + self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l self.codes[i:i+l] = codes.cpu() self.code_ind = self.code_ind + l if self.code_ind >= self.codes.shape[0]: @@ -264,14 +274,14 @@ def register_lucidrains_dvae(opt_net, opt): if __name__ == '__main__': - #v = DiscreteVAE() - #o=v(torch.randn(1,3,256,256)) - #print(o.shape) + # v = DiscreteVAE() + # o=v(torch.randn(1,3,256,256)) + # print(o.shape) v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048, hidden_dim=512, num_resnet_blocks=3, kernel_size=3, num_layers=1, use_transposed_convs=False, use_lr_quantizer=True) - #v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth')) - #v.eval() - r,l,o=v(torch.randn(1,80,256)) - v.decode(torch.randint(0,8192,(1,256))) + # v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth')) + # v.eval() + r, l, o = v(torch.randn(1, 80, 256)) + v.decode(torch.randint(0, 8192, (1, 256))) print(o.shape, l.shape) diff --git a/dlas/models/audio/tts/mini_encoder.py b/dlas/models/audio/tts/mini_encoder.py index 4a61199b..681ecc79 100644 --- a/dlas/models/audio/tts/mini_encoder.py +++ b/dlas/models/audio/tts/mini_encoder.py @@ -1,14 +1,14 @@ import torch import torch.nn as nn -import torch_intermediary as ml - - -from models.diffusion.nn import normalization, conv_nd, zero_module -from models.diffusion.unet_diffusion import Downsample, AttentionBlock, QKVAttention, QKVAttentionLegacy, Upsample +import dlas.torch_intermediary as ml +from dlas.models.diffusion.nn import conv_nd, normalization, zero_module +from dlas.models.diffusion.unet_diffusion import (AttentionBlock, Downsample, + QKVAttention, + QKVAttentionLegacy, Upsample) # Combined resnet & full-attention encoder for converting an audio clip into an embedding. -from trainer.networks import register_model -from utils.util import checkpoint, opt_get, sequential_checkpoint +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint, opt_get, sequential_checkpoint class ResBlock(nn.Module): @@ -37,7 +37,8 @@ class ResBlock(nn.Module): self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), - conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), + conv_nd(dims, channels, self.out_channels, + kernel_size, padding=padding), ) self.updown = up or down @@ -56,7 +57,8 @@ class ResBlock(nn.Module): nn.SiLU(), nn.Dropout(p=dropout), zero_module( - conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) + conv_nd(dims, self.out_channels, self.out_channels, + kernel_size, padding=padding) ), ) @@ -67,7 +69,8 @@ class ResBlock(nn.Module): dims, channels, self.out_channels, kernel_size, padding=padding ) else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 1) def forward(self, x): if self.do_checkpoint: @@ -111,8 +114,10 @@ class AudioMiniEncoder(nn.Module): self.layers = depth for l in range(depth): for r in range(resnet_blocks): - res.append(ResBlock(ch, dropout, dims=1, do_checkpoint=False, kernel_size=kernel_size)) - res.append(Downsample(ch, use_conv=True, dims=1, out_channels=ch*2, factor=downsample_factor)) + res.append(ResBlock(ch, dropout, dims=1, + do_checkpoint=False, kernel_size=kernel_size)) + res.append(Downsample(ch, use_conv=True, dims=1, + out_channels=ch*2, factor=downsample_factor)) ch *= 2 self.res = nn.Sequential(*res) self.final = nn.Sequential( @@ -122,7 +127,8 @@ class AudioMiniEncoder(nn.Module): ) attn = [] for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False)) + attn.append(AttentionBlock(embedding_dim, + num_attn_heads, do_checkpoint=False)) self.attn = nn.Sequential(*attn) self.dim = embedding_dim @@ -150,10 +156,12 @@ class AudioMiniEncoderWithClassifierHead(nn.Module): return logits else: if self.distribute_zero_label: - oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes) + oh_labels = nn.functional.one_hot( + labels, num_classes=self.num_classes) zeros_indices = (labels == 0).unsqueeze(-1) # Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise. - zero_extra_mass = torch.full_like(oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1)) + zero_extra_mass = torch.full_like( + oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1)) zero_extra_mass[:, 0] = -.2 zero_extra_mass = zero_extra_mass * zeros_indices oh_labels = oh_labels + zero_extra_mass @@ -167,6 +175,7 @@ class QueryProvidedAttentionBlock(nn.Module): """ An attention block that provides a separate signal for the query vs the keys/parameters. """ + def __init__( self, channels, @@ -200,12 +209,13 @@ class QueryProvidedAttentionBlock(nn.Module): return checkpoint(self._forward, qx, kvx, mask) def _forward(self, qx, kvx, mask=None): - q = self.q(self.qnorm(qx)).unsqueeze(1).repeat(1, kvx.shape[1], 1).permute(0,2,1) - kv = self.kv(self.norm(kvx.permute(0,2,1))) + q = self.q(self.qnorm(qx)).unsqueeze(1).repeat( + 1, kvx.shape[1], 1).permute(0, 2, 1) + kv = self.kv(self.norm(kvx.permute(0, 2, 1))) qkv = torch.cat([q, kv], dim=1) h = self.attention(qkv, mask) h = self.proj_out(h) - return kvx + h.permute(0,2,1) + return kvx + h.permute(0, 2, 1) # Next up: combine multiple embeddings given a conditioning signal into a single embedding. @@ -213,7 +223,8 @@ class EmbeddingCombiner(nn.Module): def __init__(self, embedding_dim, attn_blocks=3, num_attn_heads=2, cond_provided=True): super().__init__() block = QueryProvidedAttentionBlock if cond_provided else AttentionBlock - self.attn = nn.ModuleList([block(embedding_dim, num_attn_heads) for _ in range(attn_blocks)]) + self.attn = nn.ModuleList( + [block(embedding_dim, num_attn_heads) for _ in range(attn_blocks)]) self.cond_provided = cond_provided # x_s: (b,n,d); b=batch_sz, n=number of embeddings, d=embedding_dim diff --git a/dlas/models/audio/tts/random_latent_converter.py b/dlas/models/audio/tts/random_latent_converter.py index 9f14ebb3..d309778f 100644 --- a/dlas/models/audio/tts/random_latent_converter.py +++ b/dlas/models/audio/tts/random_latent_converter.py @@ -3,10 +3,10 @@ import math import torch import torch.nn as nn import torch.nn.functional as F -import torch_intermediary as ml -from trainer.networks import register_model -from utils.util import opt_get +import dlas.torch_intermediary as ml +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): @@ -61,4 +61,4 @@ def register_random_latent_converter(opt_net, opt): if __name__ == '__main__': model = RandomLatentConverter(512) - model(torch.randn(5,512)) \ No newline at end of file + model(torch.randn(5, 512)) diff --git a/dlas/models/audio/tts/tacotron2/__init__.py b/dlas/models/audio/tts/tacotron2/__init__.py index feca08a9..ce46f0cb 100644 --- a/dlas/models/audio/tts/tacotron2/__init__.py +++ b/dlas/models/audio/tts/tacotron2/__init__.py @@ -1,6 +1,6 @@ -from models.audio.tts.tacotron2.taco_utils import * -from models.audio.tts.tacotron2.text import * -from models.audio.tts.tacotron2.tacotron2 import * -from models.audio.tts.tacotron2.stft import * -from models.audio.tts.tacotron2.layers import * -from models.audio.tts.tacotron2.loss import * \ No newline at end of file +from dlas.models.audio.tts.tacotron2.layers import * +from dlas.models.audio.tts.tacotron2.loss import * +from dlas.models.audio.tts.tacotron2.stft import * +from dlas.models.audio.tts.tacotron2.taco_utils import * +from dlas.models.audio.tts.tacotron2.tacotron2 import * +from dlas.models.audio.tts.tacotron2.text import * diff --git a/dlas/models/audio/tts/tacotron2/audio_processing.py b/dlas/models/audio/tts/tacotron2/audio_processing.py index 1d448bae..185118bb 100644 --- a/dlas/models/audio/tts/tacotron2/audio_processing.py +++ b/dlas/models/audio/tts/tacotron2/audio_processing.py @@ -1,7 +1,7 @@ -import torch -import numpy as np -from scipy.signal import get_window import librosa.util as librosa_util +import numpy as np +import torch +from scipy.signal import get_window def window_sumsquare(window, n_frames, hop_length=200, win_length=800, @@ -52,7 +52,8 @@ def window_sumsquare(window, n_frames, hop_length=200, win_length=800, # Fill the envelope for i in range(n_frames): sample = i * hop_length - x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] + x[sample:min(n, sample + n_fft) + ] += win_sq[:max(0, min(n_fft, n - sample))] return x @@ -90,4 +91,4 @@ def dynamic_range_decompression(x, C=1): ------ C: compression factor used to compress """ - return torch.exp(x) / C \ No newline at end of file + return torch.exp(x) / C diff --git a/dlas/models/audio/tts/tacotron2/hparams.py b/dlas/models/audio/tts/tacotron2/hparams.py index b25df857..4d2e6ee8 100644 --- a/dlas/models/audio/tts/tacotron2/hparams.py +++ b/dlas/models/audio/tts/tacotron2/hparams.py @@ -1,5 +1,5 @@ -#import tensorflow as tf -from models.audio.tts.tacotron2.text import symbols +# import tensorflow as tf +from dlas.models.audio.tts.tacotron2.text import symbols def create_hparams(hparams_string=None, verbose=False): @@ -33,7 +33,8 @@ def create_hparams(hparams_string=None, verbose=False): # Audio Parameters # ################################ max_wav_value=32768.0, - input_sample_rate=22050, # When different from sampling_rate, dataset automatically interpolates to sampling_rate + # When different from sampling_rate, dataset automatically interpolates to sampling_rate + input_sample_rate=22050, sampling_rate=22050, filter_length=1024, hop_length=256, # This means a MEL is 1/256th the equivalent audio. @@ -86,4 +87,4 @@ def create_hparams(hparams_string=None, verbose=False): mask_padding=True # set model's padded outputs to padded values ) - return hparams \ No newline at end of file + return hparams diff --git a/dlas/models/audio/tts/tacotron2/layers.py b/dlas/models/audio/tts/tacotron2/layers.py index 11022c02..39353a6a 100644 --- a/dlas/models/audio/tts/tacotron2/layers.py +++ b/dlas/models/audio/tts/tacotron2/layers.py @@ -1,9 +1,10 @@ import torch from librosa.filters import mel as librosa_mel_fn -from models.audio.tts.tacotron2.audio_processing import dynamic_range_compression -from models.audio.tts.tacotron2.audio_processing import dynamic_range_decompression -from models.audio.tts.tacotron2.stft import STFT -import torch_intermediary as ml + +import dlas.torch_intermediary as ml +from dlas.models.audio.tts.tacotron2.audio_processing import ( + dynamic_range_compression, dynamic_range_decompression) +from dlas.models.audio.tts.tacotron2.stft import STFT class LinearNorm(torch.nn.Module): @@ -24,7 +25,7 @@ class ConvNorm(torch.nn.Module): padding=None, dilation=1, bias=True, w_init_gain='linear'): super(ConvNorm, self).__init__() if padding is None: - assert(kernel_size % 2 == 1) + assert (kernel_size % 2 == 1) padding = int(dilation * (kernel_size - 1) / 2) self.conv = torch.nn.Conv1d(in_channels, out_channels, @@ -71,8 +72,8 @@ class TacotronSTFT(torch.nn.Module): ------- mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) """ - assert(torch.min(y.data) >= -10) - assert(torch.max(y.data) <= 10) + assert (torch.min(y.data) >= -10) + assert (torch.max(y.data) <= 10) y = torch.clip(y, min=-1, max=1) magnitudes, phases = self.stft_fn.transform(y) diff --git a/dlas/models/audio/tts/tacotron2/loss.py b/dlas/models/audio/tts/tacotron2/loss.py index 77d6457b..b5155b37 100644 --- a/dlas/models/audio/tts/tacotron2/loss.py +++ b/dlas/models/audio/tts/tacotron2/loss.py @@ -1,6 +1,6 @@ from torch import nn -from trainer.losses import ConfigurableLoss +from dlas.trainer.losses import ConfigurableLoss class Tacotron2Loss(ConfigurableLoss): @@ -20,7 +20,8 @@ class Tacotron2Loss(ConfigurableLoss): gate_target.requires_grad = False gate_target = gate_target.view(-1, 1) - mel_out, mel_out_postnet, gate_out = state[self.mel_output_key], state[self.mel_output_postnet_key], state[self.gate_output_key] + mel_out, mel_out_postnet, gate_out = state[self.mel_output_key], state[ + self.mel_output_postnet_key], state[self.gate_output_key] gate_out = gate_out.view(-1, 1) mel_loss = nn.MSELoss()(mel_out, mel_target) + \ nn.MSELoss()(mel_out_postnet, mel_target) @@ -61,4 +62,4 @@ class Tacotron2LossRaw(nn.Module): return { 'mel_loss': self.last_mel_loss, 'gate_loss': self.last_gate_loss - } \ No newline at end of file + } diff --git a/dlas/models/audio/tts/tacotron2/stft.py b/dlas/models/audio/tts/tacotron2/stft.py index f2c3186a..ccc23b31 100644 --- a/dlas/models/audio/tts/tacotron2/stft.py +++ b/dlas/models/audio/tts/tacotron2/stft.py @@ -30,17 +30,19 @@ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ -import torch import numpy as np +import torch import torch.nn.functional as F -from torch.autograd import Variable -from scipy.signal import get_window from librosa.util import pad_center, tiny -from models.audio.tts.tacotron2.audio_processing import window_sumsquare +from scipy.signal import get_window +from torch.autograd import Variable + +from dlas.models.audio.tts.tacotron2.audio_processing import window_sumsquare class STFT(torch.nn.Module): """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" + def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'): super(STFT, self).__init__() @@ -61,7 +63,7 @@ class STFT(torch.nn.Module): np.linalg.pinv(scale * fourier_basis).T[:, None, :]) if window is not None: - assert(filter_length >= win_length) + assert (filter_length >= win_length) # get window and zero center pad it to filter_length fft_window = get_window(window, win_length, fftbins=True) fft_window = pad_center(fft_window, filter_length) @@ -125,17 +127,19 @@ class STFT(torch.nn.Module): window_sum = torch.autograd.Variable( torch.from_numpy(window_sum), requires_grad=False) window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum - inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] + inverse_transform[:, :, + approx_nonzero_indices] /= window_sum[approx_nonzero_indices] # scale by hop ratio inverse_transform *= float(self.filter_length) / self.hop_length inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] - inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] + inverse_transform = inverse_transform[:, + :, :-int(self.filter_length/2):] return inverse_transform def forward(self, input_data): self.magnitude, self.phase = self.transform(input_data) reconstruction = self.inverse(self.magnitude, self.phase) - return reconstruction \ No newline at end of file + return reconstruction diff --git a/dlas/models/audio/tts/tacotron2/taco_utils.py b/dlas/models/audio/tts/tacotron2/taco_utils.py index 13c63fa5..84e01018 100644 --- a/dlas/models/audio/tts/tacotron2/taco_utils.py +++ b/dlas/models/audio/tts/tacotron2/taco_utils.py @@ -8,7 +8,8 @@ from scipy.io.wavfile import read def get_mask_from_lengths(lengths, max_len=None): if max_len is None: max_len = torch.max(lengths).item() - ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).to(lengths.device) + ids = torch.arange(0, max_len, out=torch.LongTensor( + max_len)).to(lengths.device) mask = (ids < lengths.unsqueeze(1)).bool() return mask @@ -22,24 +23,29 @@ def load_wav_to_torch(full_path): elif data.dtype == np.float16 or data.dtype == np.float32: norm_fix = 1. else: - raise NotImplemented(f"Provided data dtype not supported: {data.dtype}") + raise NotImplemented( + f"Provided data dtype not supported: {data.dtype}") return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate) def load_filepaths_and_text_type(filename, type, split="|"): with open(filename, encoding='utf-8') as f: - filepaths_and_text = [list(line.strip().split(split)) + [type] for line in f] + filepaths_and_text = [ + list(line.strip().split(split)) + [type] for line in f] base = os.path.dirname(filename) for j in range(len(filepaths_and_text)): - filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0]) + filepaths_and_text[j][0] = os.path.join( + base, filepaths_and_text[j][0]) return filepaths_and_text + def load_filepaths_and_text(filename, split="|"): with open(filename, encoding='utf-8') as f: filepaths_and_text = [line.strip().split(split) for line in f] base = os.path.dirname(filename) for j in range(len(filepaths_and_text)): - filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0]) + filepaths_and_text[j][0] = os.path.join( + base, filepaths_and_text[j][0]) return filepaths_and_text @@ -48,4 +54,4 @@ def to_gpu(x): if torch.cuda.is_available(): x = x.cuda(non_blocking=True) - return torch.autograd.Variable(x) \ No newline at end of file + return torch.autograd.Variable(x) diff --git a/dlas/models/audio/tts/tacotron2/tacotron2.py b/dlas/models/audio/tts/tacotron2/tacotron2.py index f8f64cb7..c9ae6170 100644 --- a/dlas/models/audio/tts/tacotron2/tacotron2.py +++ b/dlas/models/audio/tts/tacotron2/tacotron2.py @@ -1,14 +1,16 @@ from math import sqrt + import torch from munch import munchify -from torch.autograd import Variable from torch import nn +from torch.autograd import Variable from torch.nn import functional as F -from models.audio.tts.tacotron2.layers import ConvNorm, LinearNorm -from models.audio.tts.tacotron2.hparams import create_hparams -from trainer.networks import register_model -from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths -import torch_intermediary as ml + +import dlas.torch_intermediary as ml +from dlas.models.audio.tts.tacotron2.hparams import create_hparams +from dlas.models.audio.tts.tacotron2.layers import ConvNorm, LinearNorm +from dlas.models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths +from dlas.trainer.networks import register_model class LocationLayer(nn.Module): @@ -59,7 +61,8 @@ class Attention(nn.Module): """ processed_query = self.query_layer(query.unsqueeze(1)) - processed_attention_weights = self.location_layer(attention_weights_cat) + processed_attention_weights = self.location_layer( + attention_weights_cat) energies = self.v(torch.tanh( processed_query + processed_attention_weights + processed_memory)) @@ -128,7 +131,8 @@ class Postnet(nn.Module): ConvNorm(hparams.postnet_embedding_dim, hparams.postnet_embedding_dim, kernel_size=hparams.postnet_kernel_size, stride=1, - padding=int((hparams.postnet_kernel_size - 1) / 2), + padding=int( + (hparams.postnet_kernel_size - 1) / 2), dilation=1, w_init_gain='tanh'), nn.BatchNorm1d(hparams.postnet_embedding_dim)) ) @@ -140,11 +144,12 @@ class Postnet(nn.Module): padding=int((hparams.postnet_kernel_size - 1) / 2), dilation=1, w_init_gain='linear'), nn.BatchNorm1d(hparams.n_mel_channels)) - ) + ) def forward(self, x): for i in range(len(self.convolutions) - 1): - x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) + x = F.dropout(torch.tanh( + self.convolutions[i](x)), 0.5, self.training) x = F.dropout(self.convolutions[-1](x), 0.5, self.training) return x @@ -155,6 +160,7 @@ class Encoder(nn.Module): - Three 1-d convolution banks - Bidirectional LSTM """ + def __init__(self, hparams): super(Encoder, self).__init__() @@ -408,7 +414,8 @@ class Decoder(nn.Module): mel_outputs, gate_outputs, alignments = [], [], [] while len(mel_outputs) < decoder_inputs.size(0) - 1: decoder_input = decoder_inputs[len(mel_outputs)] - mel_output, gate_output, attention_weights = self.decode(decoder_input) + mel_output, gate_output, attention_weights = self.decode( + decoder_input) mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output.squeeze(1)] alignments += [attention_weights] @@ -529,9 +536,9 @@ def register_nv_tacotron2(opt_net, opt): if __name__ == '__main__': tron = register_nv_tacotron2({}, {}) - inputs = torch.randint(high=24, size=(1,12)), \ - torch.tensor([12]), \ - torch.randn((1,80,749)), \ - torch.tensor([749]) + inputs = torch.randint(high=24, size=(1, 12)), \ + torch.tensor([12]), \ + torch.randn((1, 80, 749)), \ + torch.tensor([749]) out = tron(*inputs) - print(out) \ No newline at end of file + print(out) diff --git a/dlas/models/audio/tts/tacotron2/text/__init__.py b/dlas/models/audio/tts/tacotron2/text/__init__.py index 245432aa..c77ccbeb 100644 --- a/dlas/models/audio/tts/tacotron2/text/__init__.py +++ b/dlas/models/audio/tts/tacotron2/text/__init__.py @@ -3,9 +3,8 @@ import re import torch -from models.audio.tts.tacotron2.text import cleaners -from models.audio.tts.tacotron2.text.symbols import symbols - +from dlas.models.audio.tts.tacotron2.text import cleaners +from dlas.models.audio.tts.tacotron2.text.symbols import symbols # Mappings from symbol to numeric ID and vice versa: _symbol_to_id = {s: i for i, s in enumerate(symbols)} @@ -16,72 +15,73 @@ _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') def text_to_sequence(text, cleaner_names=['english_cleaners']): - '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. - The text can optionally have ARPAbet sequences enclosed in curly braces embedded - in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." + The text can optionally have ARPAbet sequences enclosed in curly braces embedded + in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." - Args: - text: string to convert to a sequence - cleaner_names: names of the cleaner functions to run the text through + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through - Returns: - List of integers corresponding to the symbols in the text - ''' - sequence = [] + Returns: + List of integers corresponding to the symbols in the text + ''' + sequence = [] - # Check for curly braces and treat their contents as ARPAbet: - while len(text): - m = _curly_re.match(text) - if not m: - sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) - break - sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) - sequence += _arpabet_to_sequence(m.group(2)) - text = m.group(3) + # Check for curly braces and treat their contents as ARPAbet: + while len(text): + m = _curly_re.match(text) + if not m: + sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) + break + sequence += _symbols_to_sequence( + _clean_text(m.group(1), cleaner_names)) + sequence += _arpabet_to_sequence(m.group(2)) + text = m.group(3) - return sequence + return sequence def sequence_to_text(sequence): - '''Converts a sequence of IDs back to a string''' - result = '' - for symbol_id in sequence: - if isinstance(symbol_id, torch.Tensor): - symbol_id = symbol_id.item() - if symbol_id in _id_to_symbol: - s = _id_to_symbol[symbol_id] - # Enclose ARPAbet back in curly braces: - if len(s) > 1 and s[0] == '@': - s = '{%s}' % s[1:] - result += s - return result.replace('}{', ' ') + '''Converts a sequence of IDs back to a string''' + result = '' + for symbol_id in sequence: + if isinstance(symbol_id, torch.Tensor): + symbol_id = symbol_id.item() + if symbol_id in _id_to_symbol: + s = _id_to_symbol[symbol_id] + # Enclose ARPAbet back in curly braces: + if len(s) > 1 and s[0] == '@': + s = '{%s}' % s[1:] + result += s + return result.replace('}{', ' ') def tacotron_symbols(): - return list(_symbol_to_id.keys()) + return list(_symbol_to_id.keys()) def tacotron_symbol_mapping(): - return _symbol_to_id.copy() + return _symbol_to_id.copy() def _clean_text(text, cleaner_names): - for name in cleaner_names: - cleaner = getattr(cleaners, name) - if not cleaner: - raise Exception('Unknown cleaner: %s' % name) - text = cleaner(text) - return text + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception('Unknown cleaner: %s' % name) + text = cleaner(text) + return text def _symbols_to_sequence(symbols): - return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] + return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] def _arpabet_to_sequence(text): - return _symbols_to_sequence(['@' + s for s in text.split()]) + return _symbols_to_sequence(['@' + s for s in text.split()]) def _should_keep_symbol(s): - return s in _symbol_to_id and s != '_' and s != '~' + return s in _symbol_to_id and s != '_' and s != '~' diff --git a/dlas/models/audio/tts/tacotron2/text/cleaners.py b/dlas/models/audio/tts/tacotron2/text/cleaners.py index 1277d87f..ba69db11 100644 --- a/dlas/models/audio/tts/tacotron2/text/cleaners.py +++ b/dlas/models/audio/tts/tacotron2/text/cleaners.py @@ -12,80 +12,79 @@ hyperparameter. Some cleaners are English-specific. You'll typically want to use the symbols in symbols.py to match your data). ''' + +# Regular expression matching whitespace: import re from unidecode import unidecode from .numbers import normalize_numbers - - -# Regular expression matching whitespace: _whitespace_re = re.compile(r'\s+') # List of (regular expression, replacement) pairs for abbreviations: _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ - ('mrs', 'misess'), - ('mr', 'mister'), - ('dr', 'doctor'), - ('st', 'saint'), - ('co', 'company'), - ('jr', 'junior'), - ('maj', 'major'), - ('gen', 'general'), - ('drs', 'doctors'), - ('rev', 'reverend'), - ('lt', 'lieutenant'), - ('hon', 'honorable'), - ('sgt', 'sergeant'), - ('capt', 'captain'), - ('esq', 'esquire'), - ('ltd', 'limited'), - ('col', 'colonel'), - ('ft', 'fort'), + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), ]] def expand_abbreviations(text): - for regex, replacement in _abbreviations: - text = re.sub(regex, replacement, text) - return text + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text def expand_numbers(text): - return normalize_numbers(text) + return normalize_numbers(text) def lowercase(text): - return text.lower() + return text.lower() def collapse_whitespace(text): - return re.sub(_whitespace_re, ' ', text) + return re.sub(_whitespace_re, ' ', text) def convert_to_ascii(text): - return unidecode(text) + return unidecode(text) def basic_cleaners(text): - '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' - text = lowercase(text) - text = collapse_whitespace(text) - return text + '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' + text = lowercase(text) + text = collapse_whitespace(text) + return text def transliteration_cleaners(text): - '''Pipeline for non-English text that transliterates to ASCII.''' - text = convert_to_ascii(text) - text = lowercase(text) - text = collapse_whitespace(text) - return text + '''Pipeline for non-English text that transliterates to ASCII.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text def english_cleaners(text): - '''Pipeline for English text, including number and abbreviation expansion.''' - text = convert_to_ascii(text) - text = lowercase(text) - text = expand_numbers(text) - text = expand_abbreviations(text) - text = collapse_whitespace(text) - text = text.replace('"', '') - return text + '''Pipeline for English text, including number and abbreviation expansion.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_numbers(text) + text = expand_abbreviations(text) + text = collapse_whitespace(text) + text = text.replace('"', '') + return text diff --git a/dlas/models/audio/tts/tacotron2/text/cmudict.py b/dlas/models/audio/tts/tacotron2/text/cmudict.py index 62bfef74..97040e00 100644 --- a/dlas/models/audio/tts/tacotron2/text/cmudict.py +++ b/dlas/models/audio/tts/tacotron2/text/cmudict.py @@ -2,64 +2,62 @@ import re - valid_symbols = [ - 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', - 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', - 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', - 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', - 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', - 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', - 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' + 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', + 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', + 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', + 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', + 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', + 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', + 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' ] _valid_symbol_set = set(valid_symbols) class CMUDict: - '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' - def __init__(self, file_or_path, keep_ambiguous=True): - if isinstance(file_or_path, str): - with open(file_or_path, encoding='latin-1') as f: - entries = _parse_cmudict(f) - else: - entries = _parse_cmudict(file_or_path) - if not keep_ambiguous: - entries = {word: pron for word, pron in entries.items() if len(pron) == 1} - self._entries = entries + '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' + def __init__(self, file_or_path, keep_ambiguous=True): + if isinstance(file_or_path, str): + with open(file_or_path, encoding='latin-1') as f: + entries = _parse_cmudict(f) + else: + entries = _parse_cmudict(file_or_path) + if not keep_ambiguous: + entries = {word: pron for word, + pron in entries.items() if len(pron) == 1} + self._entries = entries - def __len__(self): - return len(self._entries) - - - def lookup(self, word): - '''Returns list of ARPAbet pronunciations of the given word.''' - return self._entries.get(word.upper()) + def __len__(self): + return len(self._entries) + def lookup(self, word): + '''Returns list of ARPAbet pronunciations of the given word.''' + return self._entries.get(word.upper()) _alt_re = re.compile(r'\([0-9]+\)') def _parse_cmudict(file): - cmudict = {} - for line in file: - if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): - parts = line.split(' ') - word = re.sub(_alt_re, '', parts[0]) - pronunciation = _get_pronunciation(parts[1]) - if pronunciation: - if word in cmudict: - cmudict[word].append(pronunciation) - else: - cmudict[word] = [pronunciation] - return cmudict + cmudict = {} + for line in file: + if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): + parts = line.split(' ') + word = re.sub(_alt_re, '', parts[0]) + pronunciation = _get_pronunciation(parts[1]) + if pronunciation: + if word in cmudict: + cmudict[word].append(pronunciation) + else: + cmudict[word] = [pronunciation] + return cmudict def _get_pronunciation(s): - parts = s.strip().split(' ') - for part in parts: - if part not in _valid_symbol_set: - return None - return ' '.join(parts) + parts = s.strip().split(' ') + for part in parts: + if part not in _valid_symbol_set: + return None + return ' '.join(parts) diff --git a/dlas/models/audio/tts/tacotron2/text/numbers.py b/dlas/models/audio/tts/tacotron2/text/numbers.py index 0d5f7fa8..101e3eda 100644 --- a/dlas/models/audio/tts/tacotron2/text/numbers.py +++ b/dlas/models/audio/tts/tacotron2/text/numbers.py @@ -1,8 +1,8 @@ """ from https://github.com/keithito/tacotron """ -import inflect import re +import inflect _inflect = inflect.engine() _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') @@ -14,58 +14,58 @@ _number_re = re.compile(r'[0-9]+') def _remove_commas(m): - return m.group(1).replace(',', '') + return m.group(1).replace(',', '') def _expand_decimal_point(m): - return m.group(1).replace('.', ' point ') + return m.group(1).replace('.', ' point ') def _expand_dollars(m): - match = m.group(1) - parts = match.split('.') - if len(parts) > 2: - return match + ' dollars' # Unexpected format - dollars = int(parts[0]) if parts[0] else 0 - cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 - if dollars and cents: - dollar_unit = 'dollar' if dollars == 1 else 'dollars' - cent_unit = 'cent' if cents == 1 else 'cents' - return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) - elif dollars: - dollar_unit = 'dollar' if dollars == 1 else 'dollars' - return '%s %s' % (dollars, dollar_unit) - elif cents: - cent_unit = 'cent' if cents == 1 else 'cents' - return '%s %s' % (cents, cent_unit) - else: - return 'zero dollars' + match = m.group(1) + parts = match.split('.') + if len(parts) > 2: + return match + ' dollars' # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + return '%s %s' % (dollars, dollar_unit) + elif cents: + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s' % (cents, cent_unit) + else: + return 'zero dollars' def _expand_ordinal(m): - return _inflect.number_to_words(m.group(0)) + return _inflect.number_to_words(m.group(0)) def _expand_number(m): - num = int(m.group(0)) - if num > 1000 and num < 3000: - if num == 2000: - return 'two thousand' - elif num > 2000 and num < 2010: - return 'two thousand ' + _inflect.number_to_words(num % 100) - elif num % 100 == 0: - return _inflect.number_to_words(num // 100) + ' hundred' + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return 'two thousand' + elif num > 2000 and num < 2010: + return 'two thousand ' + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + ' hundred' + else: + return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') else: - return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') - else: - return _inflect.number_to_words(num, andword='') + return _inflect.number_to_words(num, andword='') def normalize_numbers(text): - text = re.sub(_comma_number_re, _remove_commas, text) - text = re.sub(_pounds_re, r'\1 pounds', text) - text = re.sub(_dollars_re, _expand_dollars, text) - text = re.sub(_decimal_number_re, _expand_decimal_point, text) - text = re.sub(_ordinal_re, _expand_ordinal, text) - text = re.sub(_number_re, _expand_number, text) - return text + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r'\1 pounds', text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/dlas/models/audio/tts/tacotron2/text/symbols.py b/dlas/models/audio/tts/tacotron2/text/symbols.py index 071395be..28c616c3 100644 --- a/dlas/models/audio/tts/tacotron2/text/symbols.py +++ b/dlas/models/audio/tts/tacotron2/text/symbols.py @@ -4,8 +4,9 @@ Defines the set of symbols used in text input to the model. The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' -from models.audio.tts.tacotron2.text import cmudict + +from dlas.models.audio.tts.tacotron2.text import cmudict _pad = '_' _punctuation = '!\'(),.:;? ' _special = '-' @@ -15,4 +16,5 @@ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' _arpabet = ['@' + s for s in cmudict.valid_symbols] # Export all symbols: -symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet +symbols = [_pad] + list(_special) + list(_punctuation) + \ + list(_letters) + _arpabet diff --git a/dlas/models/audio/tts/tacotron2/wave_tacotron.py b/dlas/models/audio/tts/tacotron2/wave_tacotron.py index 8e73e93d..6c4a6ad0 100644 --- a/dlas/models/audio/tts/tacotron2/wave_tacotron.py +++ b/dlas/models/audio/tts/tacotron2/wave_tacotron.py @@ -1,20 +1,20 @@ from math import sqrt + import torch from munch import munchify -from torch.autograd import Variable from torch import nn +from torch.autograd import Variable from torch.nn import functional as F -from models.arch_util import ConvGnSilu -from models.diffusion.unet_diffusion import UNetModel, AttentionPool2d -from models.audio.tts.tacotron2.layers import LinearNorm -from models.audio.tts.tacotron2.hparams import create_hparams -from models.audio.tts.tacotron2.tacotron2 import Attention, Encoder -from trainer.networks import register_model -from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths -from utils.util import checkpoint -import torch_intermediary as ml - +import dlas.torch_intermediary as ml +from dlas.models.arch_util import ConvGnSilu +from dlas.models.audio.tts.tacotron2.hparams import create_hparams +from dlas.models.audio.tts.tacotron2.layers import LinearNorm +from dlas.models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths +from dlas.models.audio.tts.tacotron2.tacotron2 import Attention, Encoder +from dlas.models.diffusion.unet_diffusion import AttentionPool2d, UNetModel +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint class WavDecoder(nn.Module): @@ -24,26 +24,33 @@ class WavDecoder(nn.Module): self.K = int(sample_rate * (K_ms/1000)) self.clarifier = UNetModel(image_size=self.K, in_channels=1, - model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model. + # This is a requirement to enable to load the embedding produced by the decoder into the unet model. + model_channels=dec_channels // 4, out_channels=2, # 2 channels: eps_pred and variance_pred num_res_blocks=2, attention_resolutions=(8,), dims=1, dropout=.1, - channel_mult=(1,2,4,8), + channel_mult=(1, 2, 4, 8), use_raw_y_as_embedding=True) assert self.K % 64 == 0 # Otherwise the UNetModel breaks. - self.pre_rnn = nn.Sequential(ConvGnSilu(1,32,kernel_size=5,convnd=nn.Conv1d), - ConvGnSilu(32,64,kernel_size=5,stride=4,convnd=nn.Conv1d), - ConvGnSilu(64,128,kernel_size=5,stride=4,convnd=nn.Conv1d), - ConvGnSilu(128,256,kernel_size=5,stride=4,convnd=nn.Conv1d), - ConvGnSilu(256,dec_channels,kernel_size=1,convnd=nn.Conv1d), - AttentionPool2d(self.K//64,dec_channels,dec_channels//4)) + self.pre_rnn = nn.Sequential(ConvGnSilu(1, 32, kernel_size=5, convnd=nn.Conv1d), + ConvGnSilu(32, 64, kernel_size=5, + stride=4, convnd=nn.Conv1d), + ConvGnSilu(64, 128, kernel_size=5, + stride=4, convnd=nn.Conv1d), + ConvGnSilu(128, 256, kernel_size=5, + stride=4, convnd=nn.Conv1d), + ConvGnSilu(256, dec_channels, + kernel_size=1, convnd=nn.Conv1d), + AttentionPool2d(self.K//64, dec_channels, dec_channels//4)) self.attention_rnn = nn.LSTMCell(dec_channels*2, dec_channels) - self.attention_layer = Attention(dec_channels, dec_channels, dec_channels) + self.attention_layer = Attention( + dec_channels, dec_channels, dec_channels) self.decoder_rnn = nn.LSTMCell(dec_channels*2, dec_channels, 1) self.linear_projection = LinearNorm(dec_channels*2, self.dec_channels) - self.gate_layer = LinearNorm(self.dec_channels*2, 1, bias=True, w_init_gain='sigmoid') + self.gate_layer = LinearNorm( + self.dec_channels*2, 1, bias=True, w_init_gain='sigmoid') self.dropout_probability = dropout_probability def chunk_wav(self, wav): @@ -51,16 +58,18 @@ class WavDecoder(nn.Module): # Pad the last chunk as needed. padding_needed = self.K - wavs[-1].shape[-1] if padding_needed > 0: - wavs[-1] = F.pad(wavs[-1], (0,padding_needed)) + wavs[-1] = F.pad(wavs[-1], (0, padding_needed)) - wavs = torch.stack(wavs, dim=1) # wavs.shape = (b,s,K) where s=decoder sequence length + # wavs.shape = (b,s,K) where s=decoder sequence length + wavs = torch.stack(wavs, dim=1) return wavs, padding_needed - + def prepare_decoder_inputs(self, inp): # inp.shape = (b,s,K) chunked waveform. - b,s,K = inp.shape - first_frame = torch.zeros(b,1,K).to(inp.device) - x = torch.cat([first_frame, inp[:,:-1]], dim=1) # It is now aligned for teacher forcing. + b, s, K = inp.shape + first_frame = torch.zeros(b, 1, K).to(inp.device) + # It is now aligned for teacher forcing. + x = torch.cat([first_frame, inp[:, :-1]], dim=1) return x def initialize_decoder_states(self, memory, mask): @@ -75,18 +84,25 @@ class WavDecoder(nn.Module): B = memory.size(0) MAX_TIME = memory.size(1) - self.attention_hidden = Variable(memory.data.new(B, self.dec_channels).zero_()) - self.attention_cell = Variable(memory.data.new(B, self.dec_channels).zero_()) + self.attention_hidden = Variable( + memory.data.new(B, self.dec_channels).zero_()) + self.attention_cell = Variable( + memory.data.new(B, self.dec_channels).zero_()) - self.decoder_hidden = Variable(memory.data.new(B, self.dec_channels).zero_()) - self.decoder_cell = Variable(memory.data.new(B, self.dec_channels).zero_()) + self.decoder_hidden = Variable( + memory.data.new(B, self.dec_channels).zero_()) + self.decoder_cell = Variable( + memory.data.new(B, self.dec_channels).zero_()) self.attention_weights = Variable(memory.data.new(B, MAX_TIME).zero_()) - self.attention_weights_cum = Variable(memory.data.new(B, MAX_TIME).zero_()) - self.attention_context = Variable(memory.data.new(B, self.dec_channels).zero_()) + self.attention_weights_cum = Variable( + memory.data.new(B, MAX_TIME).zero_()) + self.attention_context = Variable( + memory.data.new(B, self.dec_channels).zero_()) self.memory = memory - self.processed_memory = checkpoint(self.attention_layer.memory_layer, memory) + self.processed_memory = checkpoint( + self.attention_layer.memory_layer, memory) self.mask = mask def teardown_states(self): @@ -113,20 +129,28 @@ class WavDecoder(nn.Module): attention_weights: """ cell_input = torch.cat((decoder_input, self.attention_context), -1) - self.attention_hidden, self.attention_cell = self.attention_rnn(cell_input, (self.attention_hidden, self.attention_cell)) - self.attention_hidden = F.dropout(self.attention_hidden, self.dropout_probability, self.training) + self.attention_hidden, self.attention_cell = self.attention_rnn( + cell_input, (self.attention_hidden, self.attention_cell)) + self.attention_hidden = F.dropout( + self.attention_hidden, self.dropout_probability, self.training) - attention_weights_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1) + attention_weights_cat = torch.cat((self.attention_weights.unsqueeze( + 1), self.attention_weights_cum.unsqueeze(1)), dim=1) self.attention_context, self.attention_weights = checkpoint(self.attention_layer, self.attention_hidden, self.memory, self.processed_memory, attention_weights_cat, self.mask) self.attention_weights_cum += self.attention_weights - decoder_input = torch.cat((self.attention_hidden, self.attention_context), -1) - self.decoder_hidden, self.decoder_cell = self.decoder_rnn(decoder_input, (self.decoder_hidden, self.decoder_cell)) - self.decoder_hidden = F.dropout(self.decoder_hidden, self.dropout_probability, self.training) + decoder_input = torch.cat( + (self.attention_hidden, self.attention_context), -1) + self.decoder_hidden, self.decoder_cell = self.decoder_rnn( + decoder_input, (self.decoder_hidden, self.decoder_cell)) + self.decoder_hidden = F.dropout( + self.decoder_hidden, self.dropout_probability, self.training) - decoder_hidden_attention_context = torch.cat((self.decoder_hidden, self.attention_context), dim=1) - decoder_output = checkpoint(self.linear_projection, decoder_hidden_attention_context) + decoder_hidden_attention_context = torch.cat( + (self.decoder_hidden, self.attention_context), dim=1) + decoder_output = checkpoint( + self.linear_projection, decoder_hidden_attention_context) gate_prediction = self.gate_layer(decoder_hidden_attention_context) return decoder_output, gate_prediction, self.attention_weights @@ -137,11 +161,11 @@ class WavDecoder(nn.Module): # (T_out, B) -> (B, T_out) gate_outputs = torch.stack(gate_outputs, dim=1).repeat(1, self.K) - b,s,_,K = diffusion_eps.shape + b, s, _, K = diffusion_eps.shape # (B, S, 2, K) -> (B, 2, S*K) - diffusion_eps = diffusion_eps.permute(0,2,1,3).reshape(b, 2, s*K) + diffusion_eps = diffusion_eps.permute(0, 2, 1, 3).reshape(b, 2, s*K) - return diffusion_eps[:,:,:-padding_added], gate_outputs[:,:-padding_added], alignments[:,:-padding_added] + return diffusion_eps[:, :, :-padding_added], gate_outputs[:, :-padding_added], alignments[:, :-padding_added] def forward(self, wav_noised, wav_real, timesteps, text_enc, memory_lengths): ''' @@ -155,14 +179,17 @@ class WavDecoder(nn.Module): wav_noised, padding_added = self.chunk_wav(wav_noised) wav_real, _ = self.chunk_wav(wav_real) wav_real = self.prepare_decoder_inputs(wav_real) - b,s,K = wav_real.shape - wav_real = checkpoint(self.pre_rnn, wav_real.reshape(b*s,1,K)).reshape(b,s,self.dec_channels) + b, s, K = wav_real.shape + wav_real = checkpoint(self.pre_rnn, wav_real.reshape( + b*s, 1, K)).reshape(b, s, self.dec_channels) - self.initialize_decoder_states(text_enc, mask=~get_mask_from_lengths(memory_lengths)) + self.initialize_decoder_states( + text_enc, mask=~get_mask_from_lengths(memory_lengths)) decoder_contexts, gate_outputs, alignments = [], [], [] while len(decoder_contexts) < wav_real.size(1): decoder_input = wav_real[:, len(decoder_contexts)] - dec_context, gate_output, attention_weights = self.produce_context(decoder_input) + dec_context, gate_output, attention_weights = self.produce_context( + decoder_input) decoder_contexts += [dec_context.squeeze(1)] gate_outputs += [gate_output.squeeze(1)] alignments += [attention_weights] @@ -170,12 +197,14 @@ class WavDecoder(nn.Module): # diffusion_inputs and wavs needs to have the sequence and batch dimensions combined, and needs a channel dimension diffusion_emb = torch.stack(decoder_contexts, dim=1) - b,s,c = diffusion_emb.shape - diffusion_emb = diffusion_emb.reshape(b*s,c) - wav_noised = wav_noised.reshape(b*s,1,self.K) - diffusion_eps = self.clarifier(wav_noised, timesteps.repeat(s), diffusion_emb).reshape(b,s,2,self.K) + b, s, c = diffusion_emb.shape + diffusion_emb = diffusion_emb.reshape(b*s, c) + wav_noised = wav_noised.reshape(b*s, 1, self.K) + diffusion_eps = self.clarifier(wav_noised, timesteps.repeat( + s), diffusion_emb).reshape(b, s, 2, self.K) # Recombine diffusion outputs across the sequence into a single prediction. - diffusion_eps, gate_outputs, alignments = self.recombine(diffusion_eps, gate_outputs, alignments, padding_added) + diffusion_eps, gate_outputs, alignments = self.recombine( + diffusion_eps, gate_outputs, alignments, padding_added) return diffusion_eps, gate_outputs, alignments @@ -199,11 +228,11 @@ class WaveTacotron2(nn.Module): if self.mask_padding and output_lengths is not None: mask_fill = outputs[0].shape[-1] mask = ~get_mask_from_lengths(output_lengths, mask_fill) - mask = mask.unsqueeze(1).repeat(1,2,1) + mask = mask.unsqueeze(1).repeat(1, 2, 1) outputs[0].data.masked_fill_(mask, 0.0) outputs[0] = outputs[0].unsqueeze(1) # Re-add channel dimension. - outputs[1].data.masked_fill_(mask[:,0], 1e3) # gate energies + outputs[1].data.masked_fill_(mask[:, 0], 1e3) # gate energies return outputs @@ -214,7 +243,8 @@ class WaveTacotron2(nn.Module): text_lengths, output_lengths = text_lengths.data, output_lengths.data embedded_inputs = self.embedding(text_inputs).transpose(1, 2) - encoder_outputs = checkpoint(self.encoder, embedded_inputs, text_lengths) + encoder_outputs = checkpoint( + self.encoder, embedded_inputs, text_lengths) eps_pred, gate_outputs, alignments = self.decoder( wavs_diffused, wavs_corrected, timesteps, encoder_outputs, memory_lengths=text_lengths) @@ -234,7 +264,7 @@ if __name__ == '__main__': out = tron(wavs_diffused=torch.randn(2, 1, 22000), wavs_corrected=torch.randn(2, 1, 22000), timesteps=torch.LongTensor([555, 543]), - text_inputs=torch.randint(high=24, size=(2,12)), + text_inputs=torch.randint(high=24, size=(2, 12)), text_lengths=torch.tensor([12, 12]), output_lengths=torch.tensor([21995])) - print([o.shape for o in out]) \ No newline at end of file + print([o.shape for o in out]) diff --git a/dlas/models/audio/tts/transformer_builders.py b/dlas/models/audio/tts/transformer_builders.py index ce88d1a7..0c49e052 100644 --- a/dlas/models/audio/tts/transformer_builders.py +++ b/dlas/models/audio/tts/transformer_builders.py @@ -23,11 +23,13 @@ Returns: import functools import random from time import time + import torch import torch.nn as nn -import torch_intermediary as ml from tqdm import tqdm +import dlas.torch_intermediary as ml + def null_position_embeddings(range, dim): return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) @@ -61,13 +63,13 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text """ from transformers import GPT2Config, GPT2Model gpt_config = GPT2Config(vocab_size=256, # Unused. - n_positions=max_mel_seq_len+max_text_seq_len, - n_ctx=max_mel_seq_len+max_text_seq_len, - n_embd=model_dim, - n_layer=layers, - n_head=heads, - gradient_checkpointing=checkpointing, - use_cache=not checkpointing) + n_positions=max_mel_seq_len+max_text_seq_len, + n_ctx=max_mel_seq_len+max_text_seq_len, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + gradient_checkpointing=checkpointing, + use_cache=not checkpointing) gpt = GPT2Model(gpt_config) # Override the built in positional embeddings del gpt.wpe @@ -75,8 +77,10 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text # Built-in token embeddings are unused. del gpt.wte - mel_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim) - text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim) + mel_pos_emb = LearnedPositionEmbeddings( + max_mel_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim) + text_pos_emb = LearnedPositionEmbeddings( + max_text_seq_len, model_dim) if max_mel_seq_len != -1 else functools.partial(null_position_embeddings, dim=model_dim) return gpt, mel_pos_emb, text_pos_emb, None, None @@ -85,7 +89,8 @@ def build_lr_performer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_l lucidrains Performer implementation, https://github.com/lucidrains/performer-pytorch """ from models.lucidrains.performer.performer_pytorch import Performer - model = Performer(dim=model_dim, depth=layers, heads=heads, dim_head=model_dim, causal=True) + model = Performer(dim=model_dim, depth=layers, + heads=heads, dim_head=model_dim, causal=True) return model @@ -104,14 +109,14 @@ def build_lr_xformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len def test_all_performance(**kwargs): - transformer_builders = [#build_hf_gpt_transformer, - build_lr_performer,] - # build_lr_reformer, - # build_lr_xformer] + transformer_builders = [ # build_hf_gpt_transformer, + build_lr_performer,] + # build_lr_reformer, + # build_lr_xformer] for builder in transformer_builders: model = builder(**kwargs) start = time() - args = torch.randint(0, 8192, (16,450)) + args = torch.randint(0, 8192, (16, 450)) for k in tqdm(range(10)): model(args) stop = time() @@ -119,4 +124,5 @@ def test_all_performance(**kwargs): if __name__ == '__main__': - test_all_performance(layers=12, model_dim=512, heads=8, num_tokens=8192, max_seq_len=1000, checkpointing=False) \ No newline at end of file + test_all_performance(layers=12, model_dim=512, heads=8, + num_tokens=8192, max_seq_len=1000, checkpointing=False) diff --git a/dlas/models/audio/tts/transformer_diffusion_tts.py b/dlas/models/audio/tts/transformer_diffusion_tts.py index bb78a008..1d5acb74 100644 --- a/dlas/models/audio/tts/transformer_diffusion_tts.py +++ b/dlas/models/audio/tts/transformer_diffusion_tts.py @@ -2,17 +2,23 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock -from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding -from trainer.networks import register_model -from utils.util import checkpoint -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.models.diffusion.nn import (conv_nd, linear, normalization, + timestep_embedding, zero_module) +from dlas.models.diffusion.unet_diffusion import (TimestepBlock, + TimestepEmbedSequential) +from dlas.models.lucidrains.x_transformers import (Attention, Encoder, + FeedForward, + RMSScaleShiftNorm, + RotaryEmbedding) +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint def is_latent(t): return t.dtype == torch.float + def is_sequence(t): return t.dtype == torch.long @@ -21,7 +27,8 @@ class MultiGroupEmbedding(nn.Module): def __init__(self, tokens, groups, dim): super().__init__() # nn.Embedding - self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)]) + self.m = nn.ModuleList( + [ml.Embedding(tokens, dim // groups) for _ in range(groups)]) def forward(self, x): h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] @@ -41,13 +48,16 @@ class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock): class AttentionBlock(TimestepBlock): def __init__(self, dim, heads, dropout): super().__init__() - self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout, zero_init_output=False) - self.ff = FeedForward(dim, mult=1, dropout=dropout, zero_init_output=True) + self.attn = Attention(dim, heads=heads, causal=False, + dropout=dropout, zero_init_output=False) + self.ff = FeedForward( + dim, mult=1, dropout=dropout, zero_init_output=True) self.rms_scale_norm = RMSScaleShiftNorm(dim) def forward(self, x, timestep_emb, rotary_emb): h = self.rms_scale_norm(x, norm_scale_shift_inp=timestep_emb) - h, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb) + h, _, _, _ = checkpoint(self.attn, h, None, None, + None, None, None, rotary_emb) h = checkpoint(self.ff, h) return h + x @@ -56,6 +66,7 @@ class TransformerDiffusionTTS(nn.Module): """ A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way? """ + def __init__( self, model_channels=512, @@ -71,7 +82,8 @@ class TransformerDiffusionTTS(nn.Module): dropout=0, use_fp16=False, # Parameters for regularization. - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. + # This implements a mechanism similar to what is used in classifier-free training. + unconditioned_percentage=.1, ): super().__init__() @@ -91,17 +103,17 @@ class TransformerDiffusionTTS(nn.Module): linear(model_channels, model_channels), ) self.conditioning_embedder = nn.Sequential(nn.Conv1d(in_channels, model_channels // 2, 3, padding=1, stride=2), - nn.Conv1d(model_channels//2, model_channels,3,padding=1,stride=2)) + nn.Conv1d(model_channels//2, model_channels, 3, padding=1, stride=2)) self.conditioning_encoder = Encoder( - dim=model_channels, - depth=4, - heads=heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - ) + dim=model_channels, + depth=4, + heads=heads, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + ) self.clvp_encoder = ml.Linear(clvp_in_dim, model_channels) # nn.Embedding self.type_embedding = ml.Embedding(types, model_channels) @@ -114,43 +126,48 @@ class TransformerDiffusionTTS(nn.Module): # nn.Embedding self.embeddings = ml.Embedding(token_count, model_channels) else: - self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels) + self.embeddings = MultiGroupEmbedding( + token_count, in_groups, model_channels) self.latent_conditioner = nn.Sequential( nn.Conv1d(in_latent_channels, model_channels, 3, padding=1), Encoder( - dim=model_channels, - depth=2, - heads=heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - ) + dim=model_channels, + depth=2, + heads=heads, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + ) ) - self.latent_fade = nn.Parameter(torch.zeros(1,1,model_channels)) + self.latent_fade = nn.Parameter(torch.zeros(1, 1, model_channels)) self.code_converter = Encoder( - dim=model_channels, - depth=3, - heads=heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - ) + dim=model_channels, + depth=3, + heads=heads, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + ) - self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels)) - self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) + self.unconditioned_embedding = nn.Parameter( + torch.randn(1, 1, model_channels)) + self.mel_head = nn.Conv1d( + model_channels, in_channels, kernel_size=3, padding=1) self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) self.intg = ml.Linear(model_channels*2, model_channels) - self.layers = TimestepRotaryEmbedSequential(*[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers)]) + self.layers = TimestepRotaryEmbedSequential( + *[AttentionBlock(model_channels, model_channels//64, dropout) for _ in range(num_layers)]) self.out = nn.Sequential( normalization(model_channels), nn.SiLU(), - zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), + zero_module(conv_nd(1, model_channels, + out_channels, 3, padding=1)), ) self.debug_codes = {} @@ -165,7 +182,8 @@ class TransformerDiffusionTTS(nn.Module): return groups def timestep_independent(self, codes, conditioning_input, expected_seq_len, prenet_latent=None, return_code_pred=False): - cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1) + cond_emb = self.conditioning_embedder( + conditioning_input).permute(0, 2, 1) cond_emb = self.conditioning_encoder(cond_emb)[:, 0] code_emb = self.embeddings(codes) @@ -173,7 +191,8 @@ class TransformerDiffusionTTS(nn.Module): latent_conditioning = self.latent_conditioner(prenet_latent) code_emb = code_emb + latent_conditioning * self.latent_fade - unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device) + unconditioned_batches = torch.zeros( + (code_emb.shape[0], 1, 1), device=code_emb.device) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. if self.training and self.unconditioned_percentage > 0: unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), @@ -182,57 +201,65 @@ class TransformerDiffusionTTS(nn.Module): code_emb) code_emb = self.code_converter(code_emb) - expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1) + expanded_code_emb = F.interpolate(code_emb.permute( + 0, 2, 1), size=expected_seq_len, mode='nearest').permute(0, 2, 1) if not return_code_pred: return expanded_code_emb, cond_emb else: # Perform the mel_head computation on the pre-exanded code embeddings, then interpolate it separately. - mel_pred = self.mel_head(code_emb.permute(0,2,1)) - mel_pred = F.interpolate(mel_pred, size=expected_seq_len, mode='nearest') + mel_pred = self.mel_head(code_emb.permute(0, 2, 1)) + mel_pred = F.interpolate( + mel_pred, size=expected_seq_len, mode='nearest') # Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. # This is because we don't want that gradient being used to train parameters through the codes_embedder as # it unbalances contributions to that network from the MSE loss. mel_pred = mel_pred * unconditioned_batches.logical_not() return expanded_code_emb, cond_emb, mel_pred - def forward(self, x, timesteps, codes=None, conditioning_input=None, clvp_input=None, type=None, prenet_latent=None, precomputed_code_embeddings=None, precomputed_cond_embeddings=None, conditioning_free=False, return_code_pred=False): if precomputed_code_embeddings is not None: assert precomputed_cond_embeddings is not None, "Must specify both precomputed embeddings if one is specified" assert codes is None and conditioning_input is None and prenet_latent is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here." - assert not (return_code_pred and precomputed_code_embeddings is not None), "I cannot compute a code_pred output for you." + assert not ( + return_code_pred and precomputed_code_embeddings is not None), "I cannot compute a code_pred output for you." assert type is not None, "Type is required." unused_params = [] if not return_code_pred: unused_params.extend(list(self.mel_head.parameters())) if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) - unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) + code_emb = self.unconditioned_embedding.repeat( + x.shape[0], 1, x.shape[-1]) + unused_params.extend( + list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) unused_params.extend(list(self.latent_conditioner.parameters())) else: if precomputed_code_embeddings is not None: code_emb = precomputed_code_embeddings cond_emb = precomputed_cond_embeddings else: - code_emb, cond_emb, mel_pred = self.timestep_independent(codes, conditioning_input, x.shape[-1], prenet_latent, True) + code_emb, cond_emb, mel_pred = self.timestep_independent( + codes, conditioning_input, x.shape[-1], prenet_latent, True) if prenet_latent is None: - unused_params.extend(list(self.latent_conditioner.parameters()) + [self.latent_fade]) + unused_params.extend( + list(self.latent_conditioner.parameters()) + [self.latent_fade]) unused_params.append(self.unconditioned_embedding) - clvp_emb = torch.zeros_like(cond_emb) if clvp_input is None else self.clvp_encoder(clvp_input) + clvp_emb = torch.zeros_like( + cond_emb) if clvp_input is None else self.clvp_encoder(clvp_input) type_emb = self.type_embedding(type) if clvp_input is None: unused_params.extend(self.clvp_encoder.parameters()) - blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb + clvp_emb + type_emb - x = self.inp_block(x).permute(0,2,1) + blk_emb = self.time_embed(timestep_embedding( + timesteps, self.model_channels)) + cond_emb + clvp_emb + type_emb + x = self.inp_block(x).permute(0, 2, 1) rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device) x = self.intg(torch.cat([x, code_emb], dim=-1)) x = self.layers(x, blk_emb, rotary_pos_emb) - x = x.float().permute(0,2,1) + x = x.float().permute(0, 2, 1) out = self.out(x) # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. @@ -253,13 +280,14 @@ def register_transformer_diffusion_tts(opt_net, opt): if __name__ == '__main__': clip = torch.randn(2, 256, 400) - aligned_latent = torch.randn(2,100,512) - aligned_sequence = torch.randint(0,8,(2,100,8)) + aligned_latent = torch.randn(2, 100, 512) + aligned_sequence = torch.randint(0, 8, (2, 100, 8)) cond = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) - clvp = torch.randn(2,768) - type = torch.LongTensor([0,1]) - model = TransformerDiffusionTTS(512, unconditioned_percentage=.5, in_groups=8) - o = model(clip, ts, aligned_sequence, cond, clvp_input=clvp, type=type, return_code_pred=True) - #o = model(clip, ts, aligned_sequence, cond, aligned_latent) - + clvp = torch.randn(2, 768) + type = torch.LongTensor([0, 1]) + model = TransformerDiffusionTTS( + 512, unconditioned_percentage=.5, in_groups=8) + o = model(clip, ts, aligned_sequence, cond, + clvp_input=clvp, type=type, return_code_pred=True) + # o = model(clip, ts, aligned_sequence, cond, aligned_latent) diff --git a/dlas/models/audio/tts/transformer_diffusion_tts2.py b/dlas/models/audio/tts/transformer_diffusion_tts2.py index e4cb3a6e..a88f0f90 100644 --- a/dlas/models/audio/tts/transformer_diffusion_tts2.py +++ b/dlas/models/audio/tts/transformer_diffusion_tts2.py @@ -1,18 +1,24 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torch_intermediary as ml -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock -from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding -from trainer.networks import register_model -from utils.util import checkpoint, print_network +import dlas.torch_intermediary as ml +from dlas.models.diffusion.nn import (conv_nd, linear, normalization, + timestep_embedding, zero_module) +from dlas.models.diffusion.unet_diffusion import (TimestepBlock, + TimestepEmbedSequential) +from dlas.models.lucidrains.x_transformers import (Attention, Encoder, + FeedForward, + RMSScaleShiftNorm, + RotaryEmbedding) +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint, print_network def is_latent(t): return t.dtype == torch.float + def is_sequence(t): return t.dtype == torch.long @@ -21,7 +27,8 @@ class MultiGroupEmbedding(nn.Module): def __init__(self, tokens, groups, dim): super().__init__() # nn.Embedding - self.m = nn.ModuleList([ml.Embedding(tokens, dim // groups) for _ in range(groups)]) + self.m = nn.ModuleList( + [ml.Embedding(tokens, dim // groups) for _ in range(groups)]) def forward(self, x): h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] @@ -44,12 +51,14 @@ class DietAttentionBlock(TimestepBlock): self.rms_scale_norm = RMSScaleShiftNorm(in_dim) self.proj = ml.Linear(in_dim, dim) self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout) - self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True) + self.ff = FeedForward(dim, in_dim, mult=1, + dropout=dropout, zero_init_output=True) def forward(self, x, timestep_emb, rotary_emb): h = self.rms_scale_norm(x, norm_scale_shift_inp=timestep_emb) h = self.proj(h) - h, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb) + h, _, _, _ = checkpoint(self.attn, h, None, None, + None, None, None, rotary_emb) h = checkpoint(self.ff, h) return h + x @@ -58,6 +67,7 @@ class TransformerDiffusionTTS(nn.Module): """ A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way? """ + def __init__( self, prenet_channels=256, @@ -75,7 +85,8 @@ class TransformerDiffusionTTS(nn.Module): dropout=0, use_fp16=False, # Parameters for regularization. - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. + # This implements a mechanism similar to what is used in classifier-free training. + unconditioned_percentage=.1, ): super().__init__() @@ -96,17 +107,17 @@ class TransformerDiffusionTTS(nn.Module): ) prenet_heads = prenet_channels//64 self.conditioning_embedder = nn.Sequential(nn.Conv1d(in_channels, prenet_channels // 2, 3, padding=1, stride=2), - nn.Conv1d(prenet_channels//2, prenet_channels,3,padding=1,stride=2)) + nn.Conv1d(prenet_channels//2, prenet_channels, 3, padding=1, stride=2)) self.conditioning_encoder = Encoder( - dim=prenet_channels, - depth=4, - heads=prenet_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - ) + dim=prenet_channels, + depth=4, + heads=prenet_heads, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + ) self.clvp_encoder = ml.Linear(clvp_in_dim, prenet_channels) # nn.Embedding self.type_embedding = ml.Embedding(types, prenet_channels) @@ -119,45 +130,48 @@ class TransformerDiffusionTTS(nn.Module): # nn.Embedding self.embeddings = ml.Embedding(token_count, prenet_channels) else: - self.embeddings = MultiGroupEmbedding(token_count, in_groups, prenet_channels) + self.embeddings = MultiGroupEmbedding( + token_count, in_groups, prenet_channels) self.latent_conditioner = nn.Sequential( nn.Conv1d(in_latent_channels, prenet_channels, 3, padding=1), Encoder( - dim=prenet_channels, - depth=2, - heads=prenet_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - ) + dim=prenet_channels, + depth=2, + heads=prenet_heads, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + ) ) - self.latent_fade = nn.Parameter(torch.zeros(1,1,prenet_channels)) + self.latent_fade = nn.Parameter(torch.zeros(1, 1, prenet_channels)) self.code_converter = Encoder( - dim=prenet_channels, - depth=3, - heads=prenet_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - ) + dim=prenet_channels, + depth=3, + heads=prenet_heads, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + ) - self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) + self.unconditioned_embedding = nn.Parameter( + torch.randn(1, 1, prenet_channels)) self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) self.cond_intg = ml.Linear(prenet_channels*4, model_channels) self.intg = ml.Linear(prenet_channels*2, model_channels) - self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)]) - + self.layers = TimestepRotaryEmbedSequential( + *[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)]) self.out = nn.Sequential( normalization(model_channels), nn.SiLU(), - zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), + zero_module(conv_nd(1, model_channels, + out_channels, 3, padding=1)), ) self.debug_codes = {} @@ -172,7 +186,8 @@ class TransformerDiffusionTTS(nn.Module): return groups def timestep_independent(self, codes, conditioning_input, expected_seq_len, prenet_latent=None): - cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1) + cond_emb = self.conditioning_embedder( + conditioning_input).permute(0, 2, 1) cond_emb = self.conditioning_encoder(cond_emb)[:, 0] code_emb = self.embeddings(codes) @@ -188,11 +203,11 @@ class TransformerDiffusionTTS(nn.Module): code_emb) code_emb = self.code_converter(code_emb) - expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1) + expanded_code_emb = F.interpolate(code_emb.permute( + 0, 2, 1), size=expected_seq_len, mode='nearest').permute(0, 2, 1) return expanded_code_emb, cond_emb - def forward(self, x, timesteps, codes=None, conditioning_input=None, clvp_input=None, type=None, prenet_latent=None, precomputed_code_embeddings=None, precomputed_cond_embeddings=None, conditioning_free=False): if precomputed_code_embeddings is not None: @@ -202,32 +217,38 @@ class TransformerDiffusionTTS(nn.Module): unused_params = [] if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) - unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) + code_emb = self.unconditioned_embedding.repeat( + x.shape[0], 1, x.shape[-1]) + unused_params.extend( + list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) unused_params.extend(list(self.latent_conditioner.parameters())) else: if precomputed_code_embeddings is not None: code_emb = precomputed_code_embeddings cond_emb = precomputed_cond_embeddings else: - code_emb, cond_emb = self.timestep_independent(codes, conditioning_input, x.shape[-1], prenet_latent) + code_emb, cond_emb = self.timestep_independent( + codes, conditioning_input, x.shape[-1], prenet_latent) if prenet_latent is None: - unused_params.extend(list(self.latent_conditioner.parameters()) + [self.latent_fade]) + unused_params.extend( + list(self.latent_conditioner.parameters()) + [self.latent_fade]) unused_params.append(self.unconditioned_embedding) - clvp_emb = torch.zeros_like(cond_emb) if clvp_input is None else self.clvp_encoder(clvp_input) + clvp_emb = torch.zeros_like( + cond_emb) if clvp_input is None else self.clvp_encoder(clvp_input) type_emb = self.type_embedding(type) if clvp_input is None: unused_params.extend(self.clvp_encoder.parameters()) - blk_emb = torch.cat([self.time_embed(timestep_embedding(timesteps, self.prenet_channels)), cond_emb, clvp_emb, type_emb], dim=-1) + blk_emb = torch.cat([self.time_embed(timestep_embedding( + timesteps, self.prenet_channels)), cond_emb, clvp_emb, type_emb], dim=-1) blk_emb = self.cond_intg(blk_emb) - x = self.inp_block(x).permute(0,2,1) + x = self.inp_block(x).permute(0, 2, 1) rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device) x = self.intg(torch.cat([x, code_emb], dim=-1)) x = self.layers(x, blk_emb, rotary_pos_emb) - x = x.float().permute(0,2,1) + x = x.float().permute(0, 2, 1) out = self.out(x) # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. @@ -246,15 +267,15 @@ def register_transformer_diffusion_tts2(opt_net, opt): if __name__ == '__main__': clip = torch.randn(2, 256, 400) - aligned_latent = torch.randn(2,100,512) - aligned_sequence = torch.randint(0,8,(2,100,8)) + aligned_latent = torch.randn(2, 100, 512) + aligned_sequence = torch.randint(0, 8, (2, 100, 8)) cond = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) - clvp = torch.randn(2,768) - type = torch.LongTensor([0,1]) - model = TransformerDiffusionTTS(model_channels=3072, num_layers=16, unconditioned_percentage=.5, in_groups=8, prenet_channels=1024, block_channels=1024) + clvp = torch.randn(2, 768) + type = torch.LongTensor([0, 1]) + model = TransformerDiffusionTTS(model_channels=3072, num_layers=16, + unconditioned_percentage=.5, in_groups=8, prenet_channels=1024, block_channels=1024) print_network(model) o = model(clip, ts, aligned_sequence, cond, clvp_input=clvp, type=type) torch.save(model.state_dict(), 'test.pth') - #o = model(clip, ts, aligned_sequence, cond, aligned_latent) - + # o = model(clip, ts, aligned_sequence, cond, aligned_latent) diff --git a/dlas/models/audio/tts/unet_diffusion_tts7.py b/dlas/models/audio/tts/unet_diffusion_tts7.py index 02d25cb3..009b98da 100644 --- a/dlas/models/audio/tts/unet_diffusion_tts7.py +++ b/dlas/models/audio/tts/unet_diffusion_tts7.py @@ -5,16 +5,19 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import autocast -import torch_intermediary as ml +from x_transformers import ContinuousTransformerWrapper, Encoder -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \ - Downsample, Upsample, TimestepBlock -from models.audio.tts.mini_encoder import AudioMiniEncoder -from scripts.audio.gen.use_diffuse_tts import ceil_multiple -from trainer.networks import register_model -from utils.util import checkpoint -from x_transformers import Encoder, ContinuousTransformerWrapper +import dlas.torch_intermediary as ml +from dlas.models.audio.tts.mini_encoder import AudioMiniEncoder +from dlas.models.diffusion.nn import (conv_nd, linear, normalization, + timestep_embedding, zero_module) +from dlas.models.diffusion.unet_diffusion import (AttentionBlock, Downsample, + TimestepBlock, + TimestepEmbedSequential, + Upsample) +from dlas.scripts.audio.gen.use_diffuse_tts import ceil_multiple +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint def clustered_mask(probability, shape, dev, lateral_expansion_radius_max=3, inverted=False): @@ -567,4 +570,3 @@ if __name__ == '__main__': o.sum().backward() model.before_step(0) torch.save(model.state_dict(), 'test_out.pth') - diff --git a/dlas/models/audio/tts/unet_diffusion_tts9.py b/dlas/models/audio/tts/unet_diffusion_tts9.py index a00b2758..9ca75cbb 100644 --- a/dlas/models/audio/tts/unet_diffusion_tts9.py +++ b/dlas/models/audio/tts/unet_diffusion_tts9.py @@ -5,21 +5,26 @@ import torch.nn as nn import torch.nn.functional as F from torch import autocast from x_transformers import Encoder -import torch_intermediary as ml -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \ - Downsample, Upsample, TimestepBlock -from models.audio.tts.mini_encoder import AudioMiniEncoder -from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder -from scripts.audio.gen.use_diffuse_tts import ceil_multiple -from trainer.networks import register_model -from utils.util import checkpoint +import dlas.torch_intermediary as ml +from dlas.models.audio.tts.mini_encoder import AudioMiniEncoder +from dlas.models.audio.tts.unet_diffusion_tts7 import \ + CheckpointedXTransformerEncoder +from dlas.models.diffusion.nn import (conv_nd, linear, normalization, + timestep_embedding, zero_module) +from dlas.models.diffusion.unet_diffusion import (AttentionBlock, Downsample, + TimestepBlock, + TimestepEmbedSequential, + Upsample) +from dlas.scripts.audio.gen.use_diffuse_tts import ceil_multiple +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint def is_latent(t): return t.dtype == torch.float + def is_sequence(t): return t.dtype == torch.long @@ -49,7 +54,8 @@ class ResBlock(TimestepBlock): self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), - conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding), + conv_nd(dims, channels, self.out_channels, + eff_kernel, padding=eff_padding), ) self.emb_layers = nn.Sequential( @@ -64,14 +70,16 @@ class ResBlock(TimestepBlock): nn.SiLU(), nn.Dropout(p=dropout), zero_module( - conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) + conv_nd(dims, self.out_channels, self.out_channels, + kernel_size, padding=padding) ), ) if self.out_channels == channels: self.skip_connection = nn.Identity() else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding) + self.skip_connection = conv_nd( + dims, channels, self.out_channels, eff_kernel, padding=eff_padding) def forward(self, x, emb): """ @@ -100,6 +108,7 @@ class ResBlock(TimestepBlock): h = self.out_layers(h) return self.skip_connection(x) + h + class DiffusionTts(nn.Module): """ The full UNet model with attention and timestep embedding. @@ -143,12 +152,12 @@ class DiffusionTts(nn.Module): out_channels=2, # mean and variance dropout=0, # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K - channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48), + channel_mult=(1, 1.5, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48), num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2), # spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0) # attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1 - token_conditioning_resolutions=(1,16,), - attention_resolutions=(512,1024,2048), + token_conditioning_resolutions=(1, 16,), + attention_resolutions=(512, 1024, 2048), conv_resample=True, dims=1, use_fp16=False, @@ -159,10 +168,12 @@ class DiffusionTts(nn.Module): scale_factor=2, time_embed_dim_multiplier=4, freeze_main_net=False, - efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3. + # Uses kernels with width of 1 in several places rather than 3. + efficient_convs=True, use_scale_shift_norm=True, # Parameters for regularization. - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. + # This implements a mechanism similar to what is used in classifier-free training. + unconditioned_percentage=.1, # Parameters for super-sampling. super_sampling=False, super_sampling_max_noising_factor=.1, @@ -173,7 +184,8 @@ class DiffusionTts(nn.Module): num_heads_upsample = num_heads if super_sampling: - in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input. + # In super-sampling mode, the LR input is concatenated directly onto the input. + in_channels *= 2 self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels @@ -224,10 +236,12 @@ class DiffusionTts(nn.Module): rotary_pos_emb=True, ) )) - self.latent_converter = nn.Conv1d(in_latent_channels, conditioning_dim, 1) - self.aligned_latent_padding_embedding = nn.Parameter(torch.randn(1,in_latent_channels,1)) + self.latent_converter = nn.Conv1d( + in_latent_channels, conditioning_dim, 1) + self.aligned_latent_padding_embedding = nn.Parameter( + torch.randn(1, in_latent_channels, 1)) if in_channels > 60: # It's a spectrogram. - self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,conditioning_dim,3,padding=1,stride=2), + self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels, conditioning_dim, 3, padding=1, stride=2), CheckpointedXTransformerEncoder( needs_permute=True, max_seq_len=-1, @@ -242,25 +256,33 @@ class DiffusionTts(nn.Module): ff_glu=True, rotary_pos_emb=True, ) - )) + )) else: self.contextual_embedder = AudioMiniEncoder(1, conditioning_dim, base_channels=32, depth=6, resnet_blocks=1, attn_blocks=3, num_attn_heads=8, dropout=dropout, downsample_factor=4, kernel_size=5) - self.conditioning_conv = nn.Conv1d(conditioning_dim*2, conditioning_dim, 1) - self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1)) + self.conditioning_conv = nn.Conv1d( + conditioning_dim*2, conditioning_dim, 1) + self.unconditioned_embedding = nn.Parameter( + torch.randn(1, conditioning_dim, 1)) self.conditioning_timestep_integrator = TimestepEmbedSequential( - ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), - AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels), - ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), - AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels), - ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), + ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, + dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), + AttentionBlock(conditioning_dim, num_heads=num_heads, + num_head_channels=num_head_channels), + ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, + dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), + AttentionBlock(conditioning_dim, num_heads=num_heads, + num_head_channels=num_head_channels), + ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, + dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), ) self.conditioning_expansion = conditioning_expansion self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding) + conv_nd(dims, in_channels, model_channels, + kernel_size, padding=padding) ) ] ) @@ -371,7 +393,8 @@ class DiffusionTts(nn.Module): if level and i == num_blocks: out_ch = ch layers.append( - Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor) + Upsample(ch, conv_resample, dims=dims, + out_channels=out_ch, factor=scale_factor) ) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) @@ -380,7 +403,8 @@ class DiffusionTts(nn.Module): self.out = nn.Sequential( normalization(ch), nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), + zero_module(conv_nd(dims, model_channels, out_channels, + kernel_size, padding=padding)), ) if self.freeze_main_net: @@ -410,13 +434,14 @@ class DiffusionTts(nn.Module): cm = ceil_multiple(x.shape[-1], self.alignment_size) if cm != 0: pc = (cm-x.shape[-1])/x.shape[-1] - x = F.pad(x, (0,cm-x.shape[-1])) + x = F.pad(x, (0, cm-x.shape[-1])) # Also fix aligned_latent, which is aligned to x. if is_latent(aligned_conditioning): aligned_conditioning = torch.cat([aligned_conditioning, self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1) else: - aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1]))) + aligned_conditioning = F.pad( + aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1]))) return x, aligned_conditioning def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False): @@ -435,9 +460,12 @@ class DiffusionTts(nn.Module): if self.super_sampling_enabled: assert lr_input is not None if self.training and self.super_sampling_max_noising_factor > 0: - noising_factor = random.uniform(0,self.super_sampling_max_noising_factor) - lr_input = torch.randn_like(lr_input) * noising_factor + lr_input - lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest') + noising_factor = random.uniform( + 0, self.super_sampling_max_noising_factor) + lr_input = torch.randn_like( + lr_input) * noising_factor + lr_input + lr_input = F.interpolate( + lr_input, size=(x.shape[-1],), mode='nearest') x = torch.cat([x, lr_input], dim=1) # Shuffle aligned_latent to BxCxS format @@ -451,11 +479,13 @@ class DiffusionTts(nn.Module): with autocast(x.device.type, enabled=self.enable_fp16): hs = [] - time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + time_emb = self.time_embed( + timestep_embedding(timesteps, self.model_channels)) # Note: this block does not need to repeated on inference, since it is not timestep-dependent. if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1) + code_emb = self.unconditioned_embedding.repeat( + x.shape[0], 1, 1) else: cond_emb = self.contextual_embedder(conditioning_input) if len(cond_emb.shape) == 3: # Just take the first element. @@ -464,8 +494,10 @@ class DiffusionTts(nn.Module): code_emb = self.latent_converter(aligned_conditioning) else: code_emb = self.code_converter(aligned_conditioning) - cond_emb = cond_emb.unsqueeze(-1).repeat(1, 1, code_emb.shape[-1]) - code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb], dim=1)) + cond_emb = cond_emb.unsqueeze(-1).repeat(1, + 1, code_emb.shape[-1]) + code_emb = self.conditioning_conv( + torch.cat([cond_emb, code_emb], dim=1)) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. if self.training and self.unconditioned_percentage > 0: unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), @@ -474,15 +506,18 @@ class DiffusionTts(nn.Module): code_emb) # Everything after this comment is timestep dependent. - code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1) - code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) + code_emb = torch.repeat_interleave( + code_emb, self.conditioning_expansion, dim=-1) + code_emb = self.conditioning_timestep_integrator( + code_emb, time_emb) first = True time_emb = time_emb.float() h = x for k, module in enumerate(self.input_blocks): if isinstance(module, nn.Conv1d): - h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest') + h_tok = F.interpolate(module(code_emb), size=( + h.shape[-1]), mode='nearest') h = h + h_tok else: with autocast(x.device.type, enabled=self.enable_fp16 and not first): @@ -501,7 +536,8 @@ class DiffusionTts(nn.Module): # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. extraneous_addition = 0 - params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters()) + params = [self.aligned_latent_padding_embedding, + self.unconditioned_embedding] + list(self.latent_converter.parameters()) for p in params: extraneous_addition = extraneous_addition + p.mean() out = out + extraneous_addition * 0 @@ -516,14 +552,14 @@ def register_diffusion_tts9(opt_net, opt): if __name__ == '__main__': clip = torch.randn(2, 1, 32868) - aligned_latent = torch.randn(2,388,1024) - aligned_sequence = torch.randint(0,8192,(2,388)) + aligned_latent = torch.randn(2, 388, 1024) + aligned_sequence = torch.randint(0, 8192, (2, 388)) cond = torch.randn(2, 1, 44000) ts = torch.LongTensor([600, 600]) model = DiffusionTts(128, - channel_mult=[1,1.5,2, 3, 4, 6, 8], + channel_mult=[1, 1.5, 2, 3, 4, 6, 8], num_res_blocks=[2, 2, 2, 2, 2, 2, 1], - token_conditioning_resolutions=[1,4,16,64], + token_conditioning_resolutions=[1, 4, 16, 64], attention_resolutions=[], num_heads=8, kernel_size=3, @@ -535,4 +571,3 @@ if __name__ == '__main__': o = model(clip, ts, aligned_latent, cond) # Test with sequence aligned conditioning o = model(clip, ts, aligned_sequence, cond) - diff --git a/dlas/models/audio/tts/unet_diffusion_tts_flat.py b/dlas/models/audio/tts/unet_diffusion_tts_flat.py index 34017011..aeaf4813 100644 --- a/dlas/models/audio/tts/unet_diffusion_tts_flat.py +++ b/dlas/models/audio/tts/unet_diffusion_tts_flat.py @@ -5,18 +5,22 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import autocast -import torch_intermediary as ml -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock, QKVAttentionLegacy -from models.lucidrains.x_transformers import RelativePositionBias -from trainer.networks import register_model -from utils.util import checkpoint +import dlas.torch_intermediary as ml +from dlas.models.diffusion.nn import (conv_nd, linear, normalization, + timestep_embedding, zero_module) +from dlas.models.diffusion.unet_diffusion import (QKVAttentionLegacy, + TimestepBlock, + TimestepEmbedSequential) +from dlas.models.lucidrains.x_transformers import RelativePositionBias +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint def is_latent(t): return t.dtype == torch.float + def is_sequence(t): return t.dtype == torch.long @@ -54,7 +58,8 @@ class AttentionBlock(nn.Module): self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) if relative_pos_embeddings: - self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64) + self.relative_pos_embeddings = RelativePositionBias(scale=( + channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64) else: self.relative_pos_embeddings = None @@ -92,7 +97,8 @@ class ResBlock(TimestepBlock): self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), - conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding), + conv_nd(dims, channels, self.out_channels, + eff_kernel, padding=eff_padding), ) self.emb_layers = nn.Sequential( @@ -107,14 +113,16 @@ class ResBlock(TimestepBlock): nn.SiLU(), nn.Dropout(p=dropout), zero_module( - conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) + conv_nd(dims, self.out_channels, self.out_channels, + kernel_size, padding=padding) ), ) if self.out_channels == channels: self.skip_connection = nn.Identity() else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding) + self.skip_connection = conv_nd( + dims, channels, self.out_channels, eff_kernel, padding=eff_padding) def forward(self, x, emb): """ @@ -147,8 +155,10 @@ class ResBlock(TimestepBlock): class DiffusionLayer(TimestepBlock): def __init__(self, model_channels, dropout, num_heads): super().__init__() - self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True) - self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True) + self.resblk = ResBlock(model_channels, model_channels, dropout, + model_channels, dims=1, use_scale_shift_norm=True) + self.attn = AttentionBlock( + model_channels, num_heads, relative_pos_embeddings=True) def forward(self, x, time_emb): y = self.resblk(x, time_emb) @@ -170,7 +180,8 @@ class DiffusionTtsFlat(nn.Module): freeze_everything_except_autoregressive_inputs=False, # Parameters for regularization. layer_drop=.1, - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. + # This implements a mechanism similar to what is used in classifier-free training. + unconditioned_percentage=.1, ): super().__init__() @@ -198,33 +209,48 @@ class DiffusionTtsFlat(nn.Module): # nn.Embedding self.code_embedding = ml.Embedding(in_tokens, model_channels) self.code_converter = nn.Sequential( - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, + relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, + relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, + relative_pos_embeddings=True), ) self.code_norm = normalization(model_channels) self.latent_conditioner = nn.Sequential( nn.Conv1d(in_latent_channels, model_channels, 3, padding=1), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, + relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, + relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, + relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, + relative_pos_embeddings=True), ) - self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2), - nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2), - AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), - AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), - AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), - AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2), + nn.Conv1d( + model_channels, model_channels*2, 3, padding=1, stride=2), + AttentionBlock( + model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + AttentionBlock( + model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + AttentionBlock( + model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + AttentionBlock( + model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False)) - self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1)) + self.unconditioned_embedding = nn.Parameter( + torch.randn(1, model_channels, 1)) self.conditioning_timestep_integrator = TimestepEmbedSequential( DiffusionLayer(model_channels, dropout, num_heads), DiffusionLayer(model_channels, dropout, num_heads), DiffusionLayer(model_channels, dropout, num_heads), ) - self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1) - self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) + self.integrating_conv = nn.Conv1d( + model_channels*2, model_channels, kernel_size=1) + self.mel_head = nn.Conv1d( + model_channels, in_channels, kernel_size=3, padding=1) self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] + [ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)]) @@ -232,7 +258,8 @@ class DiffusionTtsFlat(nn.Module): self.out = nn.Sequential( normalization(model_channels), nn.SiLU(), - zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), + zero_module(conv_nd(1, model_channels, + out_channels, 3, padding=1)), ) if freeze_everything_except_autoregressive_inputs: @@ -263,25 +290,30 @@ class DiffusionTtsFlat(nn.Module): conditioning_input.shape) == 3 else conditioning_input conds = [] for j in range(speech_conditioning_input.shape[1]): - conds.append(self.contextual_embedder(speech_conditioning_input[:, j])) + conds.append(self.contextual_embedder( + speech_conditioning_input[:, j])) conds = torch.cat(conds, dim=-1) cond_emb = conds.mean(dim=-1) cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1) if is_latent(aligned_conditioning): code_emb = self.latent_conditioner(aligned_conditioning) else: - code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1) + code_emb = self.code_embedding( + aligned_conditioning).permute(0, 2, 1) code_emb = self.code_converter(code_emb) - code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1) + code_emb = self.code_norm( + code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1) - unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device) + unconditioned_batches = torch.zeros( + (code_emb.shape[0], 1, 1), device=code_emb.device) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. if self.training and self.unconditioned_percentage > 0: unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device) < self.unconditioned_percentage code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1), code_emb) - expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest') + expanded_code_emb = F.interpolate( + code_emb, size=expected_seq_len, mode='nearest') if not return_code_pred: return expanded_code_emb @@ -291,7 +323,6 @@ class DiffusionTtsFlat(nn.Module): mel_pred = mel_pred * unconditioned_batches.logical_not() return expanded_code_emb, mel_pred - def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False): """ Apply the model to an input batch. @@ -304,27 +335,36 @@ class DiffusionTtsFlat(nn.Module): :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered. :return: an [N x C x ...] Tensor of outputs. """ - assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_input is not None) - assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive. + assert precomputed_aligned_embeddings is not None or ( + aligned_conditioning is not None and conditioning_input is not None) + # These two are mutually exclusive. + assert not ( + return_code_pred and precomputed_aligned_embeddings is not None) unused_params = list(self.mel_head.parameters()) if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) - unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) + code_emb = self.unconditioned_embedding.repeat( + x.shape[0], 1, x.shape[-1]) + unused_params.extend( + list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) unused_params.extend(list(self.latent_conditioner.parameters())) else: if precomputed_aligned_embeddings is not None: code_emb = precomputed_aligned_embeddings else: - code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, x.shape[-1], True) + code_emb, mel_pred = self.timestep_independent( + aligned_conditioning, conditioning_input, x.shape[-1], True) if is_latent(aligned_conditioning): - unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) + unused_params.extend( + list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) else: - unused_params.extend(list(self.latent_conditioner.parameters())) + unused_params.extend( + list(self.latent_conditioner.parameters())) unused_params.append(self.unconditioned_embedding) - time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + time_emb = self.time_embed( + timestep_embedding(timesteps, self.model_channels)) code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) x = self.inp_block(x) x = torch.cat([x, code_emb], dim=1) @@ -356,10 +396,12 @@ class DiffusionTtsFlat(nn.Module): conditioning_input.shape) == 3 else conditioning_input conds = [] for j in range(speech_conditioning_input.shape[1]): - conds.append(self.contextual_embedder(speech_conditioning_input[:, j])) + conds.append(self.contextual_embedder( + speech_conditioning_input[:, j])) conds = torch.cat(conds, dim=-1) return conds.mean(dim=-1) + @register_model def register_diffusion_tts_flat(opt_net, opt): return DiffusionTtsFlat(**opt_net['kwargs']) @@ -367,17 +409,17 @@ def register_diffusion_tts_flat(opt_net, opt): if __name__ == '__main__': clip = torch.randn(2, 100, 400) - aligned_latent = torch.randn(2,388,512) - aligned_sequence = torch.randint(0,8192,(2,100)) + aligned_latent = torch.randn(2, 388, 512) + aligned_sequence = torch.randint(0, 8192, (2, 100)) cond = torch.randn(2, 100, 400) ts = torch.LongTensor([600, 600]) model = DiffusionTtsFlat(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, - in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=True, num_heads=16, - layer_drop=0, unconditioned_percentage=0) + in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=True, num_heads=16, + layer_drop=0, unconditioned_percentage=0) # Test with latent aligned conditioning - #o = model(clip, ts, aligned_latent, cond) + # o = model(clip, ts, aligned_latent, cond) # Test with sequence aligned conditioning - #o = model(clip, ts, aligned_sequence, cond) + # o = model(clip, ts, aligned_sequence, cond) with torch.no_grad(): proj = torch.randn(2, 100, 1024).cuda() @@ -389,4 +431,3 @@ if __name__ == '__main__': for k in range(1000): model(clip, ts, precomputed_aligned_embeddings=ti) print(f"Elapsed: {time()-start}") - diff --git a/dlas/models/audio/tts/unet_diffusion_vocoder.py b/dlas/models/audio/tts/unet_diffusion_vocoder.py index 8472b658..16b10779 100644 --- a/dlas/models/audio/tts/unet_diffusion_vocoder.py +++ b/dlas/models/audio/tts/unet_diffusion_vocoder.py @@ -1,11 +1,16 @@ -from models.diffusion.fp16_util import convert_module_to_f32, convert_module_to_f16 -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, ResBlock, TimestepEmbedSequential, \ - Downsample, Upsample import torch import torch.nn as nn -from trainer.networks import register_model +from dlas.models.diffusion.fp16_util import (convert_module_to_f16, + convert_module_to_f32) +from dlas.models.diffusion.nn import (conv_nd, linear, normalization, + timestep_embedding, zero_module) +from dlas.models.diffusion.unet_diffusion import (AttentionBlock, + AttentionPool2d, Downsample, + ResBlock, + TimestepEmbedSequential, + Upsample) +from dlas.trainer.networks import register_model class DiffusionVocoder(nn.Module): @@ -46,11 +51,12 @@ class DiffusionVocoder(nn.Module): in_channels=1, out_channels=2, # mean and variance spectrogram_channels=80, - spectrogram_conditioning_level=3, # Level at which spectrogram conditioning is applied to the waveform. + # Level at which spectrogram conditioning is applied to the waveform. + spectrogram_conditioning_level=3, dropout=0, # 106496 -> 26624 -> 6656 -> 16664 -> 416 -> 104 -> 26 for ~5secs@22050Hz channel_mult=(1, 2, 4, 8, 16, 32, 64), - attention_resolutions=(16,32,64), + attention_resolutions=(16, 32, 64), conv_resample=True, dims=1, num_classes=None, @@ -95,7 +101,8 @@ class DiffusionVocoder(nn.Module): self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding) + conv_nd(dims, in_channels, model_channels, + kernel_size, padding=padding) ) ] ) @@ -104,7 +111,8 @@ class DiffusionVocoder(nn.Module): ch = model_channels ds = 1 - spec_chs = channel_mult[spectrogram_conditioning_level] * model_channels + spec_chs = channel_mult[spectrogram_conditioning_level] * \ + model_channels self.spectrogram_conditioner = nn.Sequential( conv_nd(dims, self.spectrogram_channels, spec_chs, 1), normalization(spec_chs), @@ -119,7 +127,8 @@ class DiffusionVocoder(nn.Module): for level, mult in enumerate(channel_mult): if level == spectrogram_conditioning_level+1: - ch *= 2 # At this level, the spectrogram is concatenated onto the input. + # At this level, the spectrogram is concatenated onto the input. + ch *= 2 for _ in range(num_res_blocks): layers = [ @@ -248,7 +257,8 @@ class DiffusionVocoder(nn.Module): self.out = nn.Sequential( normalization(ch), nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), + zero_module(conv_nd(dims, model_channels, out_channels, + kernel_size, padding=padding)), ) def convert_to_fp16(self): @@ -278,14 +288,16 @@ class DiffusionVocoder(nn.Module): """ assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement. hs = [] - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = self.time_embed(timestep_embedding( + timesteps, self.model_channels)) conditioning = self.spectrogram_conditioner(spectrogram) h = x.type(self.dtype) for k, module in enumerate(self.input_blocks): h = module(h, emb) if k == self.input_block_injection_point: - cond = nn.functional.interpolate(conditioning, size=h.shape[-self.dims:], mode='nearest') + cond = nn.functional.interpolate( + conditioning, size=h.shape[-self.dims:], mode='nearest') h = torch.cat([h, cond], dim=1) h = self.convergence_conv(h) hs.append(h) diff --git a/dlas/models/audio/tts/unet_diffusion_vocoder_with_ref.py b/dlas/models/audio/tts/unet_diffusion_vocoder_with_ref.py index 66dfd52b..6c22ac98 100644 --- a/dlas/models/audio/tts/unet_diffusion_vocoder_with_ref.py +++ b/dlas/models/audio/tts/unet_diffusion_vocoder_with_ref.py @@ -1,11 +1,14 @@ -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import AttentionBlock, ResBlock, TimestepEmbedSequential, \ - Downsample, Upsample import torch import torch.nn as nn -from models.audio.tts.mini_encoder import AudioMiniEncoder -from trainer.networks import register_model +from dlas.models.audio.tts.mini_encoder import AudioMiniEncoder +from dlas.models.diffusion.nn import (conv_nd, linear, normalization, + timestep_embedding, zero_module) +from dlas.models.diffusion.unet_diffusion import (AttentionBlock, Downsample, + ResBlock, + TimestepEmbedSequential, + Upsample) +from dlas.trainer.networks import register_model class DiscreteSpectrogramConditioningBlock(nn.Module): @@ -23,6 +26,7 @@ class DiscreteSpectrogramConditioningBlock(nn.Module): :param x: bxcxS waveform latent :param codes: bxN discrete codes, N <= S """ + def forward(self, x, dvae_in): b, c, S = x.shape _, q, N = dvae_in.shape @@ -70,12 +74,12 @@ class DiffusionVocoderWithRef(nn.Module): discrete_codes=512, dropout=0, # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K - channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48), + channel_mult=(1, 1.5, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48), num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2), # spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0) # attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1 spectrogram_conditioning_resolutions=(512,), - attention_resolutions=(512,1024,2048), + attention_resolutions=(512, 1024, 2048), conv_resample=True, dims=1, use_fp16=False, @@ -122,10 +126,11 @@ class DiffusionVocoderWithRef(nn.Module): self.conditioning_enabled = conditioning_inputs_provided if conditioning_inputs_provided: self.contextual_embedder = AudioMiniEncoder(in_channels, time_embed_dim, base_channels=32, depth=6, resnet_blocks=1, - attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) + attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) seqlyr = TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding) + conv_nd(dims, in_channels, model_channels, + kernel_size, padding=padding) ) seqlyr.level = 0 self.input_blocks = nn.ModuleList([seqlyr]) @@ -137,7 +142,8 @@ class DiffusionVocoderWithRef(nn.Module): for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)): if ds in spectrogram_conditioning_resolutions: - spec_cond_block = DiscreteSpectrogramConditioningBlock(discrete_codes, ch, 2 ** level) + spec_cond_block = DiscreteSpectrogramConditioningBlock( + discrete_codes, ch, 2 ** level) self.input_blocks.append(spec_cond_block) spectrogram_blocks.append(spec_cond_block) ch *= 2 @@ -172,21 +178,21 @@ class DiffusionVocoderWithRef(nn.Module): if level != len(channel_mult) - 1: out_ch = ch upblk = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - kernel_size=kernel_size, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor - ) + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + kernel_size=kernel_size, ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor + ) + ) upblk.level = 2 ** level self.input_blocks.append(upblk) ch = out_ch @@ -270,7 +276,8 @@ class DiffusionVocoderWithRef(nn.Module): self.out = nn.Sequential( normalization(ch), nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), + zero_module(conv_nd(dims, model_channels, out_channels, + kernel_size, padding=padding)), ) if freeze_layers_below is not None: @@ -297,7 +304,8 @@ class DiffusionVocoderWithRef(nn.Module): del p.DO_NOT_TRAIN p.requires_grad = True unfrozen_params += 1 - print(f"freeze_layers_below specified. Training a total of {unfrozen_params} parameters.") + print( + f"freeze_layers_below specified. Training a total of {unfrozen_params} parameters.") def forward(self, x, timesteps, spectrogram, conditioning_input=None): """ @@ -313,7 +321,8 @@ class DiffusionVocoderWithRef(nn.Module): assert conditioning_input is not None hs = [] - emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb1 = self.time_embed(timestep_embedding( + timesteps, self.model_channels)) if self.conditioning_enabled: emb2 = self.contextual_embedder(conditioning_input) emb = emb1 + emb2 @@ -363,19 +372,20 @@ if __name__ == '__main__': move_all_layers_down(path, 'diffuse_new_lyr.pth', layers_to_be_added=2) clip = torch.randn(2, 1, 40960) - spec = torch.randn(2,80,160) + spec = torch.randn(2, 80, 160) cond = torch.randn(2, 1, 40960) ts = torch.LongTensor([555, 556]) - model = DiffusionVocoderWithRef(model_channels=128, channel_mult=[1,1,1.5,2, 3, 4, 6, 8, 8, 8, 8 ], - num_res_blocks=[1,2, 2, 2, 2, 2, 2, 2, 2, 1, 1 ], spectrogram_conditioning_resolutions=[2,512], - dropout=.05, attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2, + model = DiffusionVocoderWithRef(model_channels=128, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8], + num_res_blocks=[1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1], spectrogram_conditioning_resolutions=[2, 512], + dropout=.05, attention_resolutions=[512, 1024], num_heads=4, kernel_size=3, scale_factor=2, conditioning_inputs_provided=True, conditioning_input_dim=80, time_embed_dim_multiplier=4, discrete_codes=80, freeze_layers_below=1) - loading_errors = model.load_state_dict(torch.load('diffuse_new_lyr.pth'), strict=False) + loading_errors = model.load_state_dict( + torch.load('diffuse_new_lyr.pth'), strict=False) new_params = loading_errors.missing_keys new_params_trained = [] existing_params_trained = [] - for n,p in model.named_parameters(): + for n, p in model.named_parameters(): if not hasattr(p, 'DO_NOT_TRAIN'): if n in new_params: new_params_trained.append(n) diff --git a/dlas/models/audio/tts/unified_voice2.py b/dlas/models/audio/tts/unified_voice2.py index f03f69e2..dfcfcb29 100644 --- a/dlas/models/audio/tts/unified_voice2.py +++ b/dlas/models/audio/tts/unified_voice2.py @@ -1,24 +1,26 @@ import torch import torch.nn as nn import torch.nn.functional as F - from transformers import GPT2Config, GPT2PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.models.gpt2.modeling_gpt2 import GPT2Attention -from transformers.utils.model_parallel_utils import get_device_map, assert_device_map +from transformers.utils.model_parallel_utils import (assert_device_map, + get_device_map) -from models.arch_util import AttentionBlock -from models.audio.tts.transformer_builders import build_hf_gpt_transformer -from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb -from trainer.networks import register_model -from utils.util import opt_get +import dlas.torch_intermediary as ml +from dlas.models.arch_util import AttentionBlock +from dlas.models.audio.tts.transformer_builders import build_hf_gpt_transformer +from dlas.models.lucidrains.x_transformers import (RotaryEmbedding, + apply_rotary_pos_emb) +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get -import torch_intermediary as ml class ResBlock(nn.Module): """ Basic residual convolutional block that uses GroupNorm. """ + def __init__(self, chan): super().__init__() self.net = nn.Sequential( @@ -48,7 +50,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel): def parallelize(self, device_map=None): self.device_map = ( - get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + get_device_map(len(self.transformer.h), + range(torch.cuda.device_count())) if device_map is None else device_map ) @@ -121,7 +124,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel): ): assert self.cached_mel_emb is not None assert inputs_embeds is None # Not supported by this inference model. - assert labels is None # Training not supported by this inference model. + # Training not supported by this inference model. + assert labels is None return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Create embedding @@ -131,13 +135,15 @@ class GPT2InferenceModel(GPT2PreTrainedModel): text_emb = self.embeddings(text_inputs) text_emb = text_emb + self.text_pos_embedding(text_emb) if self.cached_mel_emb.shape[0] != text_emb.shape[0]: - mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0) + mel_emb = self.cached_mel_emb.repeat_interleave( + text_emb.shape[0]//self.cached_mel_emb.shape[0], 0) else: mel_emb = self.cached_mel_emb emb = torch.cat([mel_emb, text_emb], dim=1) else: emb = self.embeddings(input_ids) - emb = emb + self.text_pos_embedding.get_fixed_embedding(attention_mask.shape[1]-mel_len, attention_mask.device) + emb = emb + self.text_pos_embedding.get_fixed_embedding( + attention_mask.shape[1]-mel_len, attention_mask.device) transformer_outputs = self.transformer( inputs_embeds=emb, @@ -182,7 +188,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel): called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. """ return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past) for layer_past in past ) @@ -199,7 +206,8 @@ class ConditioningEncoder(nn.Module): attn = [] self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing)) + attn.append(AttentionBlock(embedding_dim, + num_attn_heads, do_checkpoint=do_checkpointing)) self.attn = nn.Sequential(*attn) self.dim = embedding_dim self.do_checkpointing = do_checkpointing @@ -219,23 +227,27 @@ class MelEncoder(nn.Module): super().__init__() self.channels = channels self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1), - nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]), - nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1), + nn.Sequential( + *[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels//4, channels//2, + kernel_size=3, stride=2, padding=1), nn.GroupNorm(channels//16, channels//2), nn.ReLU(), - nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]), - nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1), + nn.Sequential( + *[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels//2, channels, + kernel_size=3, stride=2, padding=1), nn.GroupNorm(channels//8, channels), nn.ReLU(), - nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), + nn.Sequential( + *[ResBlock(channels) for _ in range(resblocks_per_reduction)]), ) self.reduction = 4 - def forward(self, x): for e in self.encoder: x = e(x) - return x.permute(0,2,1) + return x.permute(0, 2, 1) class UnifiedVoice(nn.Module): @@ -276,25 +288,32 @@ class UnifiedVoice(nn.Module): self.layers = layers self.heads = heads self.max_conditioning_inputs = max_conditioning_inputs - self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens+2+self.max_conditioning_inputs + self.max_mel_tokens = -1 if max_mel_tokens == - \ + 1 else max_mel_tokens+2+self.max_conditioning_inputs self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens+2 self.model_dim = model_dim self.mel_length_compression = mel_length_compression - self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) + self.conditioning_encoder = ConditioningEncoder( + 80, model_dim, num_attn_heads=heads) self.average_conditioning_embeddings = average_conditioning_embeddings - self.tortoise_compat = tortoise_compat # credit to https://github.com/152334H/DL-Art-School/commit/ae80992817059acf6eef38a680efa5124cee570b + # credit to https://github.com/152334H/DL-Art-School/commit/ae80992817059acf6eef38a680efa5124cee570b + self.tortoise_compat = tortoise_compat # nn.Embedding self.text_embedding = ml.Embedding(self.number_text_tokens, model_dim) if use_mel_codes_as_input: # nn.Embedding self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim) else: - self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) + self.mel_embedding = MelEncoder( + model_dim, resblocks_per_reduction=1) self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ - build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing) + build_hf_gpt_transformer( + layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing) if train_solo_embeddings: - self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) - self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) + self.mel_solo_embedding = nn.Parameter( + torch.randn(1, 1, model_dim) * .02, requires_grad=True) + self.text_solo_embedding = nn.Parameter( + torch.randn(1, 1, model_dim) * .02, requires_grad=True) else: self.mel_solo_embedding = 0 self.text_solo_embedding = 0 @@ -303,7 +322,6 @@ class UnifiedVoice(nn.Module): self.text_head = ml.Linear(model_dim, self.number_text_tokens) self.mel_head = ml.Linear(model_dim, self.number_mel_codes) - # Initialize the embeddings per the GPT-2 scheme embeddings = [self.text_embedding] if use_mel_codes_as_input: @@ -328,8 +346,8 @@ class UnifiedVoice(nn.Module): } def build_aligned_inputs_and_targets(self, input, start_token, stop_token): - inp = F.pad(input, (1,0), value=start_token) - tar = F.pad(input, (0,1), value=stop_token) + inp = F.pad(input, (1, 0), value=start_token) + tar = F.pad(input, (0, 1), value=stop_token) return inp, tar def set_mel_padding(self, mel_input_tokens, wav_lengths): @@ -341,22 +359,26 @@ class UnifiedVoice(nn.Module): # Set padding areas within MEL (currently it is coded with the MEL code for ). mel_lengths = wav_lengths // self.mel_length_compression for b in range(len(mel_lengths)): - actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token. + # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token. + actual_end = mel_lengths[b] + 1 if actual_end < mel_input_tokens.shape[-1]: mel_input_tokens[b, actual_end:] = self.stop_mel_token return mel_input_tokens def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False): if second_inputs is not None: - emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) + emb = torch.cat([speech_conditioning_inputs, + first_inputs, second_inputs], dim=1) else: emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1) - gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) + gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, + output_attentions=get_attns) if get_attns: return gpt_out.attentions - enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input + # The first logit is tied to the speech_conditioning_input + enc = gpt_out.last_hidden_state[:, 1:] enc = self.final_norm(enc) if return_latent: @@ -364,11 +386,11 @@ class UnifiedVoice(nn.Module): first_logits = enc[:, :first_inputs.shape[1]] first_logits = first_head(first_logits) - first_logits = first_logits.permute(0,2,1) + first_logits = first_logits.permute(0, 2, 1) if second_inputs is not None: second_logits = enc[:, -second_inputs.shape[1]:] second_logits = second_head(second_logits) - second_logits = second_logits.permute(0,2,1) + second_logits = second_logits.permute(0, 2, 1) return first_logits, second_logits else: return first_logits @@ -394,24 +416,31 @@ class UnifiedVoice(nn.Module): # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by # chopping the inputs by the maximum actual length. max_text_len = text_lengths.max() - text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token) + text_inputs = F.pad( + text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token) max_mel_len = wav_lengths.max() // self.mel_length_compression - mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token) + mel_codes = F.pad(mel_codes[:, :max_mel_len], + (0, 1), value=self.stop_mel_token) if raw_mels is not None: raw_mels = raw_mels[:, :, :max_mel_len*4] mel_codes = self.set_mel_padding(mel_codes, wav_lengths) - speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input + speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len( + speech_conditioning_input.shape) == 3 else speech_conditioning_input conds = [] for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds.append(self.conditioning_encoder( + speech_conditioning_input[:, j])) conds = torch.stack(conds, dim=1) if self.average_conditioning_embeddings: conds = conds.mean(dim=1).unsqueeze(1) - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) - mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) + text_inputs, text_targets = self.build_aligned_inputs_and_targets( + text_inputs, self.start_text_token, self.stop_text_token) + text_emb = self.text_embedding( + text_inputs) + self.text_pos_embedding(text_inputs) + mel_codes, mel_targets = self.build_aligned_inputs_and_targets( + mel_codes, self.start_mel_token, self.stop_mel_token) if raw_mels is not None: mel_inp = F.pad(raw_mels, (0, 8)) else: @@ -421,13 +450,17 @@ class UnifiedVoice(nn.Module): sub = -2 if self.tortoise_compat else -1 if text_first: - text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent) + text_logits, mel_logits = self.get_logits( + conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent) if return_latent: - return mel_logits[:, :sub] # Despite the name, these are not logits. + # Despite the name, these are not logits. + return mel_logits[:, :sub] else: - mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent) + mel_logits, text_logits = self.get_logits( + conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent) if return_latent: - return text_logits[:, :sub] # Despite the name, these are not logits + # Despite the name, these are not logits + return text_logits[:, :sub] if return_attentions: return mel_logits @@ -443,18 +476,23 @@ class UnifiedVoice(nn.Module): # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by # chopping the inputs by the maximum actual length. max_text_len = text_lengths.max() - text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token) + text_inputs = F.pad( + text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token) - speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input + speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len( + speech_conditioning_input.shape) == 3 else speech_conditioning_input conds = [] for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds.append(self.conditioning_encoder( + speech_conditioning_input[:, j])) conds = torch.stack(conds, dim=1) if self.average_conditioning_embeddings: conds = conds.mean(dim=1).unsqueeze(1) - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding + text_inputs, text_targets = self.build_aligned_inputs_and_targets( + text_inputs, self.start_text_token, self.stop_text_token) + text_emb = self.text_embedding( + text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding text_logits = self.get_logits(conds, text_emb, self.text_head) loss_text = F.cross_entropy(text_logits, text_targets.long()) return loss_text.mean() @@ -468,26 +506,31 @@ class UnifiedVoice(nn.Module): # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by # chopping the inputs by the maximum actual length. max_mel_len = wav_lengths.max() // self.mel_length_compression - mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token) + mel_codes = F.pad(mel_codes[:, :max_mel_len], + (0, 1), value=self.stop_mel_token) mel_codes = self.set_mel_padding(mel_codes, wav_lengths) if raw_mels is not None: raw_mels = raw_mels[:, :, :max_mel_len*4] - speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input + speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len( + speech_conditioning_input.shape) == 3 else speech_conditioning_input conds = [] for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds.append(self.conditioning_encoder( + speech_conditioning_input[:, j])) conds = torch.stack(conds, dim=1) if self.average_conditioning_embeddings: conds = conds.mean(dim=1).unsqueeze(1) - mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) + mel_codes, mel_targets = self.build_aligned_inputs_and_targets( + mel_codes, self.start_mel_token, self.stop_mel_token) if raw_mels is not None: mel_inp = F.pad(raw_mels, (0, 4)) else: mel_inp = mel_codes mel_emb = self.mel_embedding(mel_inp) - mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding + mel_emb = mel_emb + \ + self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding mel_logits = self.get_logits(conds, mel_emb, self.mel_head) loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) return loss_mel.mean() @@ -507,17 +550,22 @@ class UnifiedVoice(nn.Module): n_head=self.heads, gradient_checkpointing=False, use_cache=True) - self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) + self.inference_model = GPT2InferenceModel( + gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) self.gpt.wte = self.mel_embedding text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + text_inputs, text_targets = self.build_aligned_inputs_and_targets( + text_inputs, self.start_text_token, self.stop_text_token) + text_emb = self.text_embedding( + text_inputs) + self.text_pos_embedding(text_inputs) - speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input + speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len( + speech_conditioning_input.shape) == 3 else speech_conditioning_input conds = [] for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds.append(self.conditioning_encoder( + speech_conditioning_input[:, j])) conds = torch.stack(conds, dim=1) if self.average_conditioning_embeddings: conds = conds.mean(dim=1).unsqueeze(1) @@ -525,8 +573,9 @@ class UnifiedVoice(nn.Module): emb = torch.cat([conds, text_emb], dim=1) self.inference_model.store_mel_emb(emb) - fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device) - fake_inputs[:,-1] = self.start_mel_token + fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), + fill_value=1, dtype=torch.long, device=text_inputs.device) + fake_inputs[:, -1] = self.start_mel_token gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, max_length=seq_length, output_attentions=return_attentions, return_dict_in_generate=True, **hf_generate_kwargs) @@ -535,15 +584,16 @@ class UnifiedVoice(nn.Module): else: return gen.sequences[:, fake_inputs.shape[1]:] - # Turns the (utterly insane) output of HF.generate() into a far more sane output: # [tensors(B,H,S,S)]. Outer=layers, B=batch,H=head,S=sequence + def make_hf_generate_attentions_sane(self, attentions): layers = [[] for _ in range(len(attentions[0]))] full_attention_size = attentions[-1][0].shape[-1] for i, gen in enumerate(attentions): for j, lyr in enumerate(gen): - layers[j].append(F.pad(lyr, (0, full_attention_size - lyr.shape[-1]))) + layers[j].append( + F.pad(lyr, (0, full_attention_size - lyr.shape[-1]))) catted = [] for lyr in layers: catted.append(torch.cat(lyr, dim=2)) @@ -562,18 +612,21 @@ class UnifiedVoice(nn.Module): for l, layer in enumerate(attentions): dec_context = layer[:, :, num_context:, :] # Mask out everything that isn't text (including the start token, which gets a LOT of attention) - dec_context[:,:,:,:text_padding+1] = 0 - dec_context[:,:,:,num_context:] = 0 + dec_context[:, :, :, :text_padding+1] = 0 + dec_context[:, :, :, num_context:] = 0 for h in range(dec_context.shape[1]): - dec_context_indices = torch.argmax(dec_context[0,h], dim=-1) + dec_context_indices = torch.argmax(dec_context[0, h], dim=-1) print(f'layer_{l};head_{h}: ' + str(dec_context_indices)) for t, att_tok in enumerate(attentions): - combined_attention_weights = torch.zeros((codes.shape[0], num_text), device=codes.device) + combined_attention_weights = torch.zeros( + (codes.shape[0], num_text), device=codes.device) for lyr in att_tok: - token_to_text_attentions = lyr[:, :, -1, text_padding:(text_padding + num_text)].sum(dim=1) + token_to_text_attentions = lyr[:, :, -1, + text_padding:(text_padding + num_text)].sum(dim=1) combined_attention_weights = combined_attention_weights + token_to_text_attentions break - most_attended_text_token = combined_attention_weights.argmax(dim=-1) + most_attended_text_token = combined_attention_weights.argmax( + dim=-1) results[:, t] = most_attended_text_token eos_token_mask = (codes != self.stop_mel_token) return results * eos_token_mask @@ -585,10 +638,11 @@ def register_unified_voice2(opt_net, opt): if __name__ == '__main__': - gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True) + gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, + max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True) l = gpt(torch.randn(2, 3, 80, 800), - torch.randint(high=256, size=(2,120)), + torch.randint(high=256, size=(2, 120)), torch.tensor([32, 120]), - torch.randint(high=8192, size=(2,250)), - torch.tensor([250*256,195*256])) - #gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80])) + torch.randint(high=8192, size=(2, 250)), + torch.tensor([250*256, 195*256])) + # gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80])) diff --git a/dlas/models/audio/tts/unified_voice3.py b/dlas/models/audio/tts/unified_voice3.py index 49e2258c..8aa80c47 100644 --- a/dlas/models/audio/tts/unified_voice3.py +++ b/dlas/models/audio/tts/unified_voice3.py @@ -1,24 +1,26 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torch_intermediary as ml - from transformers import GPT2Config, GPT2PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.models.gpt2.modeling_gpt2 import GPT2Attention -from transformers.utils.model_parallel_utils import get_device_map, assert_device_map +from transformers.utils.model_parallel_utils import (assert_device_map, + get_device_map) -from models.arch_util import AttentionBlock -from models.audio.tts.transformer_builders import build_hf_gpt_transformer -from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb -from trainer.networks import register_model -from utils.util import opt_get +import dlas.torch_intermediary as ml +from dlas.models.arch_util import AttentionBlock +from dlas.models.audio.tts.transformer_builders import build_hf_gpt_transformer +from dlas.models.lucidrains.x_transformers import (RotaryEmbedding, + apply_rotary_pos_emb) +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get class ResBlock(nn.Module): """ Basic residual convolutional block that uses GroupNorm. """ + def __init__(self, chan): super().__init__() self.net = nn.Sequential( @@ -48,7 +50,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel): def parallelize(self, device_map=None): self.device_map = ( - get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + get_device_map(len(self.transformer.h), + range(torch.cuda.device_count())) if device_map is None else device_map ) @@ -120,7 +123,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel): ): assert self.cached_prior_emb is not None assert inputs_embeds is None # Not supported by this inference model. - assert labels is None # Training not supported by this inference model. + # Training not supported by this inference model. + assert labels is None return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Create embedding @@ -128,15 +132,18 @@ class GPT2InferenceModel(GPT2PreTrainedModel): if input_ids.shape[1] != 1: posterior_inputs = input_ids[:, prior_len:] posterior_emb = self.embeddings(posterior_inputs) - posterior_emb = posterior_emb + self.posterior_pos_embedding(posterior_emb) + posterior_emb = posterior_emb + \ + self.posterior_pos_embedding(posterior_emb) if self.cached_prior_emb.shape[0] != posterior_emb.shape[0]: - prior_emb = self.cached_prior_emb.repeat_interleave(posterior_emb.shape[0] // self.cached_prior_emb.shape[0], 0) + prior_emb = self.cached_prior_emb.repeat_interleave( + posterior_emb.shape[0] // self.cached_prior_emb.shape[0], 0) else: prior_emb = self.cached_prior_emb emb = torch.cat([prior_emb, posterior_emb], dim=1) else: emb = self.embeddings(input_ids) - emb = emb + self.posterior_pos_embedding.get_fixed_embedding(attention_mask.shape[1] - prior_len, attention_mask.device) + emb = emb + self.posterior_pos_embedding.get_fixed_embedding( + attention_mask.shape[1] - prior_len, attention_mask.device) transformer_outputs = self.transformer( inputs_embeds=emb, @@ -181,7 +188,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel): called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. """ return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past) for layer_past in past ) @@ -198,7 +206,8 @@ class ConditioningEncoder(nn.Module): attn = [] self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing)) + attn.append(AttentionBlock(embedding_dim, + num_attn_heads, do_checkpoint=do_checkpointing)) self.attn = nn.Sequential(*attn) self.dim = embedding_dim self.do_checkpointing = do_checkpointing @@ -218,23 +227,27 @@ class MelEncoder(nn.Module): super().__init__() self.channels = channels self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1), - nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]), - nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1), + nn.Sequential( + *[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels//4, channels//2, + kernel_size=3, stride=2, padding=1), nn.GroupNorm(channels//16, channels//2), nn.ReLU(), - nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]), - nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1), + nn.Sequential( + *[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels//2, channels, + kernel_size=3, stride=2, padding=1), nn.GroupNorm(channels//8, channels), nn.ReLU(), - nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), + nn.Sequential( + *[ResBlock(channels) for _ in range(resblocks_per_reduction)]), ) self.reduction = 4 - def forward(self, x): for e in self.encoder: x = e(x) - return x.permute(0,2,1) + return x.permute(0, 2, 1) class UnifiedVoice(nn.Module): @@ -260,7 +273,8 @@ class UnifiedVoice(nn.Module): super().__init__() self.number_text_tokens = number_text_tokens - self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token + self.start_text_token = number_text_tokens * \ + types if start_text_token is None else start_text_token self.stop_text_token = 0 self.number_mel_codes = number_mel_codes self.start_mel_token = start_mel_token @@ -268,17 +282,21 @@ class UnifiedVoice(nn.Module): self.layers = layers self.heads = heads self.max_conditioning_inputs = max_conditioning_inputs - self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens+2+self.max_conditioning_inputs + self.max_mel_tokens = -1 if max_mel_tokens == - \ + 1 else max_mel_tokens+2+self.max_conditioning_inputs self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens+2 self.model_dim = model_dim self.mel_length_compression = mel_length_compression - self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) + self.conditioning_encoder = ConditioningEncoder( + 80, model_dim, num_attn_heads=heads) # nn.Embedding - self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim) + self.text_embedding = ml.Embedding( + self.number_text_tokens*types+1, model_dim) # nn.Embedding self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim) self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ - build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing) + build_hf_gpt_transformer( + layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing) self.final_norm = nn.LayerNorm(model_dim) self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1) @@ -306,8 +324,8 @@ class UnifiedVoice(nn.Module): } def build_aligned_inputs_and_targets(self, input, start_token, stop_token): - inp = F.pad(input, (1,0), value=start_token) - tar = F.pad(input, (0,1), value=stop_token) + inp = F.pad(input, (1, 0), value=start_token) + tar = F.pad(input, (0, 1), value=stop_token) return inp, tar def set_mel_padding(self, mel_input_tokens, wav_lengths): @@ -319,45 +337,48 @@ class UnifiedVoice(nn.Module): # Set padding areas within MEL (currently it is coded with the MEL code for ). mel_lengths = wav_lengths // self.mel_length_compression for b in range(len(mel_lengths)): - actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token. + # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token. + actual_end = mel_lengths[b] + 1 if actual_end < mel_input_tokens.shape[-1]: mel_input_tokens[b, actual_end:] = self.stop_mel_token return mel_input_tokens def get_logits(self, speech_conditioning_inputs, text_inputs, text_head, mel_inputs, mel_head, aligned_head, return_latent=False): if mel_inputs is not None: - emb = torch.cat([speech_conditioning_inputs, text_inputs, mel_inputs], dim=1) + emb = torch.cat([speech_conditioning_inputs, + text_inputs, mel_inputs], dim=1) else: emb = torch.cat([speech_conditioning_inputs, text_inputs], dim=1) gpt_out = self.gpt(inputs_embeds=emb, return_dict=True) - enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input + # The first logit is tied to the speech_conditioning_input + enc = gpt_out.last_hidden_state[:, 1:] enc = self.final_norm(enc) if return_latent: return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1] + text_inputs.shape[1]], enc[:, -mel_inputs.shape[1]:] text_logits = enc[:, :text_inputs.shape[1]] - text_logits = text_head(text_logits).permute(0,2,1) + text_logits = text_head(text_logits).permute(0, 2, 1) mel_logits = enc[:, -mel_inputs.shape[1]:] - aligned_logits = aligned_head(mel_logits).permute(0,2,1) - mel_logits = mel_head(mel_logits).permute(0,2,1) + aligned_logits = aligned_head(mel_logits).permute(0, 2, 1) + mel_logits = mel_head(mel_logits).permute(0, 2, 1) return text_logits, mel_logits, aligned_logits - def get_conditioning_latent(self, speech_conditioning_input): - speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input + speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len( + speech_conditioning_input.shape) == 3 else speech_conditioning_input conds = [] for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds.append(self.conditioning_encoder( + speech_conditioning_input[:, j])) conds = torch.stack(conds, dim=1) conds = conds.mean(dim=1).unsqueeze(1) return conds - def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, aligned_codes, types=None, return_latent=False): """ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode @@ -376,19 +397,22 @@ class UnifiedVoice(nn.Module): if types is not None: text_inputs = text_inputs * (1+types).unsqueeze(-1) - conds = self.get_conditioning_latent(speech_conditioning_input) ac_expansion_factor = mel_codes.shape[-1] / aligned_codes.shape[-1] aligned_codes = aligned_codes.repeat(1, ac_expansion_factor) - _, aligned_targets = self.build_aligned_inputs_and_targets(aligned_codes, 0, 0) + _, aligned_targets = self.build_aligned_inputs_and_targets( + aligned_codes, 0, 0) - text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token) - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + text_inputs, text_targets = self.build_aligned_inputs_and_targets( + text_inputs, self.start_text_token, self.stop_text_token) + text_emb = self.text_embedding( + text_inputs) + self.text_pos_embedding(text_inputs) mel_codes = self.set_mel_padding(mel_codes, wav_lengths) - mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) + mel_codes, mel_targets = self.build_aligned_inputs_and_targets( + mel_codes, self.start_mel_token, self.stop_mel_token) mel_inp = mel_codes mel_emb = self.mel_embedding(mel_inp) mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) @@ -396,7 +420,8 @@ class UnifiedVoice(nn.Module): text_logits, mel_logits, aligned_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, self.aligned_head, return_latent=return_latent) if return_latent: - return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. + # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. + return mel_logits[:, :-2] loss_text = F.cross_entropy(text_logits, text_targets.long()) loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) @@ -418,25 +443,31 @@ class UnifiedVoice(nn.Module): n_head=self.heads, gradient_checkpointing=False, use_cache=True) - self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) + self.inference_model = GPT2InferenceModel( + gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) self.gpt.wte = self.mel_embedding text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + text_inputs, text_targets = self.build_aligned_inputs_and_targets( + text_inputs, self.start_text_token, self.stop_text_token) + text_emb = self.text_embedding( + text_inputs) + self.text_pos_embedding(text_inputs) - speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input + speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len( + speech_conditioning_input.shape) == 3 else speech_conditioning_input conds = [] for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds.append(self.conditioning_encoder( + speech_conditioning_input[:, j])) conds = torch.stack(conds, dim=1) conds = conds.mean(dim=1).unsqueeze(1) emb = torch.cat([conds, text_emb], dim=1) self.inference_model.store_prior_emb(emb) - fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device) - fake_inputs[:,-1] = self.start_mel_token + fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), + fill_value=1, dtype=torch.long, device=text_inputs.device) + fake_inputs[:, -1] = self.start_mel_token gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, max_length=seq_length, return_dict_in_generate=True, **hf_generate_kwargs) @@ -449,11 +480,12 @@ def register_unified_voice3(opt_net, opt): if __name__ == '__main__': - gpt = UnifiedVoice(model_dim=256, heads=4, max_conditioning_inputs=4, types=2) - mel = torch.randint(high=8192, size=(2,250)) - ac = torch.randint(high=256, size=(2,250*1024//443)) + gpt = UnifiedVoice(model_dim=256, heads=4, + max_conditioning_inputs=4, types=2) + mel = torch.randint(high=8192, size=(2, 250)) + ac = torch.randint(high=256, size=(2, 250*1024//443)) l = gpt(torch.randn(2, 3, 80, 800), - torch.randint(high=256, size=(2,120)), + torch.randint(high=256, size=(2, 120)), torch.tensor([32, 120]), - mel, torch.tensor([250*256,195*256]), ac, + mel, torch.tensor([250*256, 195*256]), ac, types=torch.tensor([0, 1])) diff --git a/dlas/models/audio/tts/unified_voice4.py b/dlas/models/audio/tts/unified_voice4.py index 9d8a8568..b7f1724c 100644 --- a/dlas/models/audio/tts/unified_voice4.py +++ b/dlas/models/audio/tts/unified_voice4.py @@ -4,20 +4,23 @@ import torch.nn.functional as F from transformers import GPT2Config, GPT2PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.models.gpt2.modeling_gpt2 import GPT2Attention -from transformers.utils.model_parallel_utils import get_device_map, assert_device_map +from transformers.utils.model_parallel_utils import (assert_device_map, + get_device_map) -from models.arch_util import AttentionBlock -from models.audio.tts.transformer_builders import build_hf_gpt_transformer -from models.lucidrains.x_transformers import RotaryEmbedding, apply_rotary_pos_emb -from trainer.networks import register_model -from utils.util import opt_get -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.models.arch_util import AttentionBlock +from dlas.models.audio.tts.transformer_builders import build_hf_gpt_transformer +from dlas.models.lucidrains.x_transformers import (RotaryEmbedding, + apply_rotary_pos_emb) +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get class ResBlock(nn.Module): """ Basic residual convolutional block that uses GroupNorm. """ + def __init__(self, chan): super().__init__() self.net = nn.Sequential( @@ -47,7 +50,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel): def parallelize(self, device_map=None): self.device_map = ( - get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + get_device_map(len(self.transformer.h), + range(torch.cuda.device_count())) if device_map is None else device_map ) @@ -119,7 +123,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel): ): assert self.cached_prior_emb is not None assert inputs_embeds is None # Not supported by this inference model. - assert labels is None # Training not supported by this inference model. + # Training not supported by this inference model. + assert labels is None return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Create embedding @@ -127,15 +132,18 @@ class GPT2InferenceModel(GPT2PreTrainedModel): if input_ids.shape[1] != 1: posterior_inputs = input_ids[:, prior_len:] posterior_emb = self.embeddings(posterior_inputs) - posterior_emb = posterior_emb + self.posterior_pos_embedding(posterior_emb) + posterior_emb = posterior_emb + \ + self.posterior_pos_embedding(posterior_emb) if self.cached_prior_emb.shape[0] != posterior_emb.shape[0]: - prior_emb = self.cached_prior_emb.repeat_interleave(posterior_emb.shape[0] // self.cached_prior_emb.shape[0], 0) + prior_emb = self.cached_prior_emb.repeat_interleave( + posterior_emb.shape[0] // self.cached_prior_emb.shape[0], 0) else: prior_emb = self.cached_prior_emb emb = torch.cat([prior_emb, posterior_emb], dim=1) else: emb = self.embeddings(input_ids) - emb = emb + self.posterior_pos_embedding.get_fixed_embedding(attention_mask.shape[1] - prior_len, attention_mask.device) + emb = emb + self.posterior_pos_embedding.get_fixed_embedding( + attention_mask.shape[1] - prior_len, attention_mask.device) transformer_outputs = self.transformer( inputs_embeds=emb, @@ -180,7 +188,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel): called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. """ return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past) for layer_past in past ) @@ -197,7 +206,8 @@ class ConditioningEncoder(nn.Module): attn = [] self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing)) + attn.append(AttentionBlock(embedding_dim, + num_attn_heads, do_checkpoint=do_checkpointing)) self.attn = nn.Sequential(*attn) self.dim = embedding_dim self.do_checkpointing = do_checkpointing @@ -217,23 +227,27 @@ class MelEncoder(nn.Module): super().__init__() self.channels = channels self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1), - nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]), - nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1), + nn.Sequential( + *[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels//4, channels//2, + kernel_size=3, stride=2, padding=1), nn.GroupNorm(channels//16, channels//2), nn.ReLU(), - nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]), - nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1), + nn.Sequential( + *[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels//2, channels, + kernel_size=3, stride=2, padding=1), nn.GroupNorm(channels//8, channels), nn.ReLU(), - nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), + nn.Sequential( + *[ResBlock(channels) for _ in range(resblocks_per_reduction)]), ) self.reduction = 4 - def forward(self, x): for e in self.encoder: x = e(x) - return x.permute(0,2,1) + return x.permute(0, 2, 1) class UnifiedVoice(nn.Module): @@ -243,7 +257,8 @@ class UnifiedVoice(nn.Module): super().__init__() self.number_text_tokens = number_text_tokens - self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token + self.start_text_token = number_text_tokens * \ + types if start_text_token is None else start_text_token self.stop_text_token = 0 self.number_mel_codes = number_mel_codes self.start_mel_token = start_mel_token @@ -251,17 +266,21 @@ class UnifiedVoice(nn.Module): self.layers = layers self.heads = heads self.max_conditioning_inputs = max_conditioning_inputs - self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens+2+self.max_conditioning_inputs + self.max_mel_tokens = -1 if max_mel_tokens == - \ + 1 else max_mel_tokens+2+self.max_conditioning_inputs self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens+2 self.model_dim = model_dim self.mel_length_compression = mel_length_compression - self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) + self.conditioning_encoder = ConditioningEncoder( + 80, model_dim, num_attn_heads=heads) # nn.Embedding - self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim) + self.text_embedding = ml.Embedding( + self.number_text_tokens*types+1, model_dim) # nn.Embedding self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim) self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ - build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing) + build_hf_gpt_transformer( + layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing) self.final_norm = nn.LayerNorm(model_dim) self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1) @@ -289,8 +308,8 @@ class UnifiedVoice(nn.Module): } def build_aligned_inputs_and_targets(self, input, start_token, stop_token): - inp = F.pad(input, (1,0), value=start_token) - tar = F.pad(input, (0,1), value=stop_token) + inp = F.pad(input, (1, 0), value=start_token) + tar = F.pad(input, (0, 1), value=stop_token) return inp, tar def set_mel_padding(self, mel_input_tokens, wav_lengths): @@ -302,17 +321,20 @@ class UnifiedVoice(nn.Module): # Set padding areas within MEL (currently it is coded with the MEL code for ). mel_lengths = wav_lengths // self.mel_length_compression for b in range(len(mel_lengths)): - actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token. + # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token. + actual_end = mel_lengths[b] + 1 if actual_end < mel_input_tokens.shape[-1]: mel_input_tokens[b, actual_end:] = self.stop_mel_token return mel_input_tokens def get_logits(self, speech_conditioning_inputs, first_inputs, second_inputs, return_latent=False): - emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) + emb = torch.cat([speech_conditioning_inputs, + first_inputs, second_inputs], dim=1) gpt_out = self.gpt(inputs_embeds=emb, return_dict=True) - enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input + # The first logit is tied to the speech_conditioning_input + enc = gpt_out.last_hidden_state[:, 1:] enc = self.final_norm(enc) if return_latent: @@ -320,29 +342,29 @@ class UnifiedVoice(nn.Module): text_logits = enc[:, :first_inputs.shape[1]] text_logits = self.text_head(text_logits) - text_logits = text_logits.permute(0,2,1) + text_logits = text_logits.permute(0, 2, 1) mel_logits = enc[:, -second_inputs.shape[1]:] mel_logits = self.mel_head(mel_logits) - mel_logits = mel_logits.permute(0,2,1) + mel_logits = mel_logits.permute(0, 2, 1) alignment_logits = enc[:, -second_inputs.shape[1]:] alignment_logits = self.alignment_head(alignment_logits) - alignment_logits = alignment_logits.permute(0,2,1) + alignment_logits = alignment_logits.permute(0, 2, 1) return text_logits, mel_logits, alignment_logits - def get_conditioning_latent(self, speech_conditioning_input): - speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input + speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len( + speech_conditioning_input.shape) == 3 else speech_conditioning_input conds = [] for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds.append(self.conditioning_encoder( + speech_conditioning_input[:, j])) conds = torch.stack(conds, dim=1) conds = conds.mean(dim=1).unsqueeze(1) return conds - def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, ctc_codes, wav_lengths, types=None, return_latent=False): """ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode @@ -368,24 +390,30 @@ class UnifiedVoice(nn.Module): ctc_codes[b][j] = last_code else: last_code = ctc_codes[b][j] - alignment_targets = F.interpolate(ctc_codes.unsqueeze(1).float(), size=(mel_codes.shape[-1],), mode='nearest').long().squeeze() + alignment_targets = F.interpolate(ctc_codes.unsqueeze(1).float(), size=( + mel_codes.shape[-1],), mode='nearest').long().squeeze() mel_codes = self.set_mel_padding(mel_codes, wav_lengths) - text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token) - mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token) - alignment_targets = F.pad(alignment_targets, (0,2), value=0) + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token) + alignment_targets = F.pad(alignment_targets, (0, 2), value=0) conds = self.get_conditioning_latent(speech_conditioning_input) - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) - mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) + text_inputs, text_targets = self.build_aligned_inputs_and_targets( + text_inputs, self.start_text_token, self.stop_text_token) + text_emb = self.text_embedding( + text_inputs) + self.text_pos_embedding(text_inputs) + mel_codes, mel_targets = self.build_aligned_inputs_and_targets( + mel_codes, self.start_mel_token, self.stop_mel_token) mel_inp = mel_codes mel_emb = self.mel_embedding(mel_inp) mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) - text_logits, mel_logits, alignment_logits = self.get_logits(conds, text_emb, mel_emb, return_latent=return_latent) + text_logits, mel_logits, alignment_logits = self.get_logits( + conds, text_emb, mel_emb, return_latent=return_latent) if return_latent: - return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. + # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. + return mel_logits[:, :-2] loss_text = F.cross_entropy(text_logits, text_targets.long()) loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) @@ -407,25 +435,31 @@ class UnifiedVoice(nn.Module): n_head=self.heads, gradient_checkpointing=False, use_cache=True) - self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) + self.inference_model = GPT2InferenceModel( + gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) self.gpt.wte = self.mel_embedding text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + text_inputs, text_targets = self.build_aligned_inputs_and_targets( + text_inputs, self.start_text_token, self.stop_text_token) + text_emb = self.text_embedding( + text_inputs) + self.text_pos_embedding(text_inputs) - speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input + speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len( + speech_conditioning_input.shape) == 3 else speech_conditioning_input conds = [] for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds.append(self.conditioning_encoder( + speech_conditioning_input[:, j])) conds = torch.stack(conds, dim=1) conds = conds.mean(dim=1).unsqueeze(1) emb = torch.cat([conds, text_emb], dim=1) self.inference_model.store_prior_emb(emb) - fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device) - fake_inputs[:,-1] = self.start_mel_token + fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), + fill_value=1, dtype=torch.long, device=text_inputs.device) + fake_inputs[:, -1] = self.start_mel_token gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, max_length=seq_length, return_dict_in_generate=True, **hf_generate_kwargs) @@ -438,10 +472,11 @@ def register_unified_voice4(opt_net, opt): if __name__ == '__main__': - gpt = UnifiedVoice(model_dim=256, heads=4, max_conditioning_inputs=4, types=2) + gpt = UnifiedVoice(model_dim=256, heads=4, + max_conditioning_inputs=4, types=2) l = gpt(torch.randn(2, 3, 80, 800), - torch.randint(high=256, size=(2,120)), + torch.randint(high=256, size=(2, 120)), torch.tensor([32, 120]), - torch.randint(high=8192, size=(2,250)), - torch.tensor([250*256,195*256]), + torch.randint(high=8192, size=(2, 250)), + torch.tensor([250*256, 195*256]), types=torch.tensor([0, 1])) diff --git a/dlas/models/audio/tts/voice_voice_clip.py b/dlas/models/audio/tts/voice_voice_clip.py index 842f8b17..12efd109 100644 --- a/dlas/models/audio/tts/voice_voice_clip.py +++ b/dlas/models/audio/tts/voice_voice_clip.py @@ -1,14 +1,15 @@ import random + import torch import torch.nn as nn import torch.nn.functional as F from torch import einsum -from models.audio.tts.mini_encoder import AudioMiniEncoder -from trainer.injectors.spec_augment import spec_augment -from trainer.networks import register_model -from utils.util import opt_get -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.models.audio.tts.mini_encoder import AudioMiniEncoder +from dlas.trainer.injectors.spec_augment import spec_augment +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get def exists(val): @@ -17,7 +18,7 @@ def exists(val): def masked_mean(t, mask, dim=1): t = t.masked_fill(~mask[:, :, None], 0.) - return t.sum(dim = 1) / mask.sum(dim = 1)[..., None] + return t.sum(dim=1) / mask.sum(dim=1)[..., None] class VoiceCLIP(nn.Module): @@ -36,7 +37,8 @@ class VoiceCLIP(nn.Module): super().__init__() self.encoder = AudioMiniEncoder(80, encoder_output) if pretrained_encoder_dict_path is not None: - self.encoder.load_state_dict(torch.load(pretrained_encoder_dict_path)) + self.encoder.load_state_dict( + torch.load(pretrained_encoder_dict_path)) self.to_latent = ml.Linear(encoder_output, dim_latent, bias=False) self.temperature = nn.Parameter(torch.tensor(1.)) self.mel_compression_ratio = mel_compression_ratio @@ -47,7 +49,8 @@ class VoiceCLIP(nn.Module): speech_lengths, return_loss=True ): - half_length = min(speech_mels.shape[-1], torch.min(speech_lengths).item() // self.mel_compression_ratio) // 2 + half_length = min( + speech_mels.shape[-1], torch.min(speech_lengths).item() // self.mel_compression_ratio) // 2 # Extract two speech MELs from the same clip, apply some random noise to them and also apply specaugment to them. first_half = speech_mels[:, :, :half_length] @@ -81,7 +84,8 @@ class VoiceCLIP(nn.Module): second_emb = self.encoder(second_half) second_latents = self.to_latent(second_emb) - first_latents, second_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (first_latents, second_latents)) + first_latents, second_latents = map(lambda t: F.normalize( + t, p=2, dim=-1), (first_latents, second_latents)) temp = self.temperature.exp() @@ -90,8 +94,10 @@ class VoiceCLIP(nn.Module): return sim sim = einsum('i d, j d -> i j', first_latents, second_latents) * temp - labels = torch.arange(first_latents.shape[0], device=first_latents.device) - loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 + labels = torch.arange( + first_latents.shape[0], device=first_latents.device) + loss = (F.cross_entropy(sim, labels) + + F.cross_entropy(sim.t(), labels)) / 2 return loss def inference(self, speech_mels): @@ -111,6 +117,6 @@ def register_voice_to_voice_clip(opt_net, opt): if __name__ == '__main__': clip = VoiceCLIP() for k in range(1000): - clip(torch.randn((2,80,156)), - torch.randint(130*1024,156*1024,(2,)), - return_loss=True) \ No newline at end of file + clip(torch.randn((2, 80, 156)), + torch.randint(130*1024, 156*1024, (2,)), + return_loss=True) diff --git a/dlas/models/audio/tts/w2v_matcher.py b/dlas/models/audio/tts/w2v_matcher.py index 636a1ca0..dde1d6ac 100644 --- a/dlas/models/audio/tts/w2v_matcher.py +++ b/dlas/models/audio/tts/w2v_matcher.py @@ -3,11 +3,11 @@ import functools import torch import torch.nn as nn import torch.nn.functional as F -from x_transformers import Encoder, Decoder, ContinuousTransformerWrapper +from x_transformers import ContinuousTransformerWrapper, Decoder, Encoder -from models.audio.tts.mini_encoder import AudioMiniEncoder -from trainer.networks import register_model -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.models.audio.tts.mini_encoder import AudioMiniEncoder +from dlas.trainer.networks import register_model class CheckpointedLayer(nn.Module): @@ -15,13 +15,15 @@ class CheckpointedLayer(nn.Module): Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses checkpoint for all other args. """ + def __init__(self, wrap): super().__init__() self.wrap = wrap def forward(self, x, *args, **kwargs): for k, v in kwargs.items(): - assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing. + # This would screw up checkpointing. + assert not (isinstance(v, torch.Tensor) and v.requires_grad) partial = functools.partial(self.wrap, **kwargs) return torch.utils.checkpoint.checkpoint(partial, x, *args) @@ -31,20 +33,22 @@ class CheckpointedXTransformer(nn.Module): Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid to channels-last that XTransformer expects. """ + def __init__(self, **xtransformer_kwargs): super().__init__() self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs) for i in range(len(self.transformer.attn_layers.layers)): n, b, r = self.transformer.attn_layers.layers[i] - self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r]) + self.transformer.attn_layers.layers[i] = nn.ModuleList( + [n, CheckpointedLayer(b), r]) def forward(self, x, **kwargs): return self.transformer(x, **kwargs) class Wav2VecMatcher(nn.Module): - W2V_COMPRESSION=320 + W2V_COMPRESSION = 320 def __init__(self, model_dim, @@ -56,41 +60,42 @@ class Wav2VecMatcher(nn.Module): WAV2VEC_CHANNELS = 1024 self.conditioning_encoder = AudioMiniEncoder(1, model_dim, base_channels=32, depth=6, resnet_blocks=1, - attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) + attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) # nn.Embedding self.text_embedding = ml.Embedding(num_text_tokens, model_dim) self.encoder = CheckpointedXTransformer( - max_seq_len=-1, - use_pos_emb=False, - attn_layers=Encoder( - dim=model_dim, - depth=encoder_depth, - heads=model_dim//64, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_emb_dim=True, - ) + max_seq_len=-1, + use_pos_emb=False, + attn_layers=Encoder( + dim=model_dim, + depth=encoder_depth, + heads=model_dim//64, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_emb_dim=True, ) - self.decoder_start_embedding = nn.Parameter(torch.randn(1,1,model_dim)) - self.decoder_stop_embedding = nn.Parameter(torch.randn(1,model_dim)) + ) + self.decoder_start_embedding = nn.Parameter( + torch.randn(1, 1, model_dim)) + self.decoder_stop_embedding = nn.Parameter(torch.randn(1, model_dim)) self.w2v_query_encoder = ml.Linear(WAV2VEC_CHANNELS, model_dim) self.w2v_value_encoder = ml.Linear(WAV2VEC_CHANNELS, model_dim) self.decoder = CheckpointedXTransformer( - max_seq_len=-1, # Should be unused - use_pos_emb=False, - attn_layers=Decoder( - dim=model_dim, - depth=decoder_depth, - heads=model_dim//64, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - cross_attend=True, - ) + max_seq_len=-1, # Should be unused + use_pos_emb=False, + attn_layers=Decoder( + dim=model_dim, + depth=decoder_depth, + heads=model_dim//64, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + cross_attend=True, + ) ) def get_grad_norm_parameter_groups(self): @@ -111,28 +116,32 @@ class Wav2VecMatcher(nn.Module): enc_inputs = torch.cat([cond_emb.unsqueeze(1), text_emb], dim=1) dec_context = self.encoder(enc_inputs) w2v_values = self.w2v_value_encoder(w2v_logits) - dec_inputs = torch.cat([self.decoder_start_embedding.repeat(w2v_values.shape[0],1,1), w2v_values], dim=1) + dec_inputs = torch.cat([self.decoder_start_embedding.repeat( + w2v_values.shape[0], 1, 1), w2v_values], dim=1) dec_out = self.decoder(dec_inputs, context=dec_context)[:, :-1] w2v_queries = self.w2v_query_encoder(w2v_logits) # Compute losses, A CLIP-like dot product matcher and a mechanism to force pad prediction. - b,l,c = dec_out.shape + b, l, c = dec_out.shape keys_uncompressed = dec_out.reshape(b*l, c) queries_uncompressed = w2v_queries.reshape(b*l, c) - dot = torch.einsum("i c, j c -> i j", keys_uncompressed, queries_uncompressed) + dot = torch.einsum("i c, j c -> i j", + keys_uncompressed, queries_uncompressed) labels = torch.arange(0, b*l, 1, device=dot.device) ce_loss1 = F.cross_entropy(dot, labels, reduction="none") ce_loss2 = F.cross_entropy(dot.t(), labels, reduction="none") - mse_pad_loss = F.mse_loss(keys_uncompressed, self.decoder_stop_embedding.repeat(b*l,1), reduction="none").sum(dim=-1) + mse_pad_loss = F.mse_loss(keys_uncompressed, self.decoder_stop_embedding.repeat( + b*l, 1), reduction="none").sum(dim=-1) # Create a mask based on w2v_lengths that will be used to ensure the encodings of padding tokens are not considered in the cross entropy loss - loss_mask = torch.ones((b,l), device=ce_loss1.device) + loss_mask = torch.ones((b, l), device=ce_loss1.device) w2v_lengths = clip_lengths // self.W2V_COMPRESSION for i in range(b): loss_mask[i, w2v_lengths[i]:] = 0 loss_mask_collapsed = loss_mask.reshape(b*l) - ce_loss = (ce_loss1 * loss_mask_collapsed + ce_loss2 * loss_mask_collapsed).mean() + ce_loss = (ce_loss1 * loss_mask_collapsed + + ce_loss2 * loss_mask_collapsed).mean() mse_loss = (mse_pad_loss * (loss_mask_collapsed == 0)).mean() return ce_loss, mse_loss @@ -151,15 +160,18 @@ class Wav2VecMatcher(nn.Module): dec_out = self.decoder(dec_inputs, context=dec_context) # Check if that was EOS. - l2 = F.mse_loss(dec_out[:,-1], self.decoder_stop_embedding) + l2 = F.mse_loss(dec_out[:, -1], self.decoder_stop_embedding) if l2 < .1: # TODO: fix threshold. break # Find a matching w2v logit from the given iterable. - matching_logit_index = self.find_matching_w2v_logit(dec_out[:,-1], w2v_logit_iterable) + matching_logit_index = self.find_matching_w2v_logit( + dec_out[:, -1], w2v_logit_iterable) matching_logit = w2v_logit_iterable[matching_logit_index] - dec_inputs = torch.cat([dec_inputs, self.w2v_value_encoder(matching_logit).unsqueeze(1)], dim=1) - produced_audio = torch.cat([produced_audio, audio_clip_iterable[matching_logit_index]], dim=-1) + dec_inputs = torch.cat( + [dec_inputs, self.w2v_value_encoder(matching_logit).unsqueeze(1)], dim=1) + produced_audio = torch.cat( + [produced_audio, audio_clip_iterable[matching_logit_index]], dim=-1) return produced_audio @@ -170,9 +182,9 @@ def register_w2v_matcher(opt_net, opt): if __name__ == '__main__': model = Wav2VecMatcher(512, 8, 8) - toks = torch.randint(0, 100, (4,100)) - tok_lens = torch.tensor([50,60,70,80]) - cond = torch.randn(4,1,44000) - logits = torch.randn(4,120,1024) - logit_lens = torch.tensor([60,70,80,90]) - model(toks, cond, logits, tok_lens, logit_lens) \ No newline at end of file + toks = torch.randint(0, 100, (4, 100)) + tok_lens = torch.tensor([50, 60, 70, 80]) + cond = torch.randn(4, 1, 44000) + logits = torch.randn(4, 120, 1024) + logit_lens = torch.tensor([60, 70, 80, 90]) + model(toks, cond, logits, tok_lens, logit_lens) diff --git a/dlas/models/audio/vocoders/univnet/generator.py b/dlas/models/audio/vocoders/univnet/generator.py index c469cc75..5112e121 100644 --- a/dlas/models/audio/vocoders/univnet/generator.py +++ b/dlas/models/audio/vocoders/univnet/generator.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from omegaconf import OmegaConf -from models.audio.vocoders.univnet.lvcnet import LVCBlock -from trainer.networks import register_model +from dlas.models.audio.vocoders.univnet.lvcnet import LVCBlock +from dlas.trainer.networks import register_model MAX_WAV_VALUE = 32768.0 @@ -11,7 +11,7 @@ MAX_WAV_VALUE = 32768.0 class UnivNetGenerator(nn.Module): """UnivNet Generator""" - def __init__(self, noise_dim=64, channel_size=32, dilations=[1,3,9,27], strides=[8,8,4], lReLU_slope=.2, kpnet_conv_size=3, + def __init__(self, noise_dim=64, channel_size=32, dilations=[1, 3, 9, 27], strides=[8, 8, 4], lReLU_slope=.2, kpnet_conv_size=3, # Below are MEL configurations options that this generator requires. hop_length=256, n_mel_channels=100): super(UnivNetGenerator, self).__init__() @@ -38,11 +38,13 @@ class UnivNetGenerator(nn.Module): ) self.conv_pre = \ - nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode='reflect')) + nn.utils.weight_norm( + nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode='reflect')) self.conv_post = nn.Sequential( nn.LeakyReLU(lReLU_slope), - nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode='reflect')), + nn.utils.weight_norm( + nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode='reflect')), nn.Tanh(), ) @@ -84,11 +86,13 @@ class UnivNetGenerator(nn.Module): def inference(self, c, z=None): # pad input mel with zeros to cut artifact # see https://github.com/seungwonpark/melgan/issues/8 - zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device) + zero = torch.full( + (c.shape[0], self.mel_channel, 10), -11.5129).to(c.device) mel = torch.cat((c, zero), dim=2) if z is None: - z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device) + z = torch.randn(c.shape[0], self.noise_dim, + mel.size(2)).to(mel.device) audio = self.forward(mel, z) audio = audio[:, :, :-(self.hop_length * 10)] @@ -112,5 +116,6 @@ if __name__ == '__main__': print(y.shape) assert y.shape == torch.Size([3, 1, 2560]) - pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + pytorch_total_params = sum(p.numel() + for p in model.parameters() if p.requires_grad) print(pytorch_total_params) diff --git a/dlas/models/audio/vocoders/univnet/lvcnet.py b/dlas/models/audio/vocoders/univnet/lvcnet.py index af0ff8af..6ec03d52 100644 --- a/dlas/models/audio/vocoders/univnet/lvcnet.py +++ b/dlas/models/audio/vocoders/univnet/lvcnet.py @@ -35,12 +35,15 @@ class KernelPredictor(torch.nn.Module): self.conv_kernel_size = conv_kernel_size self.conv_layers = conv_layers - kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w + kpnet_kernel_channels = conv_in_channels * \ + conv_out_channels * conv_kernel_size * conv_layers # l_w kpnet_bias_channels = conv_out_channels * conv_layers # l_b self.input_conv = nn.Sequential( - nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)), - getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + nn.utils.weight_norm( + nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)), + getattr(nn, kpnet_nonlinear_activation)( + **kpnet_nonlinear_activation_params), ) self.residual_convs = nn.ModuleList() @@ -52,11 +55,13 @@ class KernelPredictor(torch.nn.Module): nn.utils.weight_norm( nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True)), - getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + getattr(nn, kpnet_nonlinear_activation)( + **kpnet_nonlinear_activation_params), nn.utils.weight_norm( nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True)), - getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + getattr(nn, kpnet_nonlinear_activation)( + **kpnet_nonlinear_activation_params), ) ) self.kernel_conv = nn.utils.weight_norm( @@ -170,7 +175,8 @@ class LVCBlock(torch.nn.Module): for i, conv in enumerate(self.conv_blocks): output = conv(x) # (B, c_g, stride * L') - k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length) + # (B, 2 * c_g, c_g, kernel_size, cond_length) + k = kernels[:, i, :, :, :, :] b = bias[:, i, :, :] # (B, 2 * c_g, cond_length) output = self.location_variable_convolution(output, k, b, @@ -194,23 +200,29 @@ class LVCBlock(torch.nn.Module): ''' batch, _, in_length = x.shape batch, _, out_channels, kernel_size, kernel_length = kernel.shape - assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" + assert in_length == ( + kernel_length * hop_size), "length of (x, kernel) is not matched" padding = dilation * int((kernel_size - 1) / 2) - x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding) - x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) + # (batch, in_channels, in_length + 2*padding) + x = F.pad(x, (padding, padding), 'constant', 0) + # (batch, in_channels, kernel_length, hop_size + 2*padding) + x = x.unfold(2, hop_size + 2 * padding, hop_size) if hop_size < dilation: x = F.pad(x, (0, dilation), 'constant', 0) x = x.unfold(3, dilation, dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) x = x[:, :, :, :, :hop_size] - x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) - x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) + # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) + x = x.transpose(3, 4) + # (batch, in_channels, kernel_length, dilation, _, kernel_size) + x = x.unfold(4, kernel_size, 1) o = torch.einsum('bildsk,biokl->bolsd', x, kernel) o = o.to(memory_format=torch.channels_last_3d) - bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) + bias = bias.unsqueeze(-1).unsqueeze(-1).to( + memory_format=torch.channels_last_3d) o = o + bias o = o.contiguous().view(batch, out_channels, -1) @@ -220,4 +232,4 @@ class LVCBlock(torch.nn.Module): self.kernel_predictor.remove_weight_norm() nn.utils.remove_weight_norm(self.convt_pre[1]) for block in self.conv_blocks: - nn.utils.remove_weight_norm(block[1]) \ No newline at end of file + nn.utils.remove_weight_norm(block[1]) diff --git a/dlas/models/audio/vocoders/waveglow/denoiser.py b/dlas/models/audio/vocoders/waveglow/denoiser.py index a9e05967..ed78b3d8 100644 --- a/dlas/models/audio/vocoders/waveglow/denoiser.py +++ b/dlas/models/audio/vocoders/waveglow/denoiser.py @@ -1,9 +1,10 @@ import sys -from models.audio.tts.tacotron2.stft import STFT +import torch + +from dlas.models.audio.tts.tacotron2.stft import STFT sys.path.append('tacotron2') -import torch class Denoiser(torch.nn.Module): diff --git a/dlas/models/audio/vocoders/waveglow/waveglow.py b/dlas/models/audio/vocoders/waveglow/waveglow.py index f45ddbe3..af3b6911 100644 --- a/dlas/models/audio/vocoders/waveglow/waveglow.py +++ b/dlas/models/audio/vocoders/waveglow/waveglow.py @@ -25,11 +25,12 @@ # # ***************************************************************************** import copy -import torch -from torch.autograd import Variable -import torch.nn.functional as F -from trainer.networks import register_model +import torch +import torch.nn.functional as F +from torch.autograd import Variable + +from dlas.trainer.networks import register_model @torch.jit.script @@ -57,7 +58,8 @@ class WaveGlowLoss(torch.nn.Module): log_s_total = log_s_total + torch.sum(log_s) log_det_W_total += log_det_W_list[i] - loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total + loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - \ + log_s_total - log_det_W_total return loss/(z.size(0)*z.size(1)*z.size(2)) @@ -67,6 +69,7 @@ class Invertible1x1Conv(torch.nn.Module): of its weight matrix. If reverse=True it does convolution with inverse """ + def __init__(self, c): super(Invertible1x1Conv, self).__init__() self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0, @@ -77,7 +80,7 @@ class Invertible1x1Conv(torch.nn.Module): # Ensure determinant is 1.0 not -1.0 if torch.det(W) < 0: - W[:,0] = -1*W[:,0] + W[:, 0] = -1*W[:, 0] W = W.view(c, c, 1) self.conv.weight.data = W @@ -110,11 +113,12 @@ class WN(torch.nn.Module): from WaveNet is the convolutions need not be causal. There is also no dilation size reset. The dilation only doubles on each layer """ + def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels, kernel_size): super(WN, self).__init__() - assert(kernel_size % 2 == 1) - assert(n_channels % 2 == 0) + assert (kernel_size % 2 == 1) + assert (n_channels % 2 == 0) self.n_layers = n_layers self.n_channels = n_channels self.in_layers = torch.nn.ModuleList() @@ -142,14 +146,14 @@ class WN(torch.nn.Module): in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') self.in_layers.append(in_layer) - # last one is not necessary if i < n_layers - 1: res_skip_channels = 2*n_channels else: res_skip_channels = n_channels res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) - res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') + res_skip_layer = torch.nn.utils.weight_norm( + res_skip_layer, name='weight') self.res_skip_layers.append(res_skip_layer) def forward(self, forward_input): @@ -164,13 +168,13 @@ class WN(torch.nn.Module): spect_offset = i*2*self.n_channels acts = fused_add_tanh_sigmoid_multiply( self.in_layers[i](audio), - spect[:,spect_offset:spect_offset+2*self.n_channels,:], + spect[:, spect_offset:spect_offset+2*self.n_channels, :], n_channels_tensor) res_skip_acts = self.res_skip_layers[i](acts) if i < self.n_layers - 1: - audio = audio + res_skip_acts[:,:self.n_channels,:] - output = output + res_skip_acts[:,self.n_channels:,:] + audio = audio + res_skip_acts[:, :self.n_channels, :] + output = output + res_skip_acts[:, self.n_channels:, :] else: output = output + res_skip_acts @@ -185,7 +189,7 @@ class WaveGlow(torch.nn.Module): self.upsample = torch.nn.ConvTranspose1d(n_mel_channels, n_mel_channels, 1024, stride=256) - assert(n_group % 2 == 0) + assert (n_group % 2 == 0) self.n_flows = n_flows self.n_group = n_group self.n_early_every = n_early_every @@ -215,7 +219,7 @@ class WaveGlow(torch.nn.Module): # Upsample spectrogram to size of audio spect = self.upsample(spect) - assert(spect.size(2) >= audio.size(1)) + assert (spect.size(2) >= audio.size(1)) if spect.size(2) > audio.size(1): spect = spect[:, :, :audio.size(1)] @@ -229,15 +233,15 @@ class WaveGlow(torch.nn.Module): for k in range(self.n_flows): if k % self.n_early_every == 0 and k > 0: - output_audio.append(audio[:,:self.n_early_size,:]) - audio = audio[:,self.n_early_size:,:] + output_audio.append(audio[:, :self.n_early_size, :]) + audio = audio[:, self.n_early_size:, :] audio, log_det_W = self.convinv[k](audio) log_det_W_list.append(log_det_W) n_half = int(audio.size(1)/2) - audio_0 = audio[:,:n_half,:] - audio_1 = audio[:,n_half:,:] + audio_0 = audio[:, :n_half, :] + audio_1 = audio[:, n_half:, :] output = self.WN[k]((audio_0, spect)) log_s = output[:, n_half:, :] @@ -245,10 +249,10 @@ class WaveGlow(torch.nn.Module): audio_1 = torch.exp(log_s)*audio_1 + b log_s_list.append(log_s) - audio = torch.cat([audio_0, audio_1],1) + audio = torch.cat([audio_0, audio_1], 1) output_audio.append(audio) - return torch.cat(output_audio,1), log_s_list, log_det_W_list + return torch.cat(output_audio, 1), log_s_list, log_det_W_list def infer(self, spect, sigma=1.0): spect = self.upsample(spect) @@ -272,26 +276,29 @@ class WaveGlow(torch.nn.Module): for k in reversed(range(self.n_flows)): n_half = int(audio.size(1)/2) - audio_0 = audio[:,:n_half,:] - audio_1 = audio[:,n_half:,:] + audio_0 = audio[:, :n_half, :] + audio_1 = audio[:, n_half:, :] output = self.WN[k]((audio_0, spect)) s = output[:, n_half:, :] b = output[:, :n_half, :] audio_1 = (audio_1 - b)/torch.exp(s) - audio = torch.cat([audio_0, audio_1],1) + audio = torch.cat([audio_0, audio_1], 1) audio = self.convinv[k](audio, reverse=True) if k % self.n_early_every == 0 and k > 0: if spect.type() == 'torch.cuda.HalfTensor': - z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() + z = torch.cuda.HalfTensor(spect.size( + 0), self.n_early_size, spect.size(2)).normal_() else: - z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() - audio = torch.cat((sigma*z, audio),1) + z = torch.cuda.FloatTensor(spect.size( + 0), self.n_early_size, spect.size(2)).normal_() + audio = torch.cat((sigma*z, audio), 1) - audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data + audio = audio.permute(0, 2, 1).contiguous().view( + audio.size(0), -1).data return audio @staticmethod @@ -315,4 +322,4 @@ def remove(conv_list): @register_model def register_nv_waveglow(opt_net, opt): - return WaveGlow(**opt_net['args']) \ No newline at end of file + return WaveGlow(**opt_net['args']) diff --git a/dlas/models/classifiers/cifar_resnet.py b/dlas/models/classifiers/cifar_resnet.py index 86a7af0e..2e7e03ff 100644 --- a/dlas/models/classifiers/cifar_resnet.py +++ b/dlas/models/classifiers/cifar_resnet.py @@ -8,11 +8,11 @@ https://arxiv.org/abs/1512.03385v1 """ -import torch -import torch.nn as nn -import torch_intermediary as ml -from trainer.networks import register_model +import torch.nn as nn + +import dlas.torch_intermediary as ml +from dlas.trainer.networks import register_model class BasicBlock(nn.Module): @@ -20,53 +20,60 @@ class BasicBlock(nn.Module): """ - #BasicBlock and BottleNeck block - #have different output size - #we use class attribute expansion - #to distinct + # BasicBlock and BottleNeck block + # have different output size + # we use class attribute expansion + # to distinct expansion = 1 def __init__(self, in_channels, out_channels, stride=1): super().__init__() - #residual function + # residual function self.residual_function = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), + nn.Conv2d(in_channels, out_channels, kernel_size=3, + stride=stride, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), + nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, + kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels * BasicBlock.expansion) ) - #shortcut + # shortcut self.shortcut = nn.Sequential() - #the shortcut output dimension is not the same with residual function - #use 1*1 convolution to match the dimension + # the shortcut output dimension is not the same with residual function + # use 1*1 convolution to match the dimension if stride != 1 or in_channels != BasicBlock.expansion * out_channels: self.shortcut = nn.Sequential( - nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), + nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, + kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels * BasicBlock.expansion) ) def forward(self, x): return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) + class BottleNeck(nn.Module): """Residual block for resnet over 50 layers """ expansion = 4 + def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.residual_function = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), + nn.Conv2d(out_channels, out_channels, stride=stride, + kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False), + nn.Conv2d(out_channels, out_channels * + BottleNeck.expansion, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels * BottleNeck.expansion), ) @@ -74,13 +81,15 @@ class BottleNeck(nn.Module): if stride != 1 or in_channels != out_channels * BottleNeck.expansion: self.shortcut = nn.Sequential( - nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), + nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, + stride=stride, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels * BottleNeck.expansion) ) def forward(self, x): return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) + class ResNet(nn.Module): def __init__(self, block, num_block, num_classes=100): @@ -92,8 +101,8 @@ class ResNet(nn.Module): nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True)) - #we use a different inputsize than the original paper - #so conv2_x's stride is 1 + # we use a different inputsize than the original paper + # so conv2_x's stride is 1 self.conv2_x = self._make_layer(block, 32, num_block[0], 1) self.conv3_x = self._make_layer(block, 64, num_block[1], 2) self.conv4_x = self._make_layer(block, 128, num_block[2], 2) @@ -138,30 +147,33 @@ class ResNet(nn.Module): return output + @register_model def register_cifar_resnet18(opt_net, opt): """ return a ResNet 18 object """ return ResNet(BasicBlock, [2, 2, 2, 2]) + def resnet34(): """ return a ResNet 34 object """ return ResNet(BasicBlock, [3, 4, 6, 3]) + def resnet50(): """ return a ResNet 50 object """ return ResNet(BottleNeck, [3, 4, 6, 3]) + def resnet101(): """ return a ResNet 101 object """ return ResNet(BottleNeck, [3, 4, 23, 3]) + def resnet152(): """ return a ResNet 152 object """ return ResNet(BottleNeck, [3, 8, 36, 3]) - - diff --git a/dlas/models/classifiers/resnet_with_checkpointing.py b/dlas/models/classifiers/resnet_with_checkpointing.py index f6f6c5e7..b1b5b6bb 100644 --- a/dlas/models/classifiers/resnet_with_checkpointing.py +++ b/dlas/models/classifiers/resnet_with_checkpointing.py @@ -1,18 +1,17 @@ # A direct copy of torchvision's resnet.py modified to support gradient checkpointing. import torch -import torch.nn as nn -from torchvision.models.resnet import BasicBlock, Bottleneck import torchvision -import torch_intermediary as ml +from torchvision.models.resnet import BasicBlock, Bottleneck +import dlas.torch_intermediary as ml __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'] -from trainer.networks import register_model -from utils.util import checkpoint +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', diff --git a/dlas/models/classifiers/torch_models.py b/dlas/models/classifiers/torch_models.py index 998c304c..c9fab4e7 100644 --- a/dlas/models/classifiers/torch_models.py +++ b/dlas/models/classifiers/torch_models.py @@ -1,7 +1,7 @@ from torchvision.models import vgg16 -from trainer.networks import register_model -from utils.util import opt_get +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get @register_model diff --git a/dlas/models/classifiers/twin_cifar_resnet.py b/dlas/models/classifiers/twin_cifar_resnet.py index bf4d1214..9363dbbb 100644 --- a/dlas/models/classifiers/twin_cifar_resnet.py +++ b/dlas/models/classifiers/twin_cifar_resnet.py @@ -11,9 +11,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torch_intermediary as ml -from trainer.networks import register_model +import dlas.torch_intermediary as ml +from dlas.trainer.networks import register_model class BasicBlock(nn.Module): @@ -21,32 +21,35 @@ class BasicBlock(nn.Module): """ - #BasicBlock and BottleNeck block - #have different output size - #we use class attribute expansion - #to distinct + # BasicBlock and BottleNeck block + # have different output size + # we use class attribute expansion + # to distinct expansion = 1 def __init__(self, in_channels, out_channels, stride=1): super().__init__() - #residual function + # residual function self.residual_function = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), + nn.Conv2d(in_channels, out_channels, kernel_size=3, + stride=stride, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), + nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, + kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels * BasicBlock.expansion) ) - #shortcut + # shortcut self.shortcut = nn.Sequential() - #the shortcut output dimension is not the same with residual function - #use 1*1 convolution to match the dimension + # the shortcut output dimension is not the same with residual function + # use 1*1 convolution to match the dimension if stride != 1 or in_channels != BasicBlock.expansion * out_channels: self.shortcut = nn.Sequential( - nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), + nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, + kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels * BasicBlock.expansion) ) @@ -59,16 +62,19 @@ class BottleNeck(nn.Module): """ expansion = 4 + def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.residual_function = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), + nn.Conv2d(out_channels, out_channels, stride=stride, + kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False), + nn.Conv2d(out_channels, out_channels * + BottleNeck.expansion, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels * BottleNeck.expansion), ) @@ -76,7 +82,8 @@ class BottleNeck(nn.Module): if stride != 1 or in_channels != out_channels * BottleNeck.expansion: self.shortcut = nn.Sequential( - nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), + nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, + stride=stride, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels * BottleNeck.expansion) ) @@ -95,8 +102,8 @@ class ResNet(nn.Module): nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True)) - #we use a different inputsize than the original paper - #so conv2_x's stride is 1 + # we use a different inputsize than the original paper + # so conv2_x's stride is 1 self.conv2_x = self._make_layer(block, 32, num_block[0], 1) self.conv3_x = self._make_layer(block, 64, num_block[1], 2) self.conv4_x = self._make_layer(block, 128, num_block[2], 2) @@ -143,7 +150,7 @@ class ResNet(nn.Module): class SymbolicLoss: - def __init__(self, category_depths=[3,5,5,3], convergence_weighting=[1,.6,.3,.1], divergence_weighting=[.1,.3,.6,1]): + def __init__(self, category_depths=[3, 5, 5, 3], convergence_weighting=[1, .6, .3, .1], divergence_weighting=[.1, .3, .6, 1]): self.depths = category_depths self.total_classes = 1 for c in category_depths: @@ -175,7 +182,8 @@ class SymbolicLoss: level_logits = level_logits.sum(dim=-1) level_labels = collaboratorLabels.div(epc, rounding_mode='trunc') # Convergence - convergence_loss = convergence_loss + F.cross_entropy(level_logits, level_labels) * cw + convergence_loss = convergence_loss + \ + F.cross_entropy(level_logits, level_labels) * cw # Divergence div_label_indices = level_logits.argmax(dim=-1) # TODO: find the torch-y way of doing this. @@ -184,7 +192,8 @@ class SymbolicLoss: dp.append(level_logits[:, i]) div_preds = torch.stack(dp, dim=0) div_labels = torch.arange(0, b, device=logits.device) - divergence_loss = divergence_loss + F.cross_entropy(div_preds, div_labels) + divergence_loss = divergence_loss + \ + F.cross_entropy(div_preds, div_labels) return convergence_loss, divergence_loss @@ -199,15 +208,19 @@ class TwinnedCifar(nn.Module): def __init__(self): super().__init__() self.loss = SymbolicLoss() - self.netA = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=self.loss.total_classes) - self.netB = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=self.loss.total_classes) + self.netA = ResNet( + BasicBlock, [2, 2, 2, 2], num_classes=self.loss.total_classes) + self.netB = ResNet( + BasicBlock, [2, 2, 2, 2], num_classes=self.loss.total_classes) def forward(self, x): y1 = self.netA(x) y2 = self.netB(x) b = x.shape[0] - convergenceA, divergenceA = self.loss(y1[:b//2], y2.argmax(dim=-1)[:b//2]) - convergenceB, divergenceB = self.loss(y2[b//2:], y1.argmax(dim=-1)[b//2:]) + convergenceA, divergenceA = self.loss( + y1[:b//2], y2.argmax(dim=-1)[:b//2]) + convergenceB, divergenceB = self.loss( + y2[b//2:], y1.argmax(dim=-1)[b//2:]) return convergenceA + convergenceB, divergenceA + divergenceB @@ -217,24 +230,26 @@ def register_twin_cifar(opt_net, opt): """ return TwinnedCifar() + def resnet34(): """ return a ResNet 34 object """ return ResNet(BasicBlock, [3, 4, 6, 3]) + def resnet50(): """ return a ResNet 50 object """ return ResNet(BottleNeck, [3, 4, 6, 3]) + def resnet101(): """ return a ResNet 101 object """ return ResNet(BottleNeck, [3, 4, 23, 3]) + def resnet152(): """ return a ResNet 152 object """ return ResNet(BottleNeck, [3, 8, 36, 3]) - - diff --git a/dlas/models/classifiers/weighted_conv_resnet.py b/dlas/models/classifiers/weighted_conv_resnet.py index 5b5beaf6..40c1098f 100644 --- a/dlas/models/classifiers/weighted_conv_resnet.py +++ b/dlas/models/classifiers/weighted_conv_resnet.py @@ -1,17 +1,19 @@ +from typing import (Any, Callable, Iterator, List, Optional, OrderedDict, Type, + Union) + import torch +import torch.nn as nn import torchvision from torch import Tensor -import torch.nn as nn -from typing import Type, Any, Callable, Union, List, Optional, OrderedDict, Iterator __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'] -from models.vqvae.scaled_weight_conv import ScaledWeightConv -from trainer.networks import register_model -from utils.util import checkpoint -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.models.vqvae.scaled_weight_conv import ScaledWeightConv +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', @@ -29,7 +31,7 @@ model_urls = { def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1, breadth: int = 8) -> ScaledWeightConv: """3x3 convolution with padding""" return ScaledWeightConv(in_planes, out_planes, kernel_size=3, stride=stride, - padding=dilation, groups=groups, bias=False, dilation=dilation, breadth=breadth) + padding=dilation, groups=groups, bias=False, dilation=dilation, breadth=breadth) def conv1x1(in_planes: int, out_planes: int, stride: int = 1, breadth: int = 8) -> ScaledWeightConv: @@ -80,9 +82,11 @@ class BasicBlock(nn.Module): if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: - raise ValueError('BasicBlock only supports groups=1 and base_width=64') + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64') if dilation > 1: - raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + raise NotImplementedError( + "Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride, breadth=breadth) self.bn1 = norm_layer(planes) @@ -139,7 +143,8 @@ class Bottleneck(nn.Module): # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width, breadth=breadth) self.bn1 = norm_layer(width) - self.conv2 = conv3x3(width, width, stride, groups, dilation, breadth=breadth) + self.conv2 = conv3x3(width, width, stride, groups, + dilation, breadth=breadth) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion, breadth=breadth) self.bn3 = norm_layer(planes * self.expansion) @@ -202,7 +207,7 @@ class ResNet(nn.Module): self.groups = groups self.base_width = width_per_group self.conv1 = ScaledWeightConv(3, self.inplanes, kernel_size=7, stride=2, padding=3, - bias=False, breadth=breadth) + bias=False, breadth=breadth) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) @@ -218,7 +223,8 @@ class ResNet(nn.Module): for m in self.modules(): if isinstance(m, ScaledWeightConv): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) @@ -229,9 +235,11 @@ class ResNet(nn.Module): if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): - nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + # type: ignore[arg-type] + nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, BasicBlock): - nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + # type: ignore[arg-type] + nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, breadth: int, stride: int = 1, dilate: bool = False) -> MaskedSequential: @@ -243,7 +251,8 @@ class ResNet(nn.Module): stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = MaskedSequential( - conv1x1(self.inplanes, planes * block.expansion, stride, breadth=breadth), + conv1x1(self.inplanes, planes * block.expansion, + stride, breadth=breadth), norm_layer(planes * block.expansion), ) @@ -435,7 +444,7 @@ if __name__ == '__main__': masks = {} for j in range(6): cdim = idim // (2 ** j) - masks[cdim] = torch.zeros((1,1,cdim,cdim), dtype=torch.long) - i = torch.rand(1,3,idim,idim) + masks[cdim] = torch.zeros((1, 1, cdim, cdim), dtype=torch.long) + i = torch.rand(1, 3, idim, idim) r1 = mod(i, masks) r2 = orig(i) diff --git a/dlas/models/classifiers/wide_kernel_vgg.py b/dlas/models/classifiers/wide_kernel_vgg.py index ab26fb0f..af3f975b 100644 --- a/dlas/models/classifiers/wide_kernel_vgg.py +++ b/dlas/models/classifiers/wide_kernel_vgg.py @@ -1,9 +1,10 @@ import torch import torch.nn as nn -from trainer.networks import register_model -from utils.util import opt_get -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get + class WideKernelVgg(nn.Module): def __init__(self, nf=64, num_classes=2): @@ -55,18 +56,19 @@ class WideKernelVgg(nn.Module): ) # These normalization constants should be derived experimentally. - self.log_fft_mean = torch.tensor([-3.5184, -4.071]).view(1,1,1,2) - self.log_fft_std = torch.tensor([3.1660, 3.8042]).view(1,1,1,2) + self.log_fft_mean = torch.tensor([-3.5184, -4.071]).view(1, 1, 1, 2) + self.log_fft_std = torch.tensor([3.1660, 3.8042]).view(1, 1, 1, 2) def forward(self, x): - b,c,h,w = x.shape + b, c, h, w = x.shape x_c = x.view(c*b, h, w) x_c = torch.view_as_real(torch.fft.rfft(x_c)) # Log-normalize spectrogram x_c = (x_c.abs() ** 2).clip(min=1e-8, max=1e16) x_c = torch.log(x_c) - x_c = (x_c - self.log_fft_mean.to(x.device)) / self.log_fft_std.to(x.device) + x_c = (x_c - self.log_fft_mean.to(x.device)) / \ + self.log_fft_std.to(x.device) # Return to expected input shape (b,c,h,w) x_c = x_c.permute(0, 3, 1, 2).reshape(b, c * 2, h, w // 2 + 1) @@ -83,4 +85,4 @@ def register_wide_kernel_vgg(opt_net, opt): if __name__ == '__main__': vgg = WideKernelVgg() - vgg(torch.randn(1,3,256,256)) \ No newline at end of file + vgg(torch.randn(1, 3, 256, 256)) diff --git a/dlas/models/clip/clip.py b/dlas/models/clip/clip.py index e7f62a8b..13c43475 100644 --- a/dlas/models/clip/clip.py +++ b/dlas/models/clip/clip.py @@ -1,11 +1,12 @@ import torch import torch.nn as nn -from trainer.networks import register_model -from utils.util import opt_get + +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get def encoder_for_type(type, master_dim, enc_kwargs): - from x_clip.x_clip import VisionTransformer, TextTransformer + from x_clip.x_clip import TextTransformer, VisionTransformer if type == 'image': # xclip_kwargs: image_size, patch_size, channels, depth, heads return VisionTransformer(dim=master_dim, **enc_kwargs) @@ -47,7 +48,7 @@ class XClipWrapper(nn.Module): def forward(self, seq1, seq2, return_loss=False): seq1_mask = torch.rand_like(seq1.float()) > self.mask_seq1_percentage # TODO: add support for seq2 mask.. - #seq2_mask = torch.rand_like(seq2.float()) > self.mask_seq2_percentage + # seq2_mask = torch.rand_like(seq2.float()) > self.mask_seq2_percentage return self.clip(seq1, seq2, seq1_mask, return_loss=return_loss) @@ -55,9 +56,12 @@ class XClipWrapper(nn.Module): def register_clip(opt_net, opt): return XClipWrapper(**opt_get(opt_net, ['kwargs'], {})) + if __name__ == '__main__': model = XClipWrapper(enc1_type='tokens', enc2_type='tokens', - enc1_kwargs={'num_tokens': 256, 'max_seq_len': 200, 'depth': 8, 'heads': 8}, + enc1_kwargs={'num_tokens': 256, + 'max_seq_len': 200, 'depth': 8, 'heads': 8}, enc2_kwargs={'num_tokens': 8192, 'max_seq_len': 250, 'depth': 8, 'heads': 8}) - loss = model(torch.randint(0,256, (2,200)), torch.randint(0,8192, (2,250)), True) - print(loss) \ No newline at end of file + loss = model(torch.randint(0, 256, (2, 200)), + torch.randint(0, 8192, (2, 250)), True) + print(loss) diff --git a/dlas/models/clip/clvp.py b/dlas/models/clip/clvp.py index 9bfc7655..6e85f7ab 100644 --- a/dlas/models/clip/clvp.py +++ b/dlas/models/clip/clvp.py @@ -3,14 +3,15 @@ from random import random import torch import torch.nn as nn import torch.nn.functional as F -from torch import einsum, distributed +from torch import distributed, einsum from torch.distributed import get_world_size -from models.arch_util import AttentionBlock -from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder -from trainer.networks import register_model -from utils.util import opt_get, checkpoint -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.models.arch_util import AttentionBlock +from dlas.models.lucidrains.x_transformers import ( + ContinuousTransformerWrapper, Encoder) +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint, opt_get def exists(val): @@ -19,7 +20,7 @@ def exists(val): def masked_mean(t, mask): t = t.masked_fill(~mask, 0.) - return t.sum(dim = 1) / mask.sum(dim = 1) + return t.sum(dim=1) / mask.sum(dim=1) class CollapsingTransformer(nn.Module): @@ -41,14 +42,15 @@ class CollapsingTransformer(nn.Module): **encoder_kwargs, )) self.pre_combiner = nn.Sequential(nn.Conv1d(model_dim, output_dims, 1), - AttentionBlock(output_dims, num_heads=heads, do_checkpoint=False), + AttentionBlock( + output_dims, num_heads=heads, do_checkpoint=False), nn.Conv1d(output_dims, output_dims, 1)) self.mask_percentage = mask_percentage def forward(self, x, **transformer_kwargs): h = self.transformer(x, **transformer_kwargs) - h = h.permute(0,2,1) - h = checkpoint(self.pre_combiner, h).permute(0,2,1) + h = h.permute(0, 2, 1) + h = checkpoint(self.pre_combiner, h).permute(0, 2, 1) if self.training: mask = torch.rand_like(h.float()) > self.mask_percentage else: @@ -64,7 +66,7 @@ class ConvFormatEmbedding(nn.Module): def forward(self, x): y = self.emb(x) - return y.permute(0,2,1) + return y.permute(0, 2, 1) class CLVP(nn.Module): @@ -96,21 +98,26 @@ class CLVP(nn.Module): self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2), nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1)) - self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim*2, transformer_heads, dropout, conditioning_enc_depth, 0) - self.masked_conditioning_latent = nn.Parameter(torch.randn(1,model_dim*2), requires_grad=True) + self.conditioning_transformer = CollapsingTransformer( + model_dim, model_dim*2, transformer_heads, dropout, conditioning_enc_depth, 0) + self.masked_conditioning_latent = nn.Parameter( + torch.randn(1, model_dim*2), requires_grad=True) self.mask_conditioning_percentage = mask_conditioning_percentage # nn.Embedding self.text_emb = ml.Embedding(num_text_tokens, model_dim) - self.text_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True) + self.text_transformer = CollapsingTransformer( + model_dim, latent_dim, transformer_heads, dropout, text_enc_depth, text_mask_percentage, use_rms_scaleshift_norm=True) self.to_text_latent = ml.Linear(latent_dim, latent_dim, bias=False) self.distributed_collect = distributed_collect if mel_codes is None: - self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2) + self.speech_emb = nn.Conv1d( + mel_channels, model_dim, kernel_size=5, padding=2) else: self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) - self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) + self.speech_transformer = CollapsingTransformer( + model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) self.to_speech_latent = ml.Linear(latent_dim, latent_dim, bias=False) def get_grad_norm_parameter_groups(self): @@ -130,30 +137,35 @@ class CLVP(nn.Module): device = text.device text_emb = self.text_emb(text) - speech_emb = self.speech_emb(mel_input).permute(0,2,1) + speech_emb = self.speech_emb(mel_input).permute(0, 2, 1) unused_params = [] if random() < self.mask_conditioning_percentage: enc_cond = self.masked_conditioning_latent - unused_params.extend(list(self.cond_emb.parameters()) + list(self.conditioning_transformer.parameters())) + unused_params.extend(list(self.cond_emb.parameters( + )) + list(self.conditioning_transformer.parameters())) else: - cond_emb = self.cond_emb(mel_cond).permute(0,2,1) + cond_emb = self.cond_emb(mel_cond).permute(0, 2, 1) enc_cond = self.conditioning_transformer(cond_emb) unused_params.append(self.masked_conditioning_latent) - enc_text = self.text_transformer(text_emb, norm_scale_shift_inp=enc_cond) + enc_text = self.text_transformer( + text_emb, norm_scale_shift_inp=enc_cond) enc_speech = self.speech_transformer(speech_emb) text_latents = self.to_text_latent(enc_text) speech_latents = self.to_speech_latent(enc_speech) - text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)) + text_latents, speech_latents = map(lambda t: F.normalize( + t, p=2, dim=-1), (text_latents, speech_latents)) temp = self.temperature.exp() if self.distributed_collect: - collective = [torch.zeros_like(text_latents) for _ in range(torch.distributed.get_world_size())] + collective = [torch.zeros_like(text_latents) for _ in range( + torch.distributed.get_world_size())] torch.all_gather(collective, text_latents) text_latents = torch.cat(collective, dim=0) - collective = [torch.zeros_like(speech_latents) for _ in range(torch.distributed.get_world_size())] + collective = [torch.zeros_like(speech_latents) for _ in range( + torch.distributed.get_world_size())] torch.all_gather(collective, speech_latents) speech_latents = torch.cat(collective, dim=0) @@ -163,7 +175,8 @@ class CLVP(nn.Module): sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp labels = torch.arange(text_latents.shape[0], device=device) - loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 + loss = (F.cross_entropy(sim, labels) + + F.cross_entropy(sim.t(), labels)) / 2 # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. extraneous_addition = 0 @@ -181,17 +194,17 @@ def register_clvp(opt_net, opt): if __name__ == '__main__': clvp = CLVP() - clvp(torch.randint(0,256,(2,120)), - torch.randn(2,80,100), - torch.randn(2,80,95), + clvp(torch.randint(0, 256, (2, 120)), + torch.randn(2, 80, 100), + torch.randn(2, 80, 95), return_loss=True) - nonloss = clvp(torch.randint(0,256,(2,120)), - torch.randn(2,80,100), - torch.randn(2,80,95), - return_loss=False) + nonloss = clvp(torch.randint(0, 256, (2, 120)), + torch.randn(2, 80, 100), + torch.randn(2, 80, 95), + return_loss=False) clvp = CLVP(mel_codes=8192) - clvp(torch.randint(0,256,(2,120)), - torch.randint(0,8192,(2,150)), - torch.randn(2,80,95), + clvp(torch.randint(0, 256, (2, 120)), + torch.randint(0, 8192, (2, 150)), + torch.randn(2, 80, 95), return_loss=True) - print(nonloss.shape) \ No newline at end of file + print(nonloss.shape) diff --git a/dlas/models/clip/contrastive_audio.py b/dlas/models/clip/contrastive_audio.py index 85edfec6..94b93d19 100644 --- a/dlas/models/clip/contrastive_audio.py +++ b/dlas/models/clip/contrastive_audio.py @@ -3,13 +3,13 @@ from random import random import torch import torch.nn as nn import torch.nn.functional as F -from torch import einsum -from models.arch_util import AttentionBlock -from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder -from trainer.networks import register_model -from utils.util import opt_get, checkpoint -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.models.arch_util import AttentionBlock +from dlas.models.lucidrains.x_transformers import ( + ContinuousTransformerWrapper, Encoder) +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint, opt_get def exists(val): @@ -18,7 +18,7 @@ def exists(val): def masked_mean(t, mask): t = t.masked_fill(~mask, 0.) - return t.sum(dim = 1) / mask.sum(dim = 1) + return t.sum(dim=1) / mask.sum(dim=1) class InfoNCE(nn.Module): @@ -84,26 +84,33 @@ def info_nce(query, positive_key, negative_keys=None, temperature=0.1, reduction raise ValueError(' must have 2 dimensions.') if negative_keys is not None: if negative_mode == 'unpaired' and negative_keys.dim() != 2: - raise ValueError(" must have 2 dimensions if == 'unpaired'.") + raise ValueError( + " must have 2 dimensions if == 'unpaired'.") if negative_mode == 'paired' and negative_keys.dim() != 3: - raise ValueError(" must have 3 dimensions if == 'paired'.") + raise ValueError( + " must have 3 dimensions if == 'paired'.") # Check matching number of samples. if len(query) != len(positive_key): - raise ValueError(' and must must have the same number of samples.') + raise ValueError( + ' and must must have the same number of samples.') if negative_keys is not None: if negative_mode == 'paired' and len(query) != len(negative_keys): - raise ValueError("If negative_mode == 'paired', then must have the same number of samples as .") + raise ValueError( + "If negative_mode == 'paired', then must have the same number of samples as .") # Embedding vectors should have same number of components. if query.shape[-1] != positive_key.shape[-1]: - raise ValueError('Vectors of and should have the same number of components.') + raise ValueError( + 'Vectors of and should have the same number of components.') if negative_keys is not None: if query.shape[-1] != negative_keys.shape[-1]: - raise ValueError('Vectors of and should have the same number of components.') + raise ValueError( + 'Vectors of and should have the same number of components.') # Normalize to unit vectors - query, positive_key, negative_keys = normalize(query, positive_key, negative_keys) + query, positive_key, negative_keys = normalize( + query, positive_key, negative_keys) if negative_keys is not None: # Explicit negative keys @@ -121,7 +128,8 @@ def info_nce(query, positive_key, negative_keys=None, temperature=0.1, reduction # First index in last dimension are the positive samples logits = torch.cat([positive_logit, negative_logits], dim=1) - labels = torch.zeros(len(logits), dtype=torch.long, device=query.device) + labels = torch.zeros( + len(logits), dtype=torch.long, device=query.device) else: # Negative keys are implicitly off-diagonal positive keys. @@ -161,14 +169,15 @@ class CollapsingTransformer(nn.Module): **encoder_kwargs, )) self.pre_combiner = nn.Sequential(nn.Conv1d(model_dim, output_dims, 1), - AttentionBlock(output_dims, num_heads=heads, do_checkpoint=False), + AttentionBlock( + output_dims, num_heads=heads, do_checkpoint=False), nn.Conv1d(output_dims, output_dims, 1)) self.mask_percentage = mask_percentage def forward(self, x, **transformer_kwargs): h = self.transformer(x, **transformer_kwargs) - h = h.permute(0,2,1) - h = checkpoint(self.pre_combiner, h).permute(0,2,1) + h = h.permute(0, 2, 1) + h = checkpoint(self.pre_combiner, h).permute(0, 2, 1) if self.training: mask = torch.rand_like(h.float()) > self.mask_percentage else: @@ -184,7 +193,7 @@ class ConvFormatEmbedding(nn.Module): def forward(self, x): y = self.emb(x) - return y.permute(0,2,1) + return y.permute(0, 2, 1) class ContrastiveAudio(nn.Module): @@ -204,7 +213,8 @@ class ContrastiveAudio(nn.Module): self.emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim // 2, kernel_size=5, stride=2, padding=2), nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1)) - self.transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, encoder_depth, mask_percent) + self.transformer = CollapsingTransformer( + model_dim, model_dim, transformer_heads, dropout, encoder_depth, mask_percent) self.to_latent = ml.Linear(latent_dim, latent_dim, bias=False) self.to_latent2 = ml.Linear(latent_dim, latent_dim, bias=False) @@ -219,7 +229,8 @@ class ContrastiveAudio(nn.Module): } def update_for_step(self, step, __): - self.to_latent2.weight.data = self.to_latent2.weight.data * .99 + self.to_latent.weight.data * .01 + self.to_latent2.weight.data = self.to_latent2.weight.data * .99 + \ + self.to_latent.weight.data * .01 def project(self, mel): h1 = self.emb(mel).permute(0, 2, 1) @@ -237,9 +248,9 @@ class ContrastiveAudio(nn.Module): if self.training: # Mask out big chunks of separate frequency bands for each clip. b, c, _ = mel_input1.shape - mask = torch.rand(b,c,1, device=mel_input1.device) > .3 + mask = torch.rand(b, c, 1, device=mel_input1.device) > .3 mel_input1 = mask * mel_input1 * (1-random()*.5) - mask = torch.rand(b,c,1, device=mel_input2.device) > .3 + mask = torch.rand(b, c, 1, device=mel_input2.device) > .3 mel_input2 = mask * mel_input2 * (1-random()*.5) h1 = self.emb(mel_input1).permute(0, 2, 1) @@ -261,8 +272,8 @@ def register_contrastive_audio(opt_net, opt): if __name__ == '__main__': clvp = ContrastiveAudio() - clvp(torch.randn(2,80,100), - torch.randn(2,80,95), + clvp(torch.randn(2, 80, 100), + torch.randn(2, 80, 95), return_loss=True) - v = torch.randn(2,512) - print(info_nce(v,v)) \ No newline at end of file + v = torch.randn(2, 512) + print(info_nce(v, v)) diff --git a/dlas/models/clip/cvvp.py b/dlas/models/clip/cvvp.py index e8a12b2a..5e022cef 100644 --- a/dlas/models/clip/cvvp.py +++ b/dlas/models/clip/cvvp.py @@ -1,16 +1,14 @@ -from random import random - import torch import torch.nn as nn import torch.nn.functional as F -from torch import einsum, distributed -from torch.distributed import get_world_size +from torch import einsum -from models.arch_util import AttentionBlock -from models.lucidrains.x_transformers import ContinuousTransformerWrapper, Encoder -from trainer.networks import register_model -from utils.util import opt_get, checkpoint -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.models.arch_util import AttentionBlock +from dlas.models.lucidrains.x_transformers import ( + ContinuousTransformerWrapper, Encoder) +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint, opt_get def exists(val): @@ -19,7 +17,7 @@ def exists(val): def masked_mean(t, mask): t = t.masked_fill(~mask, 0.) - return t.sum(dim = 1) / mask.sum(dim = 1) + return t.sum(dim=1) / mask.sum(dim=1) class CollapsingTransformer(nn.Module): @@ -41,14 +39,15 @@ class CollapsingTransformer(nn.Module): **encoder_kwargs, )) self.pre_combiner = nn.Sequential(nn.Conv1d(model_dim, output_dims, 1), - AttentionBlock(output_dims, num_heads=heads, do_checkpoint=False), + AttentionBlock( + output_dims, num_heads=heads, do_checkpoint=False), nn.Conv1d(output_dims, output_dims, 1)) self.mask_percentage = mask_percentage def forward(self, x, **transformer_kwargs): h = self.transformer(x, **transformer_kwargs) - h = h.permute(0,2,1) - h = checkpoint(self.pre_combiner, h).permute(0,2,1) + h = h.permute(0, 2, 1) + h = checkpoint(self.pre_combiner, h).permute(0, 2, 1) if self.training: mask = torch.rand_like(h.float()) > self.mask_percentage else: @@ -64,7 +63,7 @@ class ConvFormatEmbedding(nn.Module): def forward(self, x): y = self.emb(x) - return y.permute(0,2,1) + return y.permute(0, 2, 1) class CVVP(nn.Module): @@ -87,14 +86,18 @@ class CVVP(nn.Module): self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2), nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1)) - self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage) - self.to_conditioning_latent = ml.Linear(latent_dim, latent_dim, bias=False) + self.conditioning_transformer = CollapsingTransformer( + model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage) + self.to_conditioning_latent = ml.Linear( + latent_dim, latent_dim, bias=False) if mel_codes is None: - self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2) + self.speech_emb = nn.Conv1d( + mel_channels, model_dim, kernel_size=5, padding=2) else: self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) - self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) + self.speech_transformer = CollapsingTransformer( + model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) self.to_speech_latent = ml.Linear(latent_dim, latent_dim, bias=False) def get_grad_norm_parameter_groups(self): @@ -109,16 +112,16 @@ class CVVP(nn.Module): mel_cond, return_loss=False ): - cond_emb = self.cond_emb(mel_cond).permute(0,2,1) + cond_emb = self.cond_emb(mel_cond).permute(0, 2, 1) enc_cond = self.conditioning_transformer(cond_emb) cond_latents = self.to_conditioning_latent(enc_cond) - speech_emb = self.speech_emb(mel_input).permute(0,2,1) + speech_emb = self.speech_emb(mel_input).permute(0, 2, 1) enc_speech = self.speech_transformer(speech_emb) speech_latents = self.to_speech_latent(enc_speech) - - cond_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents)) + cond_latents, speech_latents = map(lambda t: F.normalize( + t, p=2, dim=-1), (cond_latents, speech_latents)) temp = self.temperature.exp() if not return_loss: @@ -127,7 +130,8 @@ class CVVP(nn.Module): sim = einsum('i d, j d -> i j', cond_latents, speech_latents) * temp labels = torch.arange(cond_latents.shape[0], device=mel_input.device) - loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 + loss = (F.cross_entropy(sim, labels) + + F.cross_entropy(sim.t(), labels)) / 2 return loss @@ -139,6 +143,6 @@ def register_cvvp(opt_net, opt): if __name__ == '__main__': clvp = CVVP() - clvp(torch.randn(2,80,100), - torch.randn(2,80,95), - return_loss=True) \ No newline at end of file + clvp(torch.randn(2, 80, 100), + torch.randn(2, 80, 95), + return_loss=True) diff --git a/dlas/models/clip/mel_text_clip.py b/dlas/models/clip/mel_text_clip.py index c053547d..dcabee85 100644 --- a/dlas/models/clip/mel_text_clip.py +++ b/dlas/models/clip/mel_text_clip.py @@ -1,22 +1,21 @@ import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange from torch import einsum -from models.lucidrains.dalle.transformer import Transformer -from trainer.networks import register_model -from utils.util import opt_get -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.models.lucidrains.dalle.transformer import Transformer +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get def exists(val): return val is not None -def masked_mean(t, mask, dim = 1): +def masked_mean(t, mask, dim=1): t = t.masked_fill(~mask[:, :, None], 0.) - return t.sum(dim = 1) / mask.sum(dim = 1)[..., None] + return t.sum(dim=1) / mask.sum(dim=1)[..., None] class MelTextCLIP(nn.Module): @@ -70,7 +69,8 @@ class MelTextCLIP(nn.Module): if text_mask is None: text_mask = torch.ones_like(text.float()).bool() text_emb = self.text_emb(text) - text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=text.device)) + text_emb += self.text_pos_emb( + torch.arange(text.shape[1], device=text.device)) with torch.autocast(text.device.type): enc_text = self.text_transformer(text_emb, mask=text_mask) text_latents = masked_mean(enc_text, text_mask, dim=1) @@ -78,9 +78,10 @@ class MelTextCLIP(nn.Module): def get_speech_projection(self, mel, voice_mask=None): if voice_mask is None: - voice_mask = torch.ones_like(mel[:,0,:].float()).bool() - speech_emb = self.speech_enc(mel).permute(0,2,1) - speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=mel.device)) + voice_mask = torch.ones_like(mel[:, 0, :].float()).bool() + speech_emb = self.speech_enc(mel).permute(0, 2, 1) + speech_emb += self.speech_pos_emb( + torch.arange(speech_emb.shape[1], device=mel.device)) with torch.autocast(speech_emb.device.type): enc_speech = self.speech_transformer(speech_emb, mask=voice_mask) speech_latents = masked_mean(enc_speech, voice_mask, dim=1) @@ -103,17 +104,21 @@ class MelTextCLIP(nn.Module): b, device = text.shape[0], text.device if self.training: - text_mask = torch.rand_like(text.float()) > self.text_mask_percentage - voice_mask = torch.rand_like(mel[:,0,:].float()) > self.voice_mask_percentage + text_mask = torch.rand_like( + text.float()) > self.text_mask_percentage + voice_mask = torch.rand_like( + mel[:, 0, :].float()) > self.voice_mask_percentage else: text_mask = torch.ones_like(text.float()).bool() - voice_mask = torch.ones_like(mel[:,0,:].float()).bool() + voice_mask = torch.ones_like(mel[:, 0, :].float()).bool() text_emb = self.text_emb(text) - text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device)) + text_emb += self.text_pos_emb( + torch.arange(text.shape[1], device=device)) - speech_emb = self.speech_enc(mel).permute(0,2,1) - speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device)) + speech_emb = self.speech_enc(mel).permute(0, 2, 1) + speech_emb += self.speech_pos_emb( + torch.arange(speech_emb.shape[1], device=device)) # Only autocast the transformer part. The MEL encoder loses accuracy if you autcast it. with torch.autocast(speech_emb.device.type): @@ -126,7 +131,8 @@ class MelTextCLIP(nn.Module): text_latents = self.to_text_latent(text_latents).float() speech_latents = self.to_speech_latent(speech_latents).float() - text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)) + text_latents, speech_latents = map(lambda t: F.normalize( + t, p=2, dim=-1), (text_latents, speech_latents)) temp = self.temperature.exp() @@ -136,7 +142,8 @@ class MelTextCLIP(nn.Module): sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp labels = torch.arange(b, device=device) - loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 + loss = (F.cross_entropy(sim, labels) + + F.cross_entropy(sim.t(), labels)) / 2 return loss @@ -147,8 +154,8 @@ def register_mel_text_clip(opt_net, opt): if __name__ == '__main__': clip = MelTextCLIP(text_mask_percentage=.2, voice_mask_percentage=.2) - clip(torch.randint(0,256,(2,120)), - torch.tensor([50,100]), - torch.randn(2,80,400), - torch.tensor([10100,10200]), - return_loss=True) \ No newline at end of file + clip(torch.randint(0, 256, (2, 120)), + torch.tensor([50, 100]), + torch.randn(2, 80, 400), + torch.tensor([10100, 10200]), + return_loss=True) diff --git a/dlas/models/clip/text_cond_clip.py b/dlas/models/clip/text_cond_clip.py index f221142c..d662d565 100644 --- a/dlas/models/clip/text_cond_clip.py +++ b/dlas/models/clip/text_cond_clip.py @@ -3,20 +3,20 @@ import torch.nn as nn import torch.nn.functional as F from torch import einsum -from models.audio.tts.unified_voice2 import ConditioningEncoder -from models.lucidrains.dalle.transformer import Transformer -from trainer.networks import register_model -from utils.util import opt_get -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.models.audio.tts.unified_voice2 import ConditioningEncoder +from dlas.models.lucidrains.dalle.transformer import Transformer +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get def exists(val): return val is not None -def masked_mean(t, mask, dim = 1): +def masked_mean(t, mask, dim=1): t = t.masked_fill(~mask[:, :, None], 0.) - return t.sum(dim = 1) / mask.sum(dim = 1)[..., None] + return t.sum(dim=1) / mask.sum(dim=1)[..., None] class VoiceCondCLIP(nn.Module): @@ -40,7 +40,8 @@ class VoiceCondCLIP(nn.Module): wav_token_compression=1024, ): super().__init__() - self.cond_encoder = ConditioningEncoder(80, dim_latent, do_checkpointing=True) + self.cond_encoder = ConditioningEncoder( + 80, dim_latent, do_checkpointing=True) self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech) self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) @@ -66,12 +67,14 @@ class VoiceCondCLIP(nn.Module): b, device = speech_tokens.shape[0], speech_tokens.device if self.training: - voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage + voice_mask = torch.rand_like( + speech_tokens.float()) > self.voice_mask_percentage else: voice_mask = torch.ones_like(speech_tokens.float()).bool() speech_emb = self.speech_emb(speech_tokens) - speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device)) + speech_emb += self.speech_pos_emb( + torch.arange(speech_emb.shape[1], device=device)) cond_latents = self.cond_encoder(cond_mel) @@ -79,7 +82,8 @@ class VoiceCondCLIP(nn.Module): speech_latents = masked_mean(enc_speech, voice_mask, dim=1) speech_latents = self.to_speech_latent(speech_latents) - cond_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents)) + cond_latents, speech_latents = map(lambda t: F.normalize( + t, p=2, dim=-1), (cond_latents, speech_latents)) temp = self.temperature.exp() @@ -89,7 +93,8 @@ class VoiceCondCLIP(nn.Module): sim = einsum('i d, j d -> i j', cond_latents, speech_latents) * temp labels = torch.arange(b, device=device) - loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 + loss = (F.cross_entropy(sim, labels) + + F.cross_entropy(sim.t(), labels)) / 2 return loss @@ -100,13 +105,13 @@ def register_voice_cond_clip(opt_net, opt): if __name__ == '__main__': clip = VoiceCondCLIP(voice_mask_percentage=.2) - clip(torch.randn(2,80,400), - torch.randint(0,8192,(2,250)), - torch.tensor([101,102]), + clip(torch.randn(2, 80, 400), + torch.randint(0, 8192, (2, 250)), + torch.tensor([101, 102]), return_loss=True) nonloss = clip( - torch.randn(2, 80, 400), - torch.randint(0,8192,(2,250)), - torch.tensor([101,102]), - return_loss=False) - print(nonloss.shape) \ No newline at end of file + torch.randn(2, 80, 400), + torch.randint(0, 8192, (2, 250)), + torch.tensor([101, 102]), + return_loss=False) + print(nonloss.shape) diff --git a/dlas/models/clip/text_voice_clip.py b/dlas/models/clip/text_voice_clip.py index 26bac0c9..a00c9583 100644 --- a/dlas/models/clip/text_voice_clip.py +++ b/dlas/models/clip/text_voice_clip.py @@ -3,24 +3,24 @@ from random import randint import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange from torch import einsum from x_transformers import Encoder -from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder -from models.lucidrains.dalle.transformer import Transformer -from trainer.networks import register_model -from utils.util import opt_get -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.models.audio.tts.unet_diffusion_tts7 import \ + CheckpointedXTransformerEncoder +from dlas.models.lucidrains.dalle.transformer import Transformer +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get def exists(val): return val is not None -def masked_mean(t, mask, dim = 1): +def masked_mean(t, mask, dim=1): t = t.masked_fill(~mask[:, :, None], 0.) - return t.sum(dim = 1) / mask.sum(dim = 1)[..., None] + return t.sum(dim=1) / mask.sum(dim=1)[..., None] class VoiceCLIP(nn.Module): @@ -50,7 +50,8 @@ class VoiceCLIP(nn.Module): wav_token_compression=1024, use_xformers=False, clip_mels=False, - min_mel_size=10, # Default is approximately .5sec with default mel specs. + # Default is approximately .5sec with default mel specs. + min_mel_size=10, distributed_collect=False, ): super().__init__() @@ -131,11 +132,15 @@ class VoiceCLIP(nn.Module): if self.training: if self.clip_mels: margin = speech_tokens.shape[-1] - self.min_mel_size - speech_tokens = speech_tokens[:, :self.min_mel_size+randint(0,margin)] - voice_mask = torch.ones_like(speech_tokens.float()).bool() # Disable voice masking in this case. + speech_tokens = speech_tokens[:, + :self.min_mel_size+randint(0, margin)] + # Disable voice masking in this case. + voice_mask = torch.ones_like(speech_tokens.float()).bool() else: - voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage - text_mask = torch.rand_like(text.float()) > self.text_mask_percentage + voice_mask = torch.rand_like( + speech_tokens.float()) > self.voice_mask_percentage + text_mask = torch.rand_like( + text.float()) > self.text_mask_percentage else: text_mask = torch.ones_like(text.float()).bool() voice_mask = torch.ones_like(speech_tokens.float()).bool() @@ -144,8 +149,10 @@ class VoiceCLIP(nn.Module): speech_emb = self.speech_emb(speech_tokens) if not self.xformers: - text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device)) - speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device)) + text_emb += self.text_pos_emb( + torch.arange(text.shape[1], device=device)) + speech_emb += self.speech_pos_emb( + torch.arange(speech_emb.shape[1], device=device)) enc_text = self.text_transformer(text_emb, mask=text_mask) enc_speech = self.speech_transformer(speech_emb, mask=voice_mask) @@ -157,17 +164,22 @@ class VoiceCLIP(nn.Module): speech_latents = self.to_speech_latent(speech_latents) if self.distributed_collect: - collective = [torch.zeros_like(text_latents) for _ in range(torch.distributed.get_world_size())] + collective = [torch.zeros_like(text_latents) for _ in range( + torch.distributed.get_world_size())] torch.distributed.all_gather(collective, text_latents) - collective[torch.distributed.get_rank()] = text_latents # For gradient propagation. + # For gradient propagation. + collective[torch.distributed.get_rank()] = text_latents text_latents = torch.cat(collective, dim=0) - collective = [torch.zeros_like(speech_latents) for _ in range(torch.distributed.get_world_size())] - collective[torch.distributed.get_rank()] = speech_latents # For gradient propagation. + collective = [torch.zeros_like(speech_latents) for _ in range( + torch.distributed.get_world_size())] + # For gradient propagation. + collective[torch.distributed.get_rank()] = speech_latents torch.distributed.all_gather(collective, speech_latents) speech_latents = torch.cat(collective, dim=0) b = text_latents.shape[0] - text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)) + text_latents, speech_latents = map(lambda t: F.normalize( + t, p=2, dim=-1), (text_latents, speech_latents)) temp = self.temperature.exp() @@ -177,7 +189,8 @@ class VoiceCLIP(nn.Module): sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp labels = torch.arange(b, device=device) - loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 + loss = (F.cross_entropy(sim, labels) + + F.cross_entropy(sim.t(), labels)) / 2 return loss @@ -187,11 +200,12 @@ def register_voice_clip(opt_net, opt): if __name__ == '__main__': - clip = VoiceCLIP(text_mask_percentage=.2, voice_mask_percentage=.2, use_xformers=True) - clip(torch.randint(0,256,(2,120)), - torch.randint(0,8192,(2,250)), + clip = VoiceCLIP(text_mask_percentage=.2, + voice_mask_percentage=.2, use_xformers=True) + clip(torch.randint(0, 256, (2, 120)), + torch.randint(0, 8192, (2, 250)), return_loss=True) - nonloss = clip(torch.randint(0,256,(2,120)), - torch.randint(0,8192,(2,250)), - return_loss=False) + nonloss = clip(torch.randint(0, 256, (2, 120)), + torch.randint(0, 8192, (2, 250)), + return_loss=False) print(nonloss.shape) diff --git a/dlas/models/diffusion/fp16_util.py b/dlas/models/diffusion/fp16_util.py index 9033a09f..a792597e 100644 --- a/dlas/models/diffusion/fp16_util.py +++ b/dlas/models/diffusion/fp16_util.py @@ -68,7 +68,8 @@ def master_params_to_model_params(param_groups_and_shapes, master_params): # silently not copy any parameters. for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): for (_, param), unflat_master_param in zip( - param_group, unflatten_master_params(param_group, master_param.view(-1)) + param_group, unflatten_master_params( + param_group, master_param.view(-1)) ): param.detach().copy_(unflat_master_param) @@ -99,7 +100,8 @@ def master_params_to_state_dict( master_params, param_groups_and_shapes ): for (name, _), unflat_master_param in zip( - param_group, unflatten_master_params(param_group, master_param.view(-1)) + param_group, unflatten_master_params( + param_group, master_param.view(-1)) ): assert name in state_dict state_dict[name] = unflat_master_param @@ -116,10 +118,12 @@ def state_dict_to_master_params(model, state_dict, use_fp16): named_model_params = [ (name, state_dict[name]) for name, _ in model.named_parameters() ] - param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) + param_groups_and_shapes = get_param_groups_and_shapes( + named_model_params) master_params = make_master_params(param_groups_and_shapes) else: - master_params = [state_dict[name] for name, _ in model.named_parameters()] + master_params = [state_dict[name] + for name, _ in model.named_parameters()] return master_params @@ -165,7 +169,8 @@ class MixedPrecisionTrainer: self.param_groups_and_shapes = get_param_groups_and_shapes( self.model.named_parameters() ) - self.master_params = make_master_params(self.param_groups_and_shapes) + self.master_params = make_master_params( + self.param_groups_and_shapes) self.model.convert_to_fp16() def zero_grad(self): @@ -185,8 +190,10 @@ class MixedPrecisionTrainer: return self._optimize_normal(opt) def _optimize_fp16(self, opt: th.optim.Optimizer): - model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) - grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) + model_grads_to_master_grads( + self.param_groups_and_shapes, self.master_params) + grad_norm, param_norm = self._compute_norms( + grad_scale=2 ** self.lg_loss_scale) if check_overflow(grad_norm): self.lg_loss_scale -= 1 zero_master_grads(self.master_params) @@ -194,7 +201,8 @@ class MixedPrecisionTrainer: opt.step(grad_scale=2.0 ** self.lg_loss_scale) zero_master_grads(self.master_params) - master_params_to_model_params(self.param_groups_and_shapes, self.master_params) + master_params_to_model_params( + self.param_groups_and_shapes, self.master_params) self.lg_loss_scale += self.fp16_scale_growth return True @@ -210,7 +218,8 @@ class MixedPrecisionTrainer: with th.no_grad(): param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 if p.grad is not None: - grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 + grad_norm += th.norm(p.grad, p=2, + dtype=th.float32).item() ** 2 return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) def master_params_to_state_dict(self, master_params): diff --git a/dlas/models/diffusion/gaussian_diffusion.py b/dlas/models/diffusion/gaussian_diffusion.py index 5ffd0310..de6fd4d4 100644 --- a/dlas/models/diffusion/gaussian_diffusion.py +++ b/dlas/models/diffusion/gaussian_diffusion.py @@ -15,8 +15,9 @@ import torch as th from torch.distributions import Normal from tqdm import tqdm -from models.diffusion.nn import mean_flat -from models.diffusion.losses import normal_kl, discretized_gaussian_log_likelihood +from dlas.models.diffusion.losses import (discretized_gaussian_log_likelihood, + normal_kl) +from dlas.models.diffusion.nn import mean_flat def causal_timestep_adjustment(t, S, num_timesteps, causal_slope=1, add_jitter=True): @@ -1186,8 +1187,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): def test_causal_training_losses(): - from models.diffusion.respace import SpacedDiffusion - from models.diffusion.respace import space_timesteps + from models.diffusion.respace import SpacedDiffusion, space_timesteps diff = SpacedDiffusion(use_timesteps=space_timesteps(4000, [4000]), model_mean_type='epsilon', model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', 4000), conditioning_free=False, conditioning_free_k=1) diff --git a/dlas/models/diffusion/losses.py b/dlas/models/diffusion/losses.py index 251e42e4..e94d40c6 100644 --- a/dlas/models/diffusion/losses.py +++ b/dlas/models/diffusion/losses.py @@ -5,7 +5,6 @@ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0 """ import numpy as np - import torch as th @@ -71,7 +70,8 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales): log_probs = th.where( x < -0.999, log_cdf_plus, - th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + th.where(x > 0.999, log_one_minus_cdf_min, + th.log(cdf_delta.clamp(min=1e-12))), ) assert log_probs.shape == x.shape return log_probs diff --git a/dlas/models/diffusion/nn.py b/dlas/models/diffusion/nn.py index 201fadb9..607dc6ff 100644 --- a/dlas/models/diffusion/nn.py +++ b/dlas/models/diffusion/nn.py @@ -6,7 +6,8 @@ import math import torch as th import torch.nn as nn -import torch_intermediary as ml + +import dlas.torch_intermediary as ml # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. @@ -121,15 +122,17 @@ def timestep_embedding(timesteps, dim, max_period=10000): """ half = dim // 2 freqs = th.exp( - -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half + -math.log(max_period) * th.arange(start=0, + end=half, dtype=th.float32) / half ).to(device=timesteps.device) if len(timesteps.shape) == 1: args = timesteps[:, None].float() * freqs[None] else: - args = (timesteps.float() * freqs.view(1,half,1)).permute(0,2,1) + args = (timesteps.float() * freqs.view(1, half, 1)).permute(0, 2, 1) embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) if dim % 2: - embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) + embedding = th.cat( + [embedding, th.zeros_like(embedding[:, :1])], dim=-1) return embedding @@ -163,7 +166,8 @@ class CheckpointFunction(th.autograd.Function): @staticmethod def backward(ctx, *output_grads): - ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + ctx.input_tensors = [x.detach().requires_grad_(True) + for x in ctx.input_tensors] with th.enable_grad(): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d diff --git a/dlas/models/diffusion/resample.py b/dlas/models/diffusion/resample.py index 8690bb1d..ea2c14b2 100644 --- a/dlas/models/diffusion/resample.py +++ b/dlas/models/diffusion/resample.py @@ -73,6 +73,7 @@ class DeterministicSampler: Returns the same equally spread-out sampling schedule every time it is called. Automatically handles distributed cases by sharing the load across all entities. reset() must be called once a full batch is completed. """ + def __init__(self, diffusion, sampling_range, env): super().__init__() self.timesteps = diffusion.num_timesteps @@ -82,7 +83,8 @@ class DeterministicSampler: else: self.world_size = 1 # The sampling range gets spread out across multiple distributed entities. - rnge = th.arange(self.rank, sampling_range, step=self.world_size).float() / sampling_range + rnge = th.arange(self.rank, sampling_range, + step=self.world_size).float() / sampling_range self.indices = (rnge * self.timesteps).long() def sample(self, batch_size, device): @@ -91,7 +93,8 @@ class DeterministicSampler: """ assert batch_size < self.indices.shape[0] if self.counter+batch_size > self.indices.shape[0]: - print(f"Diffusion DeterministicSampler; Likely error. {self.counter}, {batch_size}, {self.indices.shape[0]}. Did you forget to set the sampling range to your batch size for the deterministic sampler?") + print( + f"Diffusion DeterministicSampler; Likely error. {self.counter}, {batch_size}, {self.indices.shape[0]}. Did you forget to set the sampling range to your batch size for the deterministic sampler?") self.counter = 0 # Recover by setting to 0. indices = self.indices[self.counter:self.counter+batch_size].to(device) self.counter = self.counter + batch_size @@ -128,14 +131,17 @@ class LossAwareSampler(ScheduleSampler): batch_sizes = [x.item() for x in batch_sizes] max_bs = max(batch_sizes) - timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] - loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + timestep_batches = [th.zeros(max_bs).to(local_ts) + for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) + for bs in batch_sizes] dist.all_gather(timestep_batches, local_ts) dist.all_gather(loss_batches, local_losses) timesteps = [ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] ] - losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) + for x in y[:bs]] self.update_with_all_losses(timesteps, losses) @abstractmethod diff --git a/dlas/models/diffusion/respace.py b/dlas/models/diffusion/respace.py index 03aa9234..3f7742c3 100644 --- a/dlas/models/diffusion/respace.py +++ b/dlas/models/diffusion/respace.py @@ -28,7 +28,7 @@ def space_timesteps(num_timesteps, section_counts): """ if isinstance(section_counts, str): if section_counts.startswith("ddim"): - desired_count = int(section_counts[len("ddim") :]) + desired_count = int(section_counts[len("ddim"):]) for i in range(1, num_timesteps): if len(range(0, num_timesteps, i)) == desired_count: return set(range(0, num_timesteps, i)) @@ -74,7 +74,8 @@ class SpacedDiffusion(GaussianDiffusion): self.timestep_map = [] self.original_num_steps = len(kwargs["betas"]) - base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + base_diffusion = GaussianDiffusion( + **kwargs) # pylint: disable=missing-kwoa last_alpha_cumprod = 1.0 new_betas = [] for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): @@ -125,7 +126,8 @@ class _WrappedModel: self.original_num_steps = original_num_steps def __call__(self, x, ts, **kwargs): - map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + map_tensor = th.tensor( + self.timestep_map, device=ts.device, dtype=ts.dtype) new_ts = map_tensor[ts] if self.rescale_timesteps: new_ts = new_ts.float() * (1000.0 / self.original_num_steps) @@ -140,8 +142,9 @@ class _WrappedAutoregressiveModel: self.original_num_steps = original_num_steps def __call__(self, x, x0, ts, **kwargs): - map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + map_tensor = th.tensor( + self.timestep_map, device=ts.device, dtype=ts.dtype) new_ts = map_tensor[ts] if self.rescale_timesteps: new_ts = new_ts.float() * (1000.0 / self.original_num_steps) - return self.model(x, x0, new_ts, **kwargs) \ No newline at end of file + return self.model(x, x0, new_ts, **kwargs) diff --git a/dlas/models/diffusion/rrdb_diffusion.py b/dlas/models/diffusion/rrdb_diffusion.py index 933fe467..b13d7617 100644 --- a/dlas/models/diffusion/rrdb_diffusion.py +++ b/dlas/models/diffusion/rrdb_diffusion.py @@ -2,12 +2,11 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models.arch_util import ConvGnLelu, default_init_weights, make_layer -from models.diffusion.nn import timestep_embedding -from trainer.networks import register_model -from utils.util import checkpoint -import torch_intermediary as ml - +import dlas.torch_intermediary as ml +from dlas.models.arch_util import ConvGnLelu, default_init_weights, make_layer +from dlas.models.diffusion.nn import timestep_embedding +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint # Conditionally uses torch's checkpoint functionality if it is enabled in the opt file. @@ -26,7 +25,8 @@ class ResidualDenseBlock(nn.Module): super(ResidualDenseBlock, self).__init__() self.embedding = embedding if embedding: - self.first_conv = ConvGnLelu(mid_channels, mid_channels, activation=True, norm=False, bias=True) + self.first_conv = ConvGnLelu( + mid_channels, mid_channels, activation=True, norm=False, bias=True) self.emb_layers = nn.Sequential( nn.SiLU(), ml.Linear( @@ -85,7 +85,8 @@ class RRDB(nn.Module): def __init__(self, mid_channels, growth_channels=32): super(RRDB, self).__init__() - self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels, embedding=True) + self.rdb1 = ResidualDenseBlock( + mid_channels, growth_channels, embedding=True) self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels) self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels) self.normalize = nn.GroupNorm(num_groups=8, num_channels=mid_channels) @@ -137,9 +138,12 @@ class RRDBNet(nn.Module): self.mid_channels = mid_channels # The diffusion RRDB starts with a full resolution image and downsamples into a .25 working space - self.input_block = ConvGnLelu(in_channels, mid_channels, kernel_size=7, stride=1, activation=True, norm=False, bias=True) - self.down1 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=False, bias=True) - self.down2 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=False, bias=True) + self.input_block = ConvGnLelu( + in_channels, mid_channels, kernel_size=7, stride=1, activation=True, norm=False, bias=True) + self.down1 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, + stride=2, activation=True, norm=False, bias=True) + self.down2 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, + stride=2, activation=True, norm=False, bias=True) # Guided diffusion uses a time embedding. time_embed_dim = mid_channels * 4 @@ -155,16 +159,21 @@ class RRDBNet(nn.Module): mid_channels=mid_channels, growth_channels=growth_channels) - self.conv_body = nn.Conv2d(self.mid_channels, self.mid_channels, 3, 1, 1) + self.conv_body = nn.Conv2d( + self.mid_channels, self.mid_channels, 3, 1, 1) # upsample - self.conv_up1 = nn.Conv2d(self.mid_channels, self.mid_channels, 3, 1, 1) - self.conv_up2 = nn.Conv2d(self.mid_channels*2, self.mid_channels, 3, 1, 1) + self.conv_up1 = nn.Conv2d( + self.mid_channels, self.mid_channels, 3, 1, 1) + self.conv_up2 = nn.Conv2d( + self.mid_channels*2, self.mid_channels, 3, 1, 1) self.conv_up3 = None - self.conv_hr = nn.Conv2d(self.mid_channels*2, self.mid_channels, 3, 1, 1) + self.conv_hr = nn.Conv2d( + self.mid_channels*2, self.mid_channels, 3, 1, 1) self.conv_last = nn.Conv2d(self.mid_channels, out_channels, 3, 1, 1) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.normalize = nn.GroupNorm(num_groups=8, num_channels=self.mid_channels) + self.normalize = nn.GroupNorm( + num_groups=8, num_channels=self.mid_channels) for m in [ self.conv_body, self.conv_up1, @@ -178,13 +187,16 @@ class RRDBNet(nn.Module): emb = self.time_embed(timestep_embedding(timesteps, self.mid_channels)) _, _, new_height, new_width = x.shape - upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") + upsampled = F.interpolate( + low_res, (new_height, new_width), mode="bilinear") x = torch.cat([x, upsampled], dim=1) if correction_factors is not None: - correction_factors = correction_factors.view(x.shape[0], -1, 1, 1).repeat(1, 1, new_height, new_width) + correction_factors = correction_factors.view( + x.shape[0], -1, 1, 1).repeat(1, 1, new_height, new_width) else: - correction_factors = torch.zeros((b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device) + correction_factors = torch.zeros( + (b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device) x = torch.cat([x, correction_factors], dim=1) d1 = self.input_block(x) @@ -213,9 +225,9 @@ def register_rrdb_diffusion(opt_net, opt): if __name__ == '__main__': - model = RRDBNet(6,6) - x = torch.randn(1,3,128,128) - l = torch.randn(1,3,32,32) + model = RRDBNet(6, 6) + x = torch.randn(1, 3, 128, 128) + l = torch.randn(1, 3, 32, 32) t = torch.LongTensor([555]) y = model(x, t, l) print(y.shape, y.mean(), y.std(), y.min(), y.max()) diff --git a/dlas/models/diffusion/unet_diffusion.py b/dlas/models/diffusion/unet_diffusion.py index 6ebbae02..817f297b 100644 --- a/dlas/models/diffusion/unet_diffusion.py +++ b/dlas/models/diffusion/unet_diffusion.py @@ -1,26 +1,19 @@ +import math from abc import abstractmethod -import math - +import dlas.torch_intermediary as ml import numpy as np import torch import torch as th import torch.nn as nn import torch.nn.functional as F -import torchvision # For debugging, not actually used. - -from models.diffusion.fp16_util import convert_module_to_f16, convert_module_to_f32 -from models.diffusion.nn import ( - conv_nd, - linear, - avg_pool_nd, - zero_module, - normalization, - timestep_embedding, -) -from trainer.networks import register_model -from utils.util import checkpoint -import torch_intermediary as ml +from dlas.models.diffusion.fp16_util import (convert_module_to_f16, + convert_module_to_f32) +from dlas.models.diffusion.nn import (avg_pool_nd, conv_nd, linear, + normalization, timestep_embedding, + zero_module) +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint class AttentionPool2d(nn.Module): @@ -48,7 +41,8 @@ class AttentionPool2d(nn.Module): b, c, *_spatial = x.shape x = x.reshape(b, c, -1) # NC(HW) x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) - x = x + self.positional_embedding[None, :, :x.shape[-1]].to(x.dtype) # NC(HW+1) + x = x + self.positional_embedding[None, + :, :x.shape[-1]].to(x.dtype) # NC(HW+1) x = self.qkv_proj(x) x = self.attention(x) x = self.c_proj(x) @@ -109,7 +103,8 @@ class Upsample(nn.Module): if dims == 1: ksize = 5 pad = 2 - self.conv = conv_nd(dims, self.channels, self.out_channels, ksize, padding=pad) + self.conv = conv_nd(dims, self.channels, + self.out_channels, ksize, padding=pad) def forward(self, x): assert x.shape[1] == self.channels @@ -152,7 +147,7 @@ class Downsample(nn.Module): elif dims == 2: stride = 2 else: - stride = (1,2,2) + stride = (1, 2, 2) if factor is not None: stride = factor if use_conv: @@ -209,7 +204,8 @@ class ResBlock(TimestepBlock): self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), - conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), + conv_nd(dims, channels, self.out_channels, + kernel_size, padding=padding), ) self.updown = up or down @@ -235,7 +231,8 @@ class ResBlock(TimestepBlock): nn.SiLU(), nn.Dropout(p=dropout), zero_module( - conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) + conv_nd(dims, self.out_channels, self.out_channels, + kernel_size, padding=padding) ), ) @@ -246,7 +243,8 @@ class ResBlock(TimestepBlock): dims, channels, self.out_channels, kernel_size, padding=padding ) else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 1) def forward(self, x, emb): """ @@ -374,13 +372,15 @@ class QKVAttentionLegacy(nn.Module): bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, + length).split(ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = th.einsum( "bct,bcs->bts", q * scale, k * scale ) # More stable with f16 than dividing afterwards if rel_pos is not None: - weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1]) + weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape( + bs * self.n_heads, weight.shape[-2], weight.shape[-1]) weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) if mask is not None: # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. @@ -428,7 +428,8 @@ class QKVAttention(nn.Module): mask = mask.repeat(self.n_heads, 1).unsqueeze(1) weight = weight * mask weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + a = th.einsum("bts,bcs->bct", weight, + v.reshape(bs * self.n_heads, ch, length)) return a.reshape(bs, -1, length) @staticmethod @@ -519,7 +520,8 @@ class UNetModel(nn.Module): # nn.Embedding self.label_emb = ml.Embedding(num_classes, time_embed_dim) self.use_raw_y_as_embedding = use_raw_y_as_embedding - assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive. + # These are mutually-exclusive. + assert not ((self.num_classes is not None) and use_raw_y_as_embedding) self.input_blocks = nn.ModuleList( [ @@ -651,7 +653,8 @@ class UNetModel(nn.Module): self.out = nn.Sequential( normalization(ch), nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + zero_module(conv_nd(dims, model_channels, + out_channels, 3, padding=1)), ) def forward(self, x, timesteps, y=None): @@ -664,7 +667,8 @@ class UNetModel(nn.Module): :return: an [N x C x ...] Tensor of outputs. """ hs = [] - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = self.time_embed(timestep_embedding( + timesteps, self.model_channels)) if self.num_classes is not None: assert y.shape == (x.shape[0],) @@ -697,16 +701,20 @@ class SuperResModel(UNetModel): def forward(self, x, timesteps, low_res=None, corruption_factor=None, **kwargs): b, _, new_height, new_width = x.shape - upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") + upsampled = F.interpolate( + low_res, (new_height, new_width), mode="bilinear") if corruption_factor is not None: if corruption_factor.shape[1] != self.num_corruptions: if not hasattr(self, '_corruption_factor_warning'): - print(f"Warning! Dataloader gave us {corruption_factor.shape[1]} dim but we are only processing {self.num_corruptions}. The last n corruptions will be truncated.") + print( + f"Warning! Dataloader gave us {corruption_factor.shape[1]} dim but we are only processing {self.num_corruptions}. The last n corruptions will be truncated.") self._corruption_factor_warning = True corruption_factor = corruption_factor[:, :self.num_corruptions] - corruption_factor = corruption_factor.view(b, -1, 1, 1).repeat(1, 1, new_height, new_width) + corruption_factor = corruption_factor.view( + b, -1, 1, 1).repeat(1, 1, new_height, new_width) else: - corruption_factor = torch.zeros((b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device) + corruption_factor = torch.zeros( + (b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device) upsampled = torch.cat([upsampled, corruption_factor], dim=1) x = th.cat([x, upsampled], dim=1) res = super().forward(x, timesteps, **kwargs) @@ -891,7 +899,8 @@ class EncoderUNetModel(nn.Module): :param timesteps: a 1-D batch of timesteps. :return: an [N x K] Tensor of outputs. """ - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = self.time_embed(timestep_embedding( + timesteps, self.model_channels)) results = [] h = x.type(self.dtype) @@ -908,18 +917,20 @@ class EncoderUNetModel(nn.Module): h = h.type(x.dtype) return self.out(h) + @register_model def register_unet_diffusion(opt_net, opt): return SuperResModel(**opt_net['args']) + if __name__ == '__main__': attention_ds = [] for res in "16,8".split(","): attention_ds.append(128 // int(res)) srm = SuperResModel(image_size=128, in_channels=3, model_channels=64, out_channels=3, num_res_blocks=1, attention_resolutions=attention_ds, num_heads=4, num_heads_upsample=-1, use_scale_shift_norm=True) - x = torch.randn(1,3,128,128) - l = torch.randn(1,3,32,32) + x = torch.randn(1, 3, 128, 128) + l = torch.randn(1, 3, 32, 32) ts = torch.LongTensor([555]) y = srm(x, ts, low_res=l) - print(y.shape, y.mean(), y.std(), y.min(), y.max()) \ No newline at end of file + print(y.shape, y.mean(), y.std(), y.min(), y.max()) diff --git a/dlas/models/diffusion/unet_latent_guide.py b/dlas/models/diffusion/unet_latent_guide.py index e298a900..88305614 100644 --- a/dlas/models/diffusion/unet_latent_guide.py +++ b/dlas/models/diffusion/unet_latent_guide.py @@ -1,8 +1,7 @@ import functools -from abc import abstractmethod - import math -from typing import Union, Type, Callable, Optional, List +from abc import abstractmethod +from typing import Callable, List, Optional, Type, Union import numpy as np import torch @@ -15,18 +14,14 @@ from torch import Tensor from torchvision.models import resnet50 from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1 -from models.diffusion.fp16_util import convert_module_to_f16, convert_module_to_f32 -from models.diffusion.nn import ( - conv_nd, - linear, - avg_pool_nd, - zero_module, - normalization, - timestep_embedding, -) -from trainer.networks import register_model -from utils.util import checkpoint -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.models.diffusion.fp16_util import (convert_module_to_f16, + convert_module_to_f32) +from dlas.models.diffusion.nn import (avg_pool_nd, conv_nd, linear, + normalization, timestep_embedding, + zero_module) +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint class AttentionPool2d(nn.Module): @@ -105,7 +100,8 @@ class Upsample(nn.Module): self.use_conv = use_conv self.dims = dims if use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + self.conv = conv_nd(dims, self.channels, + self.out_channels, 3, padding=1) def forward(self, x): assert x.shape[1] == self.channels @@ -215,7 +211,8 @@ class ResBlock(TimestepBlock): nn.SiLU(), nn.Dropout(p=dropout), zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + conv_nd(dims, self.out_channels, + self.out_channels, 3, padding=1) ), ) @@ -226,7 +223,8 @@ class ResBlock(TimestepBlock): dims, channels, self.out_channels, 3, padding=1 ) else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 1) def forward(self, x, emb): """ @@ -349,7 +347,8 @@ class QKVAttentionLegacy(nn.Module): bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, + length).split(ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = th.einsum( "bct,bcs->bts", q * scale, k * scale @@ -390,7 +389,8 @@ class QKVAttention(nn.Module): (k * scale).view(bs * self.n_heads, ch, length), ) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + a = th.einsum("bts,bcs->bct", weight, + v.reshape(bs * self.n_heads, ch, length)) return a.reshape(bs, -1, length) @staticmethod @@ -540,7 +540,8 @@ class UNetModel(nn.Module): ds *= 2 self._feature_size += ch - self.latent_join_reduce = ResBlock(ch*2, time_embed_dim, dropout, out_channels=ch, dims=dims, use_scale_shift_norm=use_scale_shift_norm) + self.latent_join_reduce = ResBlock( + ch*2, time_embed_dim, dropout, out_channels=ch, dims=dims, use_scale_shift_norm=use_scale_shift_norm) self.middle_block = TimestepEmbedSequential( ResBlock( ch, @@ -611,7 +612,8 @@ class UNetModel(nn.Module): self.out = nn.Sequential( normalization(ch), nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + zero_module(conv_nd(dims, model_channels, + out_channels, 3, padding=1)), ) def convert_to_fp16(self): @@ -644,7 +646,8 @@ class UNetModel(nn.Module): ), "must specify y if and only if the model is class-conditional" hs = [] - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = self.time_embed(timestep_embedding( + timesteps, self.model_channels)) if self.num_classes is not None: assert y.shape == (x.shape[0],) @@ -655,7 +658,8 @@ class UNetModel(nn.Module): h = module(h, emb) hs.append(h) b, c = latent.shape - h = torch.cat([h, latent.view(b,c,1,1).repeat(1,1,h.shape[-2],h.shape[-1])], dim=1) + h = torch.cat([h, latent.view(b, c, 1, 1).repeat( + 1, 1, h.shape[-2], h.shape[-1])], dim=1) h = self.latent_join_reduce(h, emb) h = self.middle_block(h, emb) for module in self.output_blocks: @@ -678,11 +682,14 @@ class SuperResModel(UNetModel): def forward(self, x, timesteps, latent, low_res=None, corruption_factor=None, **kwargs): b, _, new_height, new_width = x.shape - upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") + upsampled = F.interpolate( + low_res, (new_height, new_width), mode="bilinear") if corruption_factor is not None: - corruption_factor = corruption_factor.view(b, -1, 1, 1).repeat(1, 1, new_height, new_width) + corruption_factor = corruption_factor.view( + b, -1, 1, 1).repeat(1, 1, new_height, new_width) else: - corruption_factor = torch.zeros((b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device) + corruption_factor = torch.zeros( + (b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device) upsampled = torch.cat([upsampled, corruption_factor], dim=1) x = th.cat([x, upsampled], dim=1) res = super().forward(x, latent, timesteps, **kwargs) @@ -728,21 +735,22 @@ class ResNetEncoder(nn.Module): self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) - f=128 + f = 128 if self.depth > 2: self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) - f=256 + f = 256 if self.depth > 3: self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) - f=512 + f = 512 self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = ml.Linear(f * block.expansion, output_dim) for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) @@ -753,9 +761,11 @@ class ResNetEncoder(nn.Module): if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): - nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + # type: ignore[arg-type] + nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, BasicBlock): - nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + # type: ignore[arg-type] + nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1, dilate: bool = False) -> nn.Sequential: @@ -832,10 +842,10 @@ if __name__ == '__main__': for res in "16,8".split(","): attention_ds.append(128 // int(res)) srm = UnetWithBuiltInLatentEncoder(image_size=64, in_channels=3, model_channels=64, out_channels=3, num_res_blocks=1, attention_resolutions=attention_ds, num_heads=4, - num_heads_upsample=-1, use_scale_shift_norm=True) - x = torch.randn(1,3,64,64) - alt_x = torch.randn(1,3,64,64) - l = torch.randn(1,3,32,32) + num_heads_upsample=-1, use_scale_shift_norm=True) + x = torch.randn(1, 3, 64, 64) + alt_x = torch.randn(1, 3, 64, 64) + l = torch.randn(1, 3, 32, 32) ts = torch.LongTensor([555]) y = srm(x, ts, alt_x, low_res=l) print(y.shape, y.mean(), y.std(), y.min(), y.max()) diff --git a/dlas/models/image_generation/RRDBNet_arch.py b/dlas/models/image_generation/RRDBNet_arch.py index 1e5335c6..9a69c4bd 100644 --- a/dlas/models/image_generation/RRDBNet_arch.py +++ b/dlas/models/image_generation/RRDBNet_arch.py @@ -8,9 +8,10 @@ import torch.nn.functional as F import torchvision from torchvision.models.resnet import Bottleneck -from models.arch_util import make_layer, default_init_weights, ConvGnSilu, ConvGnLelu -from trainer.networks import register_model -from utils.util import checkpoint, sequential_checkpoint, opt_get +from dlas.models.arch_util import (ConvGnLelu, ConvGnSilu, + default_init_weights, make_layer) +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint, opt_get, sequential_checkpoint class ResidualDenseBlock(nn.Module): @@ -35,7 +36,6 @@ class ResidualDenseBlock(nn.Module): for i in range(5): default_init_weights(getattr(self, f'conv{i+1}'), init_weight) - def forward(self, x): """Forward function. @@ -70,7 +70,8 @@ class RRDB(nn.Module): self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels) self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels) if reduce_to is not None: - self.reducer = ConvGnLelu(mid_channels, reduce_to, kernel_size=3, activation=False, norm=False, bias=True) + self.reducer = ConvGnLelu( + mid_channels, reduce_to, kernel_size=3, activation=False, norm=False, bias=True) self.recover_ch = mid_channels - reduce_to else: self.reducer = None @@ -90,7 +91,8 @@ class RRDB(nn.Module): if self.reducer is not None: out = self.reducer(out) b, f, h, w = out.shape - out = torch.cat([out, torch.zeros((b, self.recover_ch, h, w), device=out.device)], dim=1) + out = torch.cat( + [out, torch.zeros((b, self.recover_ch, h, w), device=out.device)], dim=1) if return_residual: return 0.2 * out @@ -115,15 +117,18 @@ class RRDBWithBypass(nn.Module): self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels) self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels) if reduce_to is not None: - self.reducer = ConvGnLelu(mid_channels, reduce_to, kernel_size=3, activation=False, norm=False, bias=True) + self.reducer = ConvGnLelu( + mid_channels, reduce_to, kernel_size=3, activation=False, norm=False, bias=True) self.recover_ch = mid_channels - reduce_to bypass_channels = mid_channels + reduce_to else: self.reducer = None bypass_channels = mid_channels * 2 self.bypass = nn.Sequential(ConvGnSilu(bypass_channels, mid_channels, kernel_size=3, bias=True, activation=True, norm=True), - ConvGnSilu(mid_channels, mid_channels//2, kernel_size=3, bias=False, activation=True, norm=False), - ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False), + ConvGnSilu( + mid_channels, mid_channels//2, kernel_size=3, bias=False, activation=True, norm=False), + ConvGnSilu( + mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False), nn.Sigmoid()) self.randomly_add_bypass_noise = randomly_add_noise_to_bypass @@ -143,7 +148,8 @@ class RRDBWithBypass(nn.Module): if self.reducer is not None: out = self.reducer(out) b, f, h, w = out.shape - out = torch.cat([out, torch.zeros((b, self.recover_ch, h, w), device=out.device)], dim=1) + out = torch.cat( + [out, torch.zeros((b, self.recover_ch, h, w), device=out.device)], dim=1) bypass = self.bypass(torch.cat([x, out], dim=1)) # The purpose of random noise is to induce usage of bypass maps that would otherwise be "dead". Theoretically @@ -184,10 +190,12 @@ class RRDBNet(nn.Module): scale=4, additive_mode="not", # Options: "not", "additive", "additive_enforced" headless=False, - feature_channels=64, # Only applicable when headless=True. How many channels are used at the trunk level. + # Only applicable when headless=True. How many channels are used at the trunk level. + feature_channels=64, output_mode="hq_only", # Options: "hq_only", "hq+features", "features_only" initial_stride=1, - use_ref=False, # When set, a reference image is expected as input and synthesized if not found. Useful for video SR. + # When set, a reference image is expected as input and synthesized if not found. Useful for video SR. + use_ref=False, ): super(RRDBNet, self).__init__() assert output_mode in ['hq_only', 'hq+features', 'features_only'] @@ -205,9 +213,11 @@ class RRDBNet(nn.Module): self.conv_first = None self.reduce_ch = feature_channels reduce_to = feature_channels - self.conv_ref_first = ConvGnLelu(3, feature_channels, 7, stride=2, norm=False, activation=False, bias=True) + self.conv_ref_first = ConvGnLelu( + 3, feature_channels, 7, stride=2, norm=False, activation=False, bias=True) else: - self.conv_first = nn.Conv2d(in_channels, mid_channels, first_conv_ksize, first_conv_stride, first_conv_padding) + self.conv_first = nn.Conv2d( + in_channels, mid_channels, first_conv_ksize, first_conv_stride, first_conv_padding) self.reduce_ch = mid_channels reduce_to = None self.body = make_layer( @@ -229,7 +239,8 @@ class RRDBNet(nn.Module): self.additive_mode = additive_mode if additive_mode == "additive_enforced": - self.add_enforced_pool = nn.AvgPool2d(kernel_size=scale, stride=scale) + self.add_enforced_pool = nn.AvgPool2d( + kernel_size=scale, stride=scale) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) @@ -259,14 +270,16 @@ class RRDBNet(nn.Module): else: # "Normal" mode -> image input. if self.use_ref: - x_lg = F.interpolate(x, scale_factor=self.scale, mode="bicubic") + x_lg = F.interpolate( + x, scale_factor=self.scale, mode="bicubic") if ref is None: ref = torch.zeros_like(x_lg) x_lg = torch.cat([x_lg, ref], dim=1) else: x_lg = x feat = self.conv_first(x_lg) - feat = sequential_checkpoint(self.body, self.num_blocks // self.blocks_per_checkpoint, feat) + feat = sequential_checkpoint( + self.body, self.num_blocks // self.blocks_per_checkpoint, feat) feat = feat[:, :self.reduce_ch] body_feat = self.conv_body(feat) feat = feat + body_feat @@ -286,12 +299,14 @@ class RRDBNet(nn.Module): out = self.lrelu(self.conv_up2(out)) out = self.conv_last(self.lrelu(self.conv_hr(out))) if "additive" in self.additive_mode: - x_interp = F.interpolate(x, scale_factor=self.scale, mode='bilinear') + x_interp = F.interpolate( + x, scale_factor=self.scale, mode='bilinear') if self.additive_mode == 'additive': out = out + x_interp elif self.additive_mode == 'additive_enforced': out_pooled = self.add_enforced_pool(out) - out = out - F.interpolate(out_pooled, scale_factor=self.scale, mode='nearest') + out = out - F.interpolate(out_pooled, + scale_factor=self.scale, mode='nearest') out = out + x_interp if self.output_mode == "hq+features": @@ -301,30 +316,40 @@ class RRDBNet(nn.Module): def visual_dbg(self, step, path): for i, bm in enumerate(self.body): if hasattr(bm, 'bypass_map'): - torchvision.utils.save_image(bm.bypass_map.cpu().float(), os.path.join(path, "%i_bypass_%i.png" % (step, i+1))) + torchvision.utils.save_image(bm.bypass_map.cpu().float( + ), os.path.join(path, "%i_bypass_%i.png" % (step, i+1))) + @register_model def register_RRDBNetBypass(opt_net, opt): - additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not' - output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only' + additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys( + ) else 'not' + output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys( + ) else 'hq_only' gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32 - initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1 + initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys( + ) else 1 bypass_noise = opt_get(opt_net, ['bypass_noise'], False) - block = functools.partial(RRDBWithBypass, randomly_add_noise_to_bypass=bypass_noise) + block = functools.partial( + RRDBWithBypass, randomly_add_noise_to_bypass=bypass_noise) return RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], - mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode, - output_mode=output_mode, body_block=block, scale=opt_net['scale'], growth_channels=gc, - initial_stride=initial_stride) + mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode, + output_mode=output_mode, body_block=block, scale=opt_net[ + 'scale'], growth_channels=gc, + initial_stride=initial_stride) @register_model def register_RRDBNet(opt_net, opt): - additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not' - output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only' + additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys( + ) else 'not' + output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys( + ) else 'hq_only' gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32 - initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1 + initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys( + ) else 1 return RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], - mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode, - output_mode=output_mode, body_block=RRDB, scale=opt_net['scale'], growth_channels=gc, - initial_stride=initial_stride) - + mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode, + output_mode=output_mode, body_block=RRDB, scale=opt_net[ + 'scale'], growth_channels=gc, + initial_stride=initial_stride) diff --git a/dlas/models/image_generation/ResGen_arch.py b/dlas/models/image_generation/ResGen_arch.py index bf9d892a..0f0214f0 100644 --- a/dlas/models/image_generation/ResGen_arch.py +++ b/dlas/models/image_generation/ResGen_arch.py @@ -1,10 +1,10 @@ +import numpy as np import torch import torch.nn as nn -import numpy as np import torch.nn.functional as F - -__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152'] +__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', + 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152'] def conv3x3(in_planes, out_planes, stride=1): @@ -12,16 +12,19 @@ def conv3x3(in_planes, out_planes, stride=1): return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + def conv5x5(in_planes, out_planes, stride=1): """5x5 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride, padding=2, bias=False) + def conv7x7(in_planes, out_planes, stride=1): """7x7 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=7, stride=stride, padding=3, bias=False) + def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) @@ -67,7 +70,8 @@ class FixupResNet(nn.Module): def __init__(self, block, layers, upscale_applications=2, num_filters=64, inject_noise=False): super(FixupResNet, self).__init__() self.inject_noise = inject_noise - self.num_layers = sum(layers) + layers[-1] * (upscale_applications - 1) # The last layer is applied repeatedly to achieve high level SR. + # The last layer is applied repeatedly to achieve high level SR. + self.num_layers = sum(layers) + layers[-1] * (upscale_applications - 1) self.inplanes = num_filters self.upscale_applications = upscale_applications # Part 1 - Process raw input image. Most denoising should appear here and this should be the most complicated @@ -77,7 +81,8 @@ class FixupResNet(nn.Module): self.bias1 = nn.Parameter(torch.zeros(1)) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.layer1 = self._make_layer(block, num_filters, layers[0], stride=1) - self.skip1 = nn.Conv2d(num_filters, 3, kernel_size=5, stride=1, padding=2, bias=False) + self.skip1 = nn.Conv2d(num_filters, 3, kernel_size=5, + stride=1, padding=2, bias=False) self.skip1_bias = nn.Parameter(torch.zeros(1)) # Part 2 - This is the upsampler core. It consists of a normal multiplicative conv followed by several residual @@ -87,26 +92,31 @@ class FixupResNet(nn.Module): self.nf2 = int(num_filters/4) # This part isn't repeated. It de-filters the output from the previous step to fit the filter size used in the # upsampler-conv. - self.upsampler_conv = nn.Conv2d(num_filters, self.nf2, kernel_size=3, stride=1, padding=1, bias=False) + self.upsampler_conv = nn.Conv2d( + num_filters, self.nf2, kernel_size=3, stride=1, padding=1, bias=False) self.uc_bias = nn.Parameter(torch.zeros(1)) self.inplanes = self.nf2 if layers[1] > 0: # This is the repeated part. - self.layer2 = self._make_layer(block, int(self.nf2), layers[1], stride=1, conv_type=conv5x5) - self.skip2 = nn.Conv2d(self.nf2, 3, kernel_size=5, stride=1, padding=2, bias=False) + self.layer2 = self._make_layer(block, int( + self.nf2), layers[1], stride=1, conv_type=conv5x5) + self.skip2 = nn.Conv2d( + self.nf2, 3, kernel_size=5, stride=1, padding=2, bias=False) self.skip2_bias = nn.Parameter(torch.zeros(1)) - self.final_defilter = nn.Conv2d(self.nf2, 3, kernel_size=5, stride=1, padding=2, bias=True) + self.final_defilter = nn.Conv2d( + self.nf2, 3, kernel_size=5, stride=1, padding=2, bias=True) self.bias2 = nn.Parameter(torch.zeros(1)) for m in self.modules(): if isinstance(m, FixupBasicBlock): - nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5)) + nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt( + 2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5)) nn.init.constant_(m.conv2.weight, 0) if m.downsample is not None: - nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) - + nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt( + 2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) def _make_layer(self, block, planes, blocks, stride=1, conv_type=conv3x3): defilter = None @@ -145,6 +155,7 @@ class FixupResNet(nn.Module): x = self.final_defilter(x) + self.bias2 return x, skip_med, skip_lo + class FixupResNetV2(FixupResNet): def __init__(self, **kwargs): super(FixupResNetV2, self).__init__(**kwargs) @@ -154,10 +165,12 @@ class FixupResNetV2(FixupResNet): self.skip2 = None self.skip2_bias = None # The new filter-to-image stack will be 2 conv layers deep, not 1. - self.final_process = nn.Conv2d(self.nf2, self.nf2, kernel_size=5, stride=1, padding=2, bias=True) + self.final_process = nn.Conv2d( + self.nf2, self.nf2, kernel_size=5, stride=1, padding=2, bias=True) self.bias2 = nn.Parameter(torch.zeros(1)) self.fp_bn = nn.BatchNorm2d(self.nf2) - self.final_defilter = nn.Conv2d(self.nf2, 3, kernel_size=3, stride=1, padding=1, bias=True) + self.final_defilter = nn.Conv2d( + self.nf2, 3, kernel_size=3, stride=1, padding=1, bias=True) self.bias3 = nn.Parameter(torch.zeros(1)) def filter_to_image(self, filter): @@ -198,12 +211,14 @@ class FixupResNetV2(FixupResNet): return x, skip_med, skip_lo + def fixup_resnet34(nb_denoiser=20, nb_upsampler=10, **kwargs): """Constructs a Fixup-ResNet-34 model. """ model = FixupResNet(FixupBasicBlock, [nb_denoiser, nb_upsampler], **kwargs) return model + def fixup_resnet34_v2(nb_denoiser=20, nb_upsampler=10, **kwargs): """Constructs a Fixup-ResNet-34 model. """ @@ -213,4 +228,4 @@ def fixup_resnet34_v2(nb_denoiser=20, nb_upsampler=10, **kwargs): return model -__all__ = ['FixupResNet', 'fixup_resnet34', 'fixup_resnet34_v2'] \ No newline at end of file +__all__ = ['FixupResNet', 'fixup_resnet34', 'fixup_resnet34_v2'] diff --git a/dlas/models/image_generation/discriminator_vgg_arch.py b/dlas/models/image_generation/discriminator_vgg_arch.py index af44eca2..6c57557d 100644 --- a/dlas/models/image_generation/discriminator_vgg_arch.py +++ b/dlas/models/image_generation/discriminator_vgg_arch.py @@ -1,11 +1,10 @@ import torch import torch.nn as nn - import torch.nn.functional as F -from trainer.networks import register_model -from utils.util import checkpoint, opt_get -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint, opt_get class Discriminator_VGG_128(nn.Module): @@ -47,7 +46,8 @@ class Discriminator_VGG_128(nn.Module): input_img_factor = input_img_factor // 2 final_nf = nf * 16 - self.linear1 = ml.Linear(final_nf * 4 * input_img_factor * 4 * input_img_factor, 100) + self.linear1 = ml.Linear( + final_nf * 4 * input_img_factor * 4 * input_img_factor, 100) self.linear2 = ml.Linear(100, 1) # activation function @@ -57,11 +57,11 @@ class Discriminator_VGG_128(nn.Module): fea = self.lrelu(self.conv0_0(x)) fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) - #fea = torch.cat([fea, skip_med], dim=1) + # fea = torch.cat([fea, skip_med], dim=1) fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) - #fea = torch.cat([fea, skip_lo], dim=1) + # fea = torch.cat([fea, skip_lo], dim=1) fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) @@ -130,18 +130,19 @@ class Discriminator_VGG_128_GN(nn.Module): # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.linear1 = ml.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100) + self.linear1 = ml.Linear( + int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 100) self.linear2 = ml.Linear(100, 1) def compute_body(self, x): fea = self.lrelu(self.conv0_0(x)) fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) - #fea = torch.cat([fea, skip_med], dim=1) + # fea = torch.cat([fea, skip_med], dim=1) fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) - #fea = torch.cat([fea, skip_lo], dim=1) + # fea = torch.cat([fea, skip_lo], dim=1) fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) @@ -171,7 +172,8 @@ class Discriminator_VGG_128_GN(nn.Module): def register_discriminator_vgg_128_gn(opt_net, opt): return Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=opt_net['image_size'] / 128, - extra_conv=opt_get(opt_net, ['extra_conv'], False), + extra_conv=opt_get( + opt_net, ['extra_conv'], False), do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False)) diff --git a/dlas/models/image_generation/glean/glean.py b/dlas/models/image_generation/glean/glean.py index 89ed0b16..d677306c 100644 --- a/dlas/models/image_generation/glean/glean.py +++ b/dlas/models/image_generation/glean/glean.py @@ -1,27 +1,30 @@ import math -import torch.nn as nn import torch +import torch.nn as nn -from models.image_generation.RRDBNet_arch import RRDB -from models.arch_util import ConvGnLelu - - +from dlas.models.arch_util import ConvGnLelu # Produces a convolutional feature (`f`) and a reduced feature map with double the filters. -from models.image_generation.glean.stylegan2_latent_bank import Stylegan2LatentBank -from models.image_generation.stylegan.stylegan2_rosinality import EqualLinear -from trainer.networks import register_model -from utils.util import checkpoint, sequential_checkpoint +from dlas.models.image_generation.glean.stylegan2_latent_bank import \ + Stylegan2LatentBank +from dlas.models.image_generation.RRDBNet_arch import RRDB +from dlas.models.image_generation.stylegan.stylegan2_rosinality import \ + EqualLinear +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint, sequential_checkpoint class GleanEncoderBlock(nn.Module): def __init__(self, nf, max_nf): super().__init__() - self.structural_latent_conv = ConvGnLelu(nf, nf, kernel_size=1, activation=False, norm=False, bias=True) + self.structural_latent_conv = ConvGnLelu( + nf, nf, kernel_size=1, activation=False, norm=False, bias=True) top_nf = min(nf*2, max_nf) self.process = nn.Sequential( - ConvGnLelu(nf, top_nf, kernel_size=3, stride=2, activation=True, norm=False, bias=False), - ConvGnLelu(top_nf, top_nf, kernel_size=3, activation=True, norm=False, bias=False) + ConvGnLelu(nf, top_nf, kernel_size=3, stride=2, + activation=True, norm=False, bias=False), + ConvGnLelu(top_nf, top_nf, kernel_size=3, + activation=True, norm=False, bias=False) ) def forward(self, x): @@ -36,13 +39,16 @@ class GleanEncoderBlock(nn.Module): class GleanEncoder(nn.Module): def __init__(self, nf, nb, max_nf=512, reductions=4, latent_bank_blocks=7, latent_bank_latent_dim=512, input_dim=32, initial_stride=1): super().__init__() - self.initial_conv = ConvGnLelu(3, nf, kernel_size=7, activation=False, norm=False, bias=True, stride=initial_stride) + self.initial_conv = ConvGnLelu( + 3, nf, kernel_size=7, activation=False, norm=False, bias=True, stride=initial_stride) self.rrdb_blocks = nn.Sequential(*[RRDB(nf) for _ in range(nb)]) - self.reducers = nn.ModuleList([GleanEncoderBlock(min(nf * 2 ** i, max_nf), max_nf) for i in range(reductions)]) + self.reducers = nn.ModuleList([GleanEncoderBlock( + min(nf * 2 ** i, max_nf), max_nf) for i in range(reductions)]) reducer_output_dim = (input_dim // (2 ** (reductions + 1))) ** 2 reducer_output_nf = min(nf * 2 ** reductions, max_nf) - self.latent_conv = ConvGnLelu(reducer_output_nf, reducer_output_nf, stride=2, kernel_size=3, activation=True, norm=False, bias=True) + self.latent_conv = ConvGnLelu(reducer_output_nf, reducer_output_nf, + stride=2, kernel_size=3, activation=True, norm=False, bias=True) self.latent_linear = EqualLinear(reducer_output_dim * reducer_output_nf, latent_bank_latent_dim * latent_bank_blocks, activation="fused_lrelu") @@ -50,7 +56,8 @@ class GleanEncoder(nn.Module): def forward(self, x): fea = self.initial_conv(x) - fea = sequential_checkpoint(self.rrdb_blocks, len(self.rrdb_blocks), fea) + fea = sequential_checkpoint( + self.rrdb_blocks, len(self.rrdb_blocks), fea) rrdb_fea = fea convolutional_features = [] for reducer in self.reducers: @@ -58,7 +65,8 @@ class GleanEncoder(nn.Module): convolutional_features.append(f) latents = self.latent_conv(fea) - latents = self.latent_linear(latents.flatten(1, -1)).view(fea.shape[0], self.latent_bank_blocks, -1) + latents = self.latent_linear(latents.flatten( + 1, -1)).view(fea.shape[0], self.latent_bank_blocks, -1) return rrdb_fea, convolutional_features, latents @@ -68,7 +76,8 @@ class GleanDecoder(nn.Module): # To determine latent_bank_filters, use the `self.channels` map for the desired input dimensions from stylegan2_rosinality.py def __init__(self, nf, latent_bank_filters=[512, 256, 128]): super().__init__() - self.initial_conv = ConvGnLelu(nf, nf, kernel_size=3, activation=True, norm=False, bias=True, weight_init_factor=.1) + self.initial_conv = ConvGnLelu( + nf, nf, kernel_size=3, activation=True, norm=False, bias=True, weight_init_factor=.1) decoder_block_shuffled_dims = [nf] + latent_bank_filters self.decoder_blocks = nn.ModuleList([ConvGnLelu(decoder_block_shuffled_dims[i] + latent_bank_filters[i], @@ -78,13 +87,15 @@ class GleanDecoder(nn.Module): for i in range(len(latent_bank_filters))]) final_dim = latent_bank_filters[-1] - self.final_decode = ConvGnLelu(final_dim, 3, kernel_size=3, activation=False, bias=True, norm=False, weight_init_factor=.1) + self.final_decode = ConvGnLelu( + final_dim, 3, kernel_size=3, activation=False, bias=True, norm=False, weight_init_factor=.1) def forward(self, rrdb_fea, latent_bank_fea): fea = self.initial_conv(rrdb_fea) for i, block in enumerate(self.decoder_blocks): # The paper calls for PixelShuffle here, but I don't have good experience with that. It also doesn't align with the way the underlying StyleGAN works. - fea = nn.functional.interpolate(fea, scale_factor=2, mode="nearest") + fea = nn.functional.interpolate( + fea, scale_factor=2, mode="nearest") fea = torch.cat([fea, latent_bank_fea[i]], dim=1) fea = checkpoint(block, fea) return self.final_decode(fea) @@ -96,7 +107,8 @@ class GleanGenerator(nn.Module): super().__init__() self.input_dim = input_dim after_stride_dim = input_dim // initial_stride - latent_blocks = int(math.log(gen_output_dim, 2)) # From 4x4->gen_output_dim x gen_output_dim + initial styled conv + # From 4x4->gen_output_dim x gen_output_dim + initial styled conv + latent_blocks = int(math.log(gen_output_dim, 2)) encoder_reductions = int(math.log(after_stride_dim / 4, 2)) + 1 self.encoder = GleanEncoder(nf, encoder_rrdb_nb, reductions=encoder_reductions, latent_bank_blocks=latent_blocks, latent_bank_latent_dim=latent_bank_latent_dim, input_dim=after_stride_dim, initial_stride=initial_stride) diff --git a/dlas/models/image_generation/glean/stylegan2_latent_bank.py b/dlas/models/image_generation/glean/stylegan2_latent_bank.py index e51d48a7..7d496f1a 100644 --- a/dlas/models/image_generation/glean/stylegan2_latent_bank.py +++ b/dlas/models/image_generation/glean/stylegan2_latent_bank.py @@ -1,8 +1,9 @@ import torch import torch.nn as nn -from models.arch_util import ConvGnLelu -from models.image_generation.stylegan.stylegan2_rosinality import Generator +from dlas.models.arch_util import ConvGnLelu +from dlas.models.image_generation.stylegan.stylegan2_rosinality import \ + Generator class Stylegan2LatentBank(nn.Module): @@ -10,7 +11,9 @@ class Stylegan2LatentBank(nn.Module): super().__init__() # Initialize the bank. - self.bank = Generator(size=max_dim, style_dim=latent_dim, n_mlp=8, channel_multiplier=2) # Assumed using 'f' generators with mult=2. + # Assumed using 'f' generators with mult=2. + self.bank = Generator( + size=max_dim, style_dim=latent_dim, n_mlp=8, channel_multiplier=2) state_dict = torch.load(pretrained_model_file) self.bank.load_state_dict(state_dict, strict=True) @@ -23,8 +26,10 @@ class Stylegan2LatentBank(nn.Module): stylegan_encoder_dims = [512, 512, 512, 512, 512, 256, 128, 64, 32] # Initialize the fusion blocks. TODO: Try using the StyledConvs instead of regular ones. - encoder_output_dims = reversed([min(encoder_nf * 2 ** i, encoder_max_nf) for i in range(encoder_levels)]) - input_dims_by_layer = [eod + sed for eod, sed in zip(encoder_output_dims, stylegan_encoder_dims)] + encoder_output_dims = reversed( + [min(encoder_nf * 2 ** i, encoder_max_nf) for i in range(encoder_levels)]) + input_dims_by_layer = [ + eod + sed for eod, sed in zip(encoder_output_dims, stylegan_encoder_dims)] self.fusion_blocks = nn.ModuleList([ConvGnLelu(in_filters, out_filters, kernel_size=3, activation=True, norm=False, bias=True) for in_filters, out_filters in zip(input_dims_by_layer, stylegan_encoder_dims)]) @@ -42,7 +47,8 @@ class Stylegan2LatentBank(nn.Module): # - Later layers -> GLEAN terminates at 256 resolution. def forward(self, convolutional_features, latent_vectors): - out = self.bank.input(latent_vectors[:, 0]) # The input here is only used to fetch the batch size. + # The input here is only used to fetch the batch size. + out = self.bank.input(latent_vectors[:, 0]) out = self.bank.conv1(out, latent_vectors[:, 0], noise=None) k = 0 diff --git a/dlas/models/image_generation/srflow/FlowActNorms.py b/dlas/models/image_generation/srflow/FlowActNorms.py index 6ca7b498..c15ea738 100644 --- a/dlas/models/image_generation/srflow/FlowActNorms.py +++ b/dlas/models/image_generation/srflow/FlowActNorms.py @@ -1,7 +1,7 @@ import torch from torch import nn as nn -from models.image_generation.srflow import thops +from dlas.models.image_generation.srflow import thops class _ActNorm(nn.Module): @@ -33,10 +33,13 @@ class _ActNorm(nn.Module): if (self.bias != 0).any(): self.inited = True return - assert input.device == self.bias.device, (input.device, self.bias.device) + assert input.device == self.bias.device, ( + input.device, self.bias.device) with torch.no_grad(): - bias = thops.mean(input.clone(), dim=[0, 2, 3], keepdim=True) * -1.0 - vars = thops.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True) + bias = thops.mean(input.clone(), dim=[ + 0, 2, 3], keepdim=True) * -1.0 + vars = thops.mean((input.clone() + bias) ** 2, + dim=[0, 2, 3], keepdim=True) logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6)) self.bias.data.copy_(bias.data) self.logs.data.copy_(logs.data) @@ -60,7 +63,8 @@ class _ActNorm(nn.Module): logs = logs + offset if not reverse: - input = input * torch.exp(logs) # should have shape batchsize, n_channels, 1, 1 + # should have shape batchsize, n_channels, 1, 1 + input = input * torch.exp(logs) # input = input * torch.exp(logs+logs_offset) else: input = input * torch.exp(-logs) @@ -120,4 +124,3 @@ class MaskedActNorm2d(ActNorm2d): logdet[mask] = logdet_out[mask] return input, logdet - diff --git a/dlas/models/image_generation/srflow/FlowAffineCouplingsAblation.py b/dlas/models/image_generation/srflow/FlowAffineCouplingsAblation.py index f8d85d9c..0d106b6c 100644 --- a/dlas/models/image_generation/srflow/FlowAffineCouplingsAblation.py +++ b/dlas/models/image_generation/srflow/FlowAffineCouplingsAblation.py @@ -1,9 +1,9 @@ import torch from torch import nn as nn -from models.image_generation.srflow import thops -from models.image_generation.srflow.flow import Conv2d, Conv2dZeros -from utils.util import opt_get +from dlas.models.image_generation.srflow import thops +from dlas.models.image_generation.srflow.flow import Conv2d, Conv2dZeros +from dlas.utils.util import opt_get class CondAffineSeparatedAndCond(nn.Module): @@ -15,10 +15,12 @@ class CondAffineSeparatedAndCond(nn.Module): self.kernel_hidden = 1 self.affine_eps = 0.0001 self.n_hidden_layers = 1 - hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'CondAffineSeparatedAndCond', 'hidden_channels']) + hidden_channels = opt_get( + opt, ['networks', 'generator', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels']) self.hidden_channels = 64 if hidden_channels is None else hidden_channels - self.affine_eps = opt_get(opt, ['networks', 'generator','flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001) + self.affine_eps = opt_get( + opt, ['networks', 'generator', 'flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001) self.channels_for_nn = self.in_channels // 2 self.channels_for_co = self.in_channels - self.channels_for_nn @@ -41,7 +43,8 @@ class CondAffineSeparatedAndCond(nn.Module): def forward(self, input: torch.Tensor, logdet=None, reverse=False, ft=None): if not reverse: z = input - assert z.shape[1] == self.in_channels, (z.shape[1], self.in_channels) + assert z.shape[1] == self.in_channels, ( + z.shape[1], self.in_channels) # Feature Conditional scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures) @@ -81,10 +84,14 @@ class CondAffineSeparatedAndCond(nn.Module): return output, logdet def asserts(self, scale, shift, z1, z2): - assert z1.shape[1] == self.channels_for_nn, (z1.shape[1], self.channels_for_nn) - assert z2.shape[1] == self.channels_for_co, (z2.shape[1], self.channels_for_co) - assert scale.shape[1] == shift.shape[1], (scale.shape[1], shift.shape[1]) - assert scale.shape[1] == z2.shape[1], (scale.shape[1], z1.shape[1], z2.shape[1]) + assert z1.shape[1] == self.channels_for_nn, ( + z1.shape[1], self.channels_for_nn) + assert z2.shape[1] == self.channels_for_co, ( + z2.shape[1], self.channels_for_co) + assert scale.shape[1] == shift.shape[1], ( + scale.shape[1], shift.shape[1]) + assert scale.shape[1] == z2.shape[1], ( + scale.shape[1], z1.shape[1], z2.shape[1]) def get_logdet(self, scale): return thops.sum(torch.log(scale), dim=[1, 2, 3]) @@ -105,14 +112,16 @@ class CondAffineSeparatedAndCond(nn.Module): def split(self, z): z1 = z[:, :self.channels_for_nn] z2 = z[:, self.channels_for_nn:] - assert z1.shape[1] + z2.shape[1] == z.shape[1], (z1.shape[1], z2.shape[1], z.shape[1]) + assert z1.shape[1] + \ + z2.shape[1] == z.shape[1], (z1.shape[1], z2.shape[1], z.shape[1]) return z1, z2 def F(self, in_channels, out_channels, hidden_channels, kernel_hidden=1, n_hidden_layers=1): layers = [Conv2d(in_channels, hidden_channels), nn.ReLU(inplace=False)] for _ in range(n_hidden_layers): - layers.append(Conv2d(hidden_channels, hidden_channels, kernel_size=[kernel_hidden, kernel_hidden])) + layers.append(Conv2d(hidden_channels, hidden_channels, + kernel_size=[kernel_hidden, kernel_hidden])) layers.append(nn.ReLU(inplace=False)) layers.append(Conv2dZeros(hidden_channels, out_channels)) diff --git a/dlas/models/image_generation/srflow/FlowStep.py b/dlas/models/image_generation/srflow/FlowStep.py index 7d3f0724..81f41052 100644 --- a/dlas/models/image_generation/srflow/FlowStep.py +++ b/dlas/models/image_generation/srflow/FlowStep.py @@ -1,13 +1,14 @@ import torch from torch import nn as nn -import models.image_generation.srflow.Permutations -import models.image_generation.srflow.FlowAffineCouplingsAblation -import models.image_generation.srflow.FlowActNorms +import dlas.models.image_generation.srflow.FlowActNorms +import dlas.models.image_generation.srflow.FlowAffineCouplingsAblation +import dlas.models.image_generation.srflow.Permutations def getConditional(rrdbResults, position): - img_ft = rrdbResults if isinstance(rrdbResults, torch.Tensor) else rrdbResults[position] + img_ft = rrdbResults if isinstance( + rrdbResults, torch.Tensor) else rrdbResults[position] return img_ft @@ -46,7 +47,8 @@ class FlowStep(nn.Module): self.acOpt = acOpt # 1. actnorm - self.actnorm = models.image_generation.srflow.FlowActNorms.ActNorm2d(in_channels, actnorm_scale) + self.actnorm = models.image_generation.srflow.FlowActNorms.ActNorm2d( + in_channels, actnorm_scale) # 2. permute if flow_permutation == "invconv": @@ -75,7 +77,8 @@ class FlowStep(nn.Module): # 1. actnorm if self.norm_type == "ConditionalActNormImageInjector": img_ft = getConditional(rrdbResults, self.position) - z, logdet = self.actnorm(z, img_ft=img_ft, logdet=logdet, reverse=False) + z, logdet = self.actnorm( + z, img_ft=img_ft, logdet=logdet, reverse=False) elif self.norm_type == "noNorm": pass else: @@ -90,7 +93,8 @@ class FlowStep(nn.Module): # 3. coupling if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]: img_ft = getConditional(rrdbResults, self.position) - z, logdet = self.affine(input=z, logdet=logdet, reverse=False, ft=img_ft) + z, logdet = self.affine( + input=z, logdet=logdet, reverse=False, ft=img_ft) return z, logdet def reverse_flow(self, z, logdet, rrdbResults=None): @@ -100,7 +104,8 @@ class FlowStep(nn.Module): # 1.coupling if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]: img_ft = getConditional(rrdbResults, self.position) - z, logdet = self.affine(input=z, logdet=logdet, reverse=True, ft=img_ft) + z, logdet = self.affine( + input=z, logdet=logdet, reverse=True, ft=img_ft) # 2. permute z, logdet = FlowStep.FlowPermutation[self.flow_permutation]( diff --git a/dlas/models/image_generation/srflow/FlowUpsamplerNet.py b/dlas/models/image_generation/srflow/FlowUpsamplerNet.py index 025ddd26..7fc818ef 100644 --- a/dlas/models/image_generation/srflow/FlowUpsamplerNet.py +++ b/dlas/models/image_generation/srflow/FlowUpsamplerNet.py @@ -2,12 +2,12 @@ import numpy as np import torch from torch import nn as nn -import models.image_generation.srflow.Split -from models.image_generation.srflow import flow -from models.image_generation.srflow.Split import Split2d -from models.image_generation.srflow.glow_arch import f_conv2d_bias -from models.image_generation.srflow.FlowStep import FlowStep -from utils.util import opt_get, checkpoint +import dlas.models.image_generation.srflow.Split +from dlas.models.image_generation.srflow import flow +from dlas.models.image_generation.srflow.FlowStep import FlowStep +from dlas.models.image_generation.srflow.glow_arch import f_conv2d_bias +from dlas.models.image_generation.srflow.Split import Split2d +from dlas.utils.util import checkpoint, opt_get class FlowUpsamplerNet(nn.Module): @@ -21,9 +21,10 @@ class FlowUpsamplerNet(nn.Module): self.layers = nn.ModuleList() self.output_shapes = [] - self.L = opt_get(opt, ['networks', 'generator','flow', 'L']) - self.K = opt_get(opt, ['networks', 'generator','flow', 'K']) - self.patch_sz = opt_get(opt, ['networks', 'generator', 'flow', 'patch_size'], 160) + self.L = opt_get(opt, ['networks', 'generator', 'flow', 'L']) + self.K = opt_get(opt, ['networks', 'generator', 'flow', 'K']) + self.patch_sz = opt_get( + opt, ['networks', 'generator', 'flow', 'patch_size'], 160) if isinstance(self.K, int): self.K = [K for K in [K, ] * (self.L + 1)] @@ -61,18 +62,20 @@ class FlowUpsamplerNet(nn.Module): affineInCh = self.get_affineInCh(opt_get) flow_permutation = self.get_flow_permutation(flow_permutation, opt) - normOpt = opt_get(opt, ['networks', 'generator','flow', 'norm']) + normOpt = opt_get(opt, ['networks', 'generator', 'flow', 'norm']) conditional_channels = {} n_rrdb = self.get_n_rrdb_channels(opt, opt_get) - n_bypass_channels = opt_get(opt, ['networks', 'generator','flow', 'levelConditional', 'n_channels']) + n_bypass_channels = opt_get( + opt, ['networks', 'generator', 'flow', 'levelConditional', 'n_channels']) conditional_channels[0] = n_rrdb for level in range(1, self.L + 1): # Level 1 gets conditionals from 2, 3, 4 => L - level # Level 2 gets conditionals from 3, 4 # Level 3 gets conditionals from 4 # Level 4 gets conditionals from None - n_bypass = 0 if n_bypass_channels is None else (self.L - level) * n_bypass_channels + n_bypass = 0 if n_bypass_channels is None else ( + self.L - level) * n_bypass_channels conditional_channels[level] = n_rrdb + n_bypass # Upsampler @@ -81,7 +84,8 @@ class FlowUpsamplerNet(nn.Module): H, W = self.arch_squeeze(H, W) # 2. K FlowStep - self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels, opt) + self.arch_additionalFlowAffine( + H, LU_decomposed, W, actnorm_scale, hidden_channels, opt) self.arch_FlowStep(H, self.K[level], LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation, hidden_channels, normOpt, opt, opt_get, @@ -89,7 +93,7 @@ class FlowUpsamplerNet(nn.Module): # Split self.arch_split(H, W, level, self.L, opt, opt_get) - if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']): + if opt_get(opt, ['networks', 'generator', 'flow', 'split', 'enable']): self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2) else: self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64) @@ -100,7 +104,8 @@ class FlowUpsamplerNet(nn.Module): self.scaleW = self.patch_sz / W def get_n_rrdb_channels(self, opt, opt_get): - blocks = opt_get(opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) + blocks = opt_get(opt, ['networks', 'generator', + 'flow', 'stackRRDB', 'blocks']) n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64 return n_rrdb @@ -111,8 +116,10 @@ class FlowUpsamplerNet(nn.Module): condAff['in_channels_rrdb'] = n_conditinal_channels for k in range(K): - position_name = self.get_position_name(H, opt_get(self.opt, ['networks', 'generator', 'flow_scale'])) - if normOpt: normOpt['position'] = position_name + position_name = self.get_position_name(H, opt_get( + self.opt, ['networks', 'generator', 'flow_scale'])) + if normOpt: + normOpt['position'] = position_name self.layers.append( FlowStep(in_channels=self.C, @@ -127,22 +134,31 @@ class FlowUpsamplerNet(nn.Module): [-1, self.C, H, W]) def get_condAffSetting(self, opt, opt_get): - condAff = opt_get(opt, ['networks', 'generator','flow', 'condAff']) or None - condAff = opt_get(opt, ['networks', 'generator','flow', 'condFtAffine']) or condAff + condAff = opt_get( + opt, ['networks', 'generator', 'flow', 'condAff']) or None + condAff = opt_get(opt, ['networks', 'generator', + 'flow', 'condFtAffine']) or condAff return condAff def arch_split(self, H, W, L, levels, opt, opt_get): - correct_splits = opt_get(opt, ['networks', 'generator','flow', 'split', 'correct_splits'], False) + correct_splits = opt_get( + opt, ['networks', 'generator', 'flow', 'split', 'correct_splits'], False) correction = 0 if correct_splits else 1 - if opt_get(opt, ['networks', 'generator','flow', 'split', 'enable']) and L < levels - correction: - logs_eps = opt_get(opt, ['networks', 'generator','flow', 'split', 'logs_eps']) or 0 - consume_ratio = opt_get(opt, ['networks', 'generator','flow', 'split', 'consume_ratio']) or 0.5 - position_name = self.get_position_name(H, opt_get(self.opt, ['networks', 'generator', 'flow_scale'])) - position = position_name if opt_get(opt, ['networks', 'generator','flow', 'split', 'conditional']) else None - cond_channels = opt_get(opt, ['networks', 'generator','flow', 'split', 'cond_channels']) + if opt_get(opt, ['networks', 'generator', 'flow', 'split', 'enable']) and L < levels - correction: + logs_eps = opt_get( + opt, ['networks', 'generator', 'flow', 'split', 'logs_eps']) or 0 + consume_ratio = opt_get( + opt, ['networks', 'generator', 'flow', 'split', 'consume_ratio']) or 0.5 + position_name = self.get_position_name(H, opt_get( + self.opt, ['networks', 'generator', 'flow_scale'])) + position = position_name if opt_get( + opt, ['networks', 'generator', 'flow', 'split', 'conditional']) else None + cond_channels = opt_get( + opt, ['networks', 'generator', 'flow', 'split', 'cond_channels']) cond_channels = 0 if cond_channels is None else cond_channels - t = opt_get(opt, ['networks', 'generator','flow', 'split', 'type'], 'Split2d') + t = opt_get(opt, ['networks', 'generator', + 'flow', 'split', 'type'], 'Split2d') if t == 'Split2d': split = models.image_generation.srflow.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position, @@ -153,7 +169,8 @@ class FlowUpsamplerNet(nn.Module): def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt): if 'additionalFlowNoAffine' in opt['networks']['generator']['flow']: - n_additionalFlowNoAffine = int(opt['networks']['generator']['flow']['additionalFlowNoAffine']) + n_additionalFlowNoAffine = int( + opt['networks']['generator']['flow']['additionalFlowNoAffine']) for _ in range(n_additionalFlowNoAffine): self.layers.append( FlowStep(in_channels=self.C, @@ -172,11 +189,13 @@ class FlowUpsamplerNet(nn.Module): return H, W def get_flow_permutation(self, flow_permutation, opt): - flow_permutation = opt['networks']['generator']['flow'].get('flow_permutation', 'invconv') + flow_permutation = opt['networks']['generator']['flow'].get( + 'flow_permutation', 'invconv') return flow_permutation def get_affineInCh(self, opt_get): - affineInCh = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or [] + affineInCh = opt_get( + self.opt, ['networks', 'generator', 'flow', 'stackRRDB', 'blocks']) or [] affineInCh = (len(affineInCh) + 1) * 64 return affineInCh @@ -188,14 +207,17 @@ class FlowUpsamplerNet(nn.Module): y_onehot=None): if reverse: - epses_copy = [eps for eps in epses] if isinstance(epses, list) else epses + epses_copy = [eps for eps in epses] if isinstance( + epses, list) else epses - sr, logdet = self.decode(rrdbResults, z, eps_std, epses=epses_copy, logdet=logdet, y_onehot=y_onehot) + sr, logdet = self.decode( + rrdbResults, z, eps_std, epses=epses_copy, logdet=logdet, y_onehot=y_onehot) return sr, logdet else: assert gt is not None assert rrdbResults is not None - z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot) + z, logdet = self.encode( + gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot) return z, logdet def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None): @@ -204,10 +226,11 @@ class FlowUpsamplerNet(nn.Module): level_conditionals = {} bypasses = {} - L = opt_get(self.opt, ['networks', 'generator','flow', 'L']) + L = opt_get(self.opt, ['networks', 'generator', 'flow', 'L']) for level in range(1, L + 1): - bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False) + bypasses[level] = torch.nn.functional.interpolate( + gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False) for layer, shape in zip(self.layers, self.output_shapes): size = shape[2] @@ -219,7 +242,8 @@ class FlowUpsamplerNet(nn.Module): level_conditionals[level] = rrdbResults[self.levelToName[level]] if isinstance(layer, FlowStep): - fl_fea, logdet = checkpoint(layer, fl_fea, logdet, level_conditionals[level]) + fl_fea, logdet = checkpoint( + layer, fl_fea, logdet, level_conditionals[level]) elif isinstance(layer, Split2d): fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level], y_onehot=y_onehot) @@ -242,7 +266,8 @@ class FlowUpsamplerNet(nn.Module): def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None): ft = None if layer.position is None else rrdbResults[layer.position] - fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot) + fl_fea, logdet, eps = layer( + fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot) epses.append(eps) return fl_fea, logdet @@ -253,7 +278,7 @@ class FlowUpsamplerNet(nn.Module): # debug.imwrite("fl_fea", fl_fea) bypasses = {} level_conditionals = {} - if not opt_get(self.opt, ['networks', 'generator','flow', 'levelConditional', 'conditional']) == True: + if not opt_get(self.opt, ['networks', 'generator', 'flow', 'levelConditional', 'conditional']) == True: for level in range(self.L + 1): level_conditionals[level] = rrdbResults[self.levelToName[level]] @@ -265,10 +290,12 @@ class FlowUpsamplerNet(nn.Module): if isinstance(layer, Split2d): fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer, - rrdbResults[self.levelToName[level]], logdet=logdet, + rrdbResults[self.levelToName[level] + ], logdet=logdet, y_onehot=y_onehot) elif isinstance(layer, FlowStep): - fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, rrdbResults=level_conditionals[level]) + fl_fea, logdet = layer( + fl_fea, logdet=logdet, reverse=True, rrdbResults=level_conditionals[level]) else: fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True) @@ -284,7 +311,6 @@ class FlowUpsamplerNet(nn.Module): eps_std=eps_std, ft=ft, y_onehot=y_onehot) return fl_fea, logdet - def get_position_name(self, H, scale): downscale_factor = self.patch_sz // H position_name = 'fea_up{}'.format(scale / downscale_factor) diff --git a/dlas/models/image_generation/srflow/Permutations.py b/dlas/models/image_generation/srflow/Permutations.py index 122ab3a9..aa3c64d1 100644 --- a/dlas/models/image_generation/srflow/Permutations.py +++ b/dlas/models/image_generation/srflow/Permutations.py @@ -3,7 +3,7 @@ import torch from torch import nn as nn from torch.nn import functional as F -from models.image_generation.srflow import thops +from dlas.models.image_generation.srflow import thops class InvertibleConv1x1(nn.Module): @@ -25,6 +25,7 @@ class InvertibleConv1x1(nn.Module): weight = torch.inverse(self.weight.double()).float() \ .view(w_shape[0], w_shape[1], 1, 1) return weight, dlogdet + def forward(self, input, logdet=None, reverse=False): """ log-det = log|abs(|W|)| * pixels diff --git a/dlas/models/image_generation/srflow/RRDBNet_arch.py b/dlas/models/image_generation/srflow/RRDBNet_arch.py index d34d7160..c36874be 100644 --- a/dlas/models/image_generation/srflow/RRDBNet_arch.py +++ b/dlas/models/image_generation/srflow/RRDBNet_arch.py @@ -1,11 +1,13 @@ import functools + import torch import torch.nn as nn import torch.nn.functional as F -import models.image_generation.srflow.module_util as mutil -from models.arch_util import default_init_weights, ConvGnSilu, ConvGnLelu -from trainer.networks import register_model -from utils.util import opt_get + +import dlas.models.image_generation.srflow.module_util as mutil +from dlas.models.arch_util import ConvGnLelu, ConvGnSilu, default_init_weights +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get class ResidualDenseBlock(nn.Module): @@ -30,7 +32,6 @@ class ResidualDenseBlock(nn.Module): for i in range(5): default_init_weights(getattr(self, f'conv{i+1}'), 0.1) - def forward(self, x): """Forward function. @@ -97,8 +98,10 @@ class RRDBWithBypass(nn.Module): self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels) self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels) self.bypass = nn.Sequential(ConvGnSilu(mid_channels*2, mid_channels, kernel_size=3, bias=True, activation=True, norm=True), - ConvGnSilu(mid_channels, mid_channels//2, kernel_size=3, bias=False, activation=True, norm=False), - ConvGnSilu(mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False), + ConvGnSilu( + mid_channels, mid_channels//2, kernel_size=3, bias=False, activation=True, norm=False), + ConvGnSilu( + mid_channels//2, 1, kernel_size=3, bias=False, activation=False, norm=False), nn.Sigmoid()) def forward(self, x): @@ -126,18 +129,21 @@ class RRDBNet(nn.Module): bypass = opt_get(self.opt, ['networks', 'generator', 'rrdb_bypass']) if bypass: - RRDB_block_f = functools.partial(RRDBWithBypass, mid_channels=nf, growth_channels=gc) + RRDB_block_f = functools.partial( + RRDBWithBypass, mid_channels=nf, growth_channels=gc) else: - RRDB_block_f = functools.partial(RRDB, mid_channels=nf, growth_channels=gc) + RRDB_block_f = functools.partial( + RRDB, mid_channels=nf, growth_channels=gc) self.scale = scale if initial_conv_stride == 1: self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) else: - self.conv_first = nn.Conv2d(in_nc, nf, 7, stride=initial_conv_stride, padding=3, bias=True) + self.conv_first = nn.Conv2d( + in_nc, nf, 7, stride=initial_conv_stride, padding=3, bias=True) self.body = mutil.make_layer(RRDB_block_f, nb) self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - #### upsampling + # upsampling self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) if self.scale >= 2: self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) @@ -156,7 +162,8 @@ class RRDBNet(nn.Module): def forward(self, x, get_steps=False): fea = self.conv_first(x) - block_idxs = opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'blocks']) or [] + block_idxs = opt_get( + self.opt, ['networks', 'generator', 'flow', 'stackRRDB', 'blocks']) or [] block_results = {} for idx, m in enumerate(self.body.children()): @@ -169,7 +176,8 @@ class RRDBNet(nn.Module): last_lr_fea = fea + trunk - fea_up2 = self.conv_up1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest')) + fea_up2 = self.conv_up1(F.interpolate( + last_lr_fea, scale_factor=2, mode='nearest')) fea = self.lrelu(fea_up2) fea_up4 = None @@ -178,16 +186,20 @@ class RRDBNet(nn.Module): fea_up32 = None if self.scale >= 4: - fea_up4 = self.conv_up2(F.interpolate(fea, scale_factor=2, mode='nearest')) + fea_up4 = self.conv_up2(F.interpolate( + fea, scale_factor=2, mode='nearest')) fea = self.lrelu(fea_up4) if self.scale >= 8: - fea_up8 = self.conv_up3(F.interpolate(fea, scale_factor=2, mode='nearest')) + fea_up8 = self.conv_up3(F.interpolate( + fea, scale_factor=2, mode='nearest')) fea = self.lrelu(fea_up8) if self.scale >= 16: - fea_up16 = self.conv_up4(F.interpolate(fea, scale_factor=2, mode='nearest')) + fea_up16 = self.conv_up4(F.interpolate( + fea, scale_factor=2, mode='nearest')) fea = self.lrelu(fea_up16) if self.scale >= 32: - fea_up32 = self.conv_up5(F.interpolate(fea, scale_factor=2, mode='nearest')) + fea_up32 = self.conv_up5(F.interpolate( + fea, scale_factor=2, mode='nearest')) fea = self.lrelu(fea_up32) out = self.conv_last(self.lrelu(self.conv_hr(fea))) @@ -202,12 +214,16 @@ class RRDBNet(nn.Module): 'fea_up32': fea_up32, 'out': out} - fea_up0_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up0']) or False + fea_up0_en = opt_get( + self.opt, ['networks', 'generator', 'flow', 'fea_up0']) or False if fea_up0_en: - results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True) - fea_upn1_en = opt_get(self.opt, ['networks', 'generator','flow', 'fea_up-1']) or False + results['fea_up0'] = F.interpolate( + last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True) + fea_upn1_en = opt_get( + self.opt, ['networks', 'generator', 'flow', 'fea_up-1']) or False if fea_upn1_en: - results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True) + results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, + mode='bilinear', align_corners=False, recompute_scale_factor=True) else: raise NotImplementedError @@ -224,19 +240,22 @@ class RRDBLatentWrapper(nn.Module): super().__init__() self.with_bypass = with_bypass self.blocks = blocks - fake_opt = { 'networks': {'generator': {'flow': {'stackRRDB': {'blocks': blocks}}, 'rrdb_bypass': with_bypass}}} + fake_opt = {'networks': {'generator': { + 'flow': {'stackRRDB': {'blocks': blocks}}, 'rrdb_bypass': with_bypass}}} self.wrappedRRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, fake_opt) if pretrain_rrdb_path is not None: rrdb_state_dict = torch.load(pretrain_rrdb_path) self.wrappedRRDB.load_state_dict(rrdb_state_dict, strict=True) out_dim = nf * (len(blocks) + 1) self.postprocess = nn.Sequential(ConvGnLelu(out_dim, out_dim, kernel_size=1, bias=True, activation=True, norm=True), - ConvGnLelu(out_dim, out_dim, kernel_size=1, bias=True, activation=True, norm=True), + ConvGnLelu( + out_dim, out_dim, kernel_size=1, bias=True, activation=True, norm=True), ConvGnLelu(out_dim, out_dim, kernel_size=1, bias=True, activation=False, norm=False)) def forward(self, lr): rrdbResults = self.wrappedRRDB(lr, get_steps=True) - blocklist = [rrdbResults["block_{}".format(idx)] for idx in self.blocks] + blocklist = [ + rrdbResults["block_{}".format(idx)] for idx in self.blocks] blocklist.append(rrdbResults['last_lr_fea']) fea = torch.cat(blocklist, dim=1) fea = self.postprocess(fea) @@ -254,5 +273,5 @@ def register_rrdb_latent_wrapper(opt_net, opt): @register_model def register_rrdb_srflow(opt_net, opt): return RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], - nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'], - initial_conv_stride=opt_net['initial_stride']) \ No newline at end of file + nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'], + initial_conv_stride=opt_net['initial_stride']) diff --git a/dlas/models/image_generation/srflow/SRFlowNet_arch.py b/dlas/models/image_generation/srflow/SRFlowNet_arch.py index 82cc5432..44b57df0 100644 --- a/dlas/models/image_generation/srflow/SRFlowNet_arch.py +++ b/dlas/models/image_generation/srflow/SRFlowNet_arch.py @@ -1,15 +1,17 @@ import math +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import numpy as np -from models.image_generation.srflow.RRDBNet_arch import RRDBNet -from models.image_generation.srflow.FlowUpsamplerNet import FlowUpsamplerNet -import models.image_generation.srflow.thops as thops -import models.image_generation.srflow.flow as flow -from trainer.networks import register_model -from utils.util import opt_get + +import dlas.models.image_generation.srflow.flow as flow +import dlas.models.image_generation.srflow.thops as thops +from dlas.models.image_generation.srflow.FlowUpsamplerNet import \ + FlowUpsamplerNet +from dlas.models.image_generation.srflow.RRDBNet_arch import RRDBNet +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get class SRFlowNet(nn.Module): @@ -18,19 +20,26 @@ class SRFlowNet(nn.Module): self.opt = opt self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \ - None else opt_get(opt, ['datasets', 'train', 'quant']) - initial_stride = opt_get(opt, ['networks', 'generator', 'initial_stride'], 1) - self.RRDB = RRDBNet(in_nc, out_nc, nf=nf, nb=nb, gc=gc, scale=scale, opt=opt, initial_conv_stride=initial_stride) + None else opt_get(opt, ['datasets', 'train', 'quant']) + initial_stride = opt_get( + opt, ['networks', 'generator', 'initial_stride'], 1) + self.RRDB = RRDBNet(in_nc, out_nc, nf=nf, nb=nb, gc=gc, + scale=scale, opt=opt, initial_conv_stride=initial_stride) if 'pretrain_rrdb' in opt['networks']['generator'].keys(): - rrdb_state_dict = torch.load(opt['networks']['generator']['pretrain_rrdb']) + rrdb_state_dict = torch.load( + opt['networks']['generator']['pretrain_rrdb']) self.RRDB.load_state_dict(rrdb_state_dict, strict=True) - hidden_channels = opt_get(opt, ['networks', 'generator','flow', 'hidden_channels']) + hidden_channels = opt_get( + opt, ['networks', 'generator', 'flow', 'hidden_channels']) hidden_channels = hidden_channels or 64 - self.RRDB_training = opt_get(self.opt, ['networks', 'generator','train_RRDB'], default=False) - self.flow_scale = opt_get(self.opt, ['networks', 'generator', 'flow_scale'], default=opt['scale']) + self.RRDB_training = opt_get( + self.opt, ['networks', 'generator', 'train_RRDB'], default=False) + self.flow_scale = opt_get( + self.opt, ['networks', 'generator', 'flow_scale'], default=opt['scale']) - self.patch_sz = opt_get(self.opt, ['networks', 'generator', 'flow', 'patch_size'], 160) + self.patch_sz = opt_get( + self.opt, ['networks', 'generator', 'flow', 'patch_size'], 160) self.flowUpsamplerNet = \ FlowUpsamplerNet((self.patch_sz, self.patch_sz, 3), hidden_channels, K, flow_coupling=opt['networks']['generator']['flow']['coupling'], opt=opt) @@ -39,11 +48,14 @@ class SRFlowNet(nn.Module): self.dbg_logdet = 0 def get_random_z(self, heat, seed=None, batch_size=1, lr_shape=None, device='cuda'): - if seed: torch.manual_seed(seed) + if seed: + torch.manual_seed(seed) if opt_get(self.opt, ['networks', 'generator', 'flow', 'split', 'enable']): C = self.flowUpsamplerNet.C - H = int(self.flow_scale * lr_shape[0] // (self.flowUpsamplerNet.scaleH * self.flow_scale / self.RRDB.scale)) - W = int(self.flow_scale * lr_shape[1] // (self.flowUpsamplerNet.scaleW * self.flow_scale / self.RRDB.scale)) + H = int(self.flow_scale * lr_shape[0] // ( + self.flowUpsamplerNet.scaleH * self.flow_scale / self.RRDB.scale)) + W = int(self.flow_scale * lr_shape[1] // ( + self.flowUpsamplerNet.scaleW * self.flow_scale / self.RRDB.scale)) size = (batch_size, C, H, W) if heat == 0: @@ -54,7 +66,8 @@ class SRFlowNet(nn.Module): L = opt_get(self.opt, ['networks', 'generator', 'flow', 'L']) or 3 fac = 2 ** (L - 3) z_size = int(self.lr_size // (2 ** (L - 3))) - z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size)) + z = torch.normal(mean=0, std=heat, size=( + batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size)) return z.to(device) def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False, @@ -67,8 +80,10 @@ class SRFlowNet(nn.Module): assert lr.shape[1] == 3 if z is None: # Synthesize it. Accommodate mismatches in LR scale and flow_scale, which are normally handled by the RRDB subnet. - lr_shape = [d * self.opt['scale'] / self.flow_scale for d in lr.shape[2:]] - z = self.get_random_z(eps_std, batch_size=lr.shape[0], lr_shape=lr_shape, device=lr.device) + lr_shape = [d * self.opt['scale'] / + self.flow_scale for d in lr.shape[2:]] + z = self.get_random_z( + eps_std, batch_size=lr.shape[0], lr_shape=lr_shape, device=lr.device) if reverse_with_grad: return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise) @@ -92,7 +107,8 @@ class SRFlowNet(nn.Module): if add_gt_noise: # Setup - noiseQuant = opt_get(self.opt, ['networks', 'generator','flow', 'augmentation', 'noiseQuant'], True) + noiseQuant = opt_get( + self.opt, ['networks', 'generator', 'flow', 'augmentation', 'noiseQuant'], True) if noiseQuant: z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant) logdet = logdet + float(-np.log(self.quant) * pixels) @@ -124,11 +140,13 @@ class SRFlowNet(nn.Module): def rrdbPreprocessing(self, lr): rrdbResults = self.RRDB(lr, get_steps=True) - block_idxs = opt_get(self.opt, ['networks', 'generator', 'flow', 'stackRRDB', 'blocks']) or [] + block_idxs = opt_get( + self.opt, ['networks', 'generator', 'flow', 'stackRRDB', 'blocks']) or [] if len(block_idxs) > 0: - concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1) + concat = torch.cat([rrdbResults["block_{}".format(idx)] + for idx in block_idxs], dim=1) - if opt_get(self.opt, ['networks', 'generator','flow', 'stackRRDB', 'concat']) or False: + if opt_get(self.opt, ['networks', 'generator', 'flow', 'stackRRDB', 'concat']) or False: keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4'] if 'fea_up0' in rrdbResults.keys(): keys.append('fea_up0') @@ -141,12 +159,13 @@ class SRFlowNet(nn.Module): for k in keys: h = rrdbResults[k].shape[2] w = rrdbResults[k].shape[3] - rrdbResults[k] = torch.cat([rrdbResults[k], F.interpolate(concat, (h, w))], dim=1) + rrdbResults[k] = torch.cat( + [rrdbResults[k], F.interpolate(concat, (h, w))], dim=1) return rrdbResults def get_score(self, disc_loss_sigma, z): score_real = 0.5 * (1 - 1 / (disc_loss_sigma ** 2)) * thops.sum(z ** 2, dim=[1, 2, 3]) - \ - z.shape[1] * z.shape[2] * z.shape[3] * math.log(disc_loss_sigma) + z.shape[1] * z.shape[2] * z.shape[3] * math.log(disc_loss_sigma) return -score_real def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True): @@ -172,4 +191,4 @@ class SRFlowNet(nn.Module): @register_model def register_srflow(opt_net, opt): return SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'], - K=opt_net['K'], opt=opt) + K=opt_net['K'], opt=opt) diff --git a/dlas/models/image_generation/srflow/Split.py b/dlas/models/image_generation/srflow/Split.py index 304c0e6c..61124db2 100644 --- a/dlas/models/image_generation/srflow/Split.py +++ b/dlas/models/image_generation/srflow/Split.py @@ -1,10 +1,9 @@ import torch +from dlas.models.image_generation.srflow import thops +from dlas.models.image_generation.srflow.flow import Conv2dZeros, GaussianDiag +from dlas.utils.util import opt_get from torch import nn as nn -from models.image_generation.srflow import thops -from models.image_generation.srflow.flow import Conv2dZeros, GaussianDiag -from utils.util import opt_get - class Split2d(nn.Module): def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None): @@ -17,7 +16,8 @@ class Split2d(nn.Module): out_channels=self.num_channels_consume * 2) self.logs_eps = logs_eps self.position = position - self.gaussian_nll_weight = opt_get(opt, ['networks', 'generator', 'flow', 'gaussian_loss_weight'], 1) + self.gaussian_nll_weight = opt_get( + opt, ['networks', 'generator', 'flow', 'gaussian_loss_weight'], 1) def split2d_prior(self, z, ft): if ft is not None: @@ -33,7 +33,7 @@ class Split2d(nn.Module): # self.input = input z1, z2 = self.split_ratio(input) mean, logs = self.split2d_prior(z1, ft) - + eps = (z2 - mean) / self.exp_eps(logs) logdet = logdet + self.get_logdet(logs, mean, z2) @@ -47,9 +47,9 @@ class Split2d(nn.Module): mean, logs = self.split2d_prior(z1, ft) if eps is None: - #print("WARNING: eps is None, generating eps untested functionality!") + # print("WARNING: eps is None, generating eps untested functionality!") eps = GaussianDiag.sample(mean, logs, eps_std) - #eps = GaussianDiag.sample_eps(mean.shape, eps_std) + # eps = GaussianDiag.sample_eps(mean.shape, eps_std) eps = eps.to(mean.device) z2 = mean + self.exp_eps(logs) * eps @@ -65,5 +65,6 @@ class Split2d(nn.Module): return logdet_diff * self.gaussian_nll_weight def split_ratio(self, input): - z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...] - return z1, z2 \ No newline at end of file + z1, z2 = input[:, :self.num_channels_pass, ...], input[:, + self.num_channels_pass:, ...] + return z1, z2 diff --git a/dlas/models/image_generation/srflow/flow.py b/dlas/models/image_generation/srflow/flow.py index db9ad7c5..9a548e42 100644 --- a/dlas/models/image_generation/srflow/flow.py +++ b/dlas/models/image_generation/srflow/flow.py @@ -1,8 +1,9 @@ +import numpy as np import torch import torch.nn as nn -import numpy as np -from models.image_generation.srflow.FlowActNorms import ActNorm2d +from dlas.models.image_generation.srflow.FlowActNorms import ActNorm2d + from . import thops @@ -56,7 +57,8 @@ class Conv2dZeros(nn.Conv2d): super().__init__(in_channels, out_channels, kernel_size, stride, padding) # logscale_factor self.logscale_factor = logscale_factor - self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels, 1, 1))) + self.register_parameter("logs", nn.Parameter( + torch.zeros(out_channels, 1, 1))) # init self.weight.data.zero_() self.bias.data.zero_() diff --git a/dlas/models/image_generation/srflow/module_util.py b/dlas/models/image_generation/srflow/module_util.py index f50198dd..d6d1333b 100644 --- a/dlas/models/image_generation/srflow/module_util.py +++ b/dlas/models/image_generation/srflow/module_util.py @@ -1,8 +1,9 @@ import torch import torch.nn as nn -import torch.nn.init as init import torch.nn.functional as F -import torch_intermediary as ml +import torch.nn.init as init + +import dlas.torch_intermediary as ml def initialize_weights(net_l, scale=1): @@ -76,5 +77,6 @@ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) - output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) + output = F.grid_sample( + x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) return output diff --git a/dlas/models/image_generation/srflow/thops.py b/dlas/models/image_generation/srflow/thops.py index 6cbc28b6..3add8514 100644 --- a/dlas/models/image_generation/srflow/thops.py +++ b/dlas/models/image_generation/srflow/thops.py @@ -49,4 +49,4 @@ def cat_feature(tensor_a, tensor_b): def pixels(tensor): - return int(tensor.size(2) * tensor.size(3)) \ No newline at end of file + return int(tensor.size(2) * tensor.size(3)) diff --git a/dlas/models/image_generation/stylegan/Discriminator_StyleGAN.py b/dlas/models/image_generation/stylegan/Discriminator_StyleGAN.py index 5722ecd8..7738a709 100644 --- a/dlas/models/image_generation/stylegan/Discriminator_StyleGAN.py +++ b/dlas/models/image_generation/stylegan/Discriminator_StyleGAN.py @@ -1,12 +1,12 @@ from collections import OrderedDict -import torch -from torch import nn -import torch.nn.functional as F import numpy as np +import torch +import torch.nn.functional as F +from torch import nn -from trainer.networks import register_model -from utils.util import opt_get +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get class BlurLayer(nn.Module): @@ -45,8 +45,10 @@ class Upscale2d(nn.Module): x = x * gain if factor != 1: shape = x.shape - x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, factor, -1, factor) - x = x.contiguous().view(shape[0], shape[1], factor * shape[2], factor * shape[3]) + x = x.view(shape[0], shape[1], shape[2], 1, shape[3], + 1).expand(-1, -1, -1, factor, -1, factor) + x = x.contiguous().view( + shape[0], shape[1], factor * shape[2], factor * shape[3]) return x def __init__(self, factor=2, gain=1): @@ -104,7 +106,8 @@ class EqualizedConv2d(nn.Module): self.downscale = Downscale2d() else: self.downscale = None - he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) # He init + he_std = gain * (input_channels * kernel_size ** + 2) ** (-0.5) # He init self.kernel_size = kernel_size if use_wscale: init_std = 1.0 / lrmul @@ -134,8 +137,10 @@ class EqualizedConv2d(nn.Module): w = w.permute(1, 0, 2, 3) # probably applying a conv on w would be more efficient. also this quadruples the weight (average)?! w = F.pad(w, [1, 1, 1, 1]) - w = w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1] - x = F.conv_transpose2d(x, w, stride=2, padding=(w.size(-1) - 1) // 2) + w = w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + \ + w[:, :, 1:, :-1] + w[:, :, :-1, :-1] + x = F.conv_transpose2d( + x, w, stride=2, padding=(w.size(-1) - 1) // 2) have_convolution = True elif self.upscale is not None: x = self.upscale(x) @@ -146,7 +151,8 @@ class EqualizedConv2d(nn.Module): w = self.weight * self.w_mul w = F.pad(w, [1, 1, 1, 1]) # in contrast to upscale, this is a mean... - w = (w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]) * 0.25 # avg_pool? + w = (w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]) * 0.25 # avg_pool? x = F.conv2d(x, w, stride=2, padding=(w.size(-1) - 1) // 2) have_convolution = True downscale = None @@ -157,7 +163,8 @@ class EqualizedConv2d(nn.Module): if not have_convolution and intermediate is None: return F.conv2d(x, self.weight * self.w_mul, bias, padding=self.kernel_size // 2) elif not have_convolution: - x = F.conv2d(x, self.weight * self.w_mul, None, padding=self.kernel_size // 2) + x = F.conv2d(x, self.weight * self.w_mul, None, + padding=self.kernel_size // 2) if intermediate is not None: x = intermediate(x) @@ -180,7 +187,8 @@ class EqualizedLinear(nn.Module): else: init_std = he_std / lrmul self.w_mul = lrmul - self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std) + self.weight = torch.nn.Parameter( + torch.randn(output_size, input_size) * init_std) if bias: self.bias = torch.nn.Parameter(torch.zeros(output_size)) self.b_mul = lrmul @@ -199,7 +207,6 @@ class View(nn.Module): super().__init__() self.shape = shape - def forward(self, x): return x.view(x.size(0), *self.shape) @@ -218,8 +225,10 @@ class StddevLayer(nn.Module): y = y - y.mean(0, keepdim=True) y = (y ** 2).mean(0, keepdim=True) y = (y + 1e-8) ** 0.5 - y = y.mean([3, 4, 5], keepdim=True).squeeze(3) # don't keep the meaned-out channels - y = y.expand(group_size, -1, -1, h, w).clone().reshape(b, self.num_new_features, h, w) + y = y.mean([3, 4, 5], keepdim=True).squeeze( + 3) # don't keep the meaned-out channels + y = y.expand(group_size, -1, -1, h, w).clone().reshape(b, + self.num_new_features, h, w) z = torch.cat([x, y], dim=1) return z @@ -227,7 +236,8 @@ class StddevLayer(nn.Module): class DiscriminatorBlock(nn.Sequential): def __init__(self, in_channels, out_channels, gain, use_wscale, activation_layer, blur_kernel): super().__init__(OrderedDict([ - ('conv0', EqualizedConv2d(in_channels, in_channels, kernel_size=3, gain=gain, use_wscale=use_wscale)), + ('conv0', EqualizedConv2d(in_channels, in_channels, + kernel_size=3, gain=gain, use_wscale=use_wscale)), # out channels nf(res-1) ('act0', activation_layer), ('blur', BlurLayer(kernel=blur_kernel)), @@ -236,7 +246,6 @@ class DiscriminatorBlock(nn.Sequential): ('act1', activation_layer)])) - class DiscriminatorTop(nn.Sequential): def __init__(self, mbstd_group_size, @@ -265,7 +274,8 @@ class DiscriminatorTop(nn.Sequential): layers = [] if mbstd_group_size > 1: - layers.append(('stddev_layer', StddevLayer(mbstd_group_size, mbstd_num_features))) + layers.append(('stddev_layer', StddevLayer( + mbstd_group_size, mbstd_num_features))) if in_channels2 is None: in_channels2 = in_channels @@ -362,8 +372,10 @@ class StyleGanDiscriminator(nn.Module): elif self.structure == 'linear': assert depth < self.depth, "Requested output depth cannot be produced" if depth > 0: - residual = self.from_rgb[self.depth - depth](self.temporaryDownsampler(images_in)) - straight = self.blocks[self.depth - depth - 1](self.from_rgb[self.depth - depth - 1](images_in)) + residual = self.from_rgb[self.depth - + depth](self.temporaryDownsampler(images_in)) + straight = self.blocks[self.depth - depth - + 1](self.from_rgb[self.depth - depth - 1](images_in)) x = (alpha * straight) + ((1 - alpha) * residual) for block in self.blocks[(self.depth - depth):]: @@ -380,4 +392,4 @@ class StyleGanDiscriminator(nn.Module): @register_model def register_stylegan_vgg(opt_net, opt): - return StyleGanDiscriminator(opt_get(opt_net, ['image_size'], 128)) \ No newline at end of file + return StyleGanDiscriminator(opt_get(opt_net, ['image_size'], 128)) diff --git a/dlas/models/image_generation/stylegan/__init__.py b/dlas/models/image_generation/stylegan/__init__.py index e219dc28..412350f9 100644 --- a/dlas/models/image_generation/stylegan/__init__.py +++ b/dlas/models/image_generation/stylegan/__init__.py @@ -2,10 +2,10 @@ def create_stylegan2_loss(opt_loss, env): type = opt_loss['type'] if type == 'stylegan2_divergence': - import models.image_generation.stylegan.stylegan2_lucidrains as stylegan2 + import dlas.models.image_generation.stylegan.stylegan2_lucidrains as stylegan2 return stylegan2.StyleGan2DivergenceLoss(opt_loss, env) elif type == 'stylegan2_pathlen': - import models.image_generation.stylegan.stylegan2_lucidrains as stylegan2 + import dlas.models.image_generation.stylegan.stylegan2_lucidrains as stylegan2 return stylegan2.StyleGan2PathLengthLoss(opt_loss, env) else: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/dlas/models/image_generation/stylegan/stylegan2_lucidrains.py b/dlas/models/image_generation/stylegan/stylegan2_lucidrains.py index 6016fb9e..32b3ecc6 100644 --- a/dlas/models/image_generation/stylegan/stylegan2_lucidrains.py +++ b/dlas/models/image_generation/stylegan/stylegan2_lucidrains.py @@ -1,24 +1,23 @@ -import functools import math import multiprocessing -from contextlib import contextmanager, ExitStack +from contextlib import ExitStack, contextmanager from functools import partial from math import log2 from random import random +import numpy as np import torch import torch.nn.functional as F -import trainer.losses as L -import numpy as np - from kornia.filters import filter2d from linear_attention_transformer import ImageLinearAttention from torch import nn from torch.autograd import grad as torch_grad from vector_quantize_pytorch import VectorQuantize -from trainer.networks import register_model -from utils.util import checkpoint, opt_get +import dlas.torch_intermediary as ml +import dlas.trainer.losses as L +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint, opt_get try: from apex import amp @@ -28,7 +27,6 @@ except: APEX_AVAILABLE = False assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.' -import torch_intermediary as ml num_cores = multiprocessing.cpu_count() @@ -46,24 +44,33 @@ def DiffAugment(x, types=[]): x = f(x) return x.contiguous() + def rand_brightness(x): x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) return x + def rand_saturation(x): x_mean = x.mean(dim=1, keepdim=True) - x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, + dtype=x.dtype, device=x.device) * 2) + x_mean return x + def rand_contrast(x): x_mean = x.mean(dim=[1, 2, 3], keepdim=True) - x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, + dtype=x.dtype, device=x.device) + 0.5) + x_mean return x + def rand_translation(x, ratio=0.125): - shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) - translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) - translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) + shift_x, shift_y = int(x.size(2) * ratio + + 0.5), int(x.size(3) * ratio + 0.5) + translation_x = torch.randint(-shift_x, shift_x + 1, + size=[x.size(0), 1, 1], device=x.device) + translation_y = torch.randint(-shift_y, shift_y + 1, + size=[x.size(0), 1, 1], device=x.device) grid_batch, grid_x, grid_y = torch.meshgrid( torch.arange(x.size(0), dtype=torch.long, device=x.device), torch.arange(x.size(2), dtype=torch.long, device=x.device), @@ -72,31 +79,40 @@ def rand_translation(x, ratio=0.125): grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) - x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) + x = x_pad.permute(0, 2, 3, 1).contiguous()[ + grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) return x + def rand_cutout(x, ratio=0.5): cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) - offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) - offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) + offset_x = torch.randint(0, x.size( + 2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) + offset_y = torch.randint(0, x.size( + 3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) grid_batch, grid_x, grid_y = torch.meshgrid( torch.arange(x.size(0), dtype=torch.long, device=x.device), torch.arange(cutout_size[0], dtype=torch.long, device=x.device), torch.arange(cutout_size[1], dtype=torch.long, device=x.device), ) - grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) - grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) - mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) + grid_x = torch.clamp(grid_x + offset_x - + cutout_size[0] // 2, min=0, max=x.size(2) - 1) + grid_y = torch.clamp(grid_y + offset_y - + cutout_size[1] // 2, min=0, max=x.size(3) - 1) + mask = torch.ones(x.size(0), x.size(2), x.size(3), + dtype=x.dtype, device=x.device) mask[grid_batch, grid_x, grid_y] = 0 x = x * mask.unsqueeze(1) return x + AUGMENT_FNS = { 'color': [rand_brightness, rand_saturation, rand_contrast], 'translation': [rand_translation], 'cutout': [rand_cutout], } + class NanException(Exception): pass @@ -162,9 +178,10 @@ class Blur(nn.Module): # one layer of self-attention and feedforward, for images -attn_and_ff = lambda chan: nn.Sequential(*[ +def attn_and_ff(chan): return nn.Sequential(*[ Residual(Rezero(ImageLinearAttention(chan, norm_queries=True))), - Residual(Rezero(nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1)))) + Residual(Rezero(nn.Sequential(nn.Conv2d(chan, chan * 2, 1), + leaky_relu(), nn.Conv2d(chan * 2, chan, 1)))) ]) @@ -216,7 +233,8 @@ def raise_if_nan(t): def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps): if is_ddp: num_no_syncs = gradient_accumulate_every - 1 - head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs + head = [combine_contexts( + map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs tail = [null_context] contexts = head + tail else: @@ -238,7 +256,8 @@ def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs): def gradient_penalty(images, output, weight=10, return_structured_grads=False): batch_size = images.shape[0] gradients = torch_grad(outputs=output, inputs=images, - grad_outputs=torch.ones(output.size(), device=images.device), + grad_outputs=torch.ones( + output.size(), device=images.device), create_graph=True, retain_graph=True, only_inputs=True)[0] flat_grad = gradients.reshape(batch_size, -1) @@ -248,6 +267,7 @@ def gradient_penalty(images, output, weight=10, return_structured_grads=False): else: return penalty + def calc_pl_lengths(styles, images): num_pixels = images.shape[2] * images.shape[3] pl_noise = torch.randn_like(images) / math.sqrt(num_pixels) @@ -269,7 +289,8 @@ def leaky_relu(p=0.2): def evaluate_in_chunks(max_batch_size, model, *args): - split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) + split_args = list( + zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) chunked_outputs = [model(*i) for i in split_args] if len(chunked_outputs) == 1: return chunked_outputs[0] @@ -286,11 +307,13 @@ def slerp(val, low, high): high_norm = high / torch.norm(high, dim=1, keepdim=True) omega = torch.acos((low_norm * high_norm).sum(1)) so = torch.sin(omega) - res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high + res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * \ + low + (torch.sin(val * omega) / so).unsqueeze(1) * high return res # augmentations + def random_hflip(tensor, prob): if prob > random(): return tensor @@ -380,8 +403,10 @@ class AdaptiveInstanceNorm(nn.Module): def __init__(self, in_channel, style_dim): super().__init__() from models.archs.arch_util import ConvGnLelu - self.style2scale = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True) - self.style2bias = ConvGnLelu(style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True, weight_init_factor=0) + self.style2scale = ConvGnLelu( + style_dim, in_channel, kernel_size=1, norm=False, activation=False, bias=True) + self.style2bias = ConvGnLelu(style_dim, in_channel, kernel_size=1, + norm=False, activation=False, bias=True, weight_init_factor=0) self.norm = nn.InstanceNorm2d(in_channel) def forward(self, input, style): @@ -454,8 +479,10 @@ class Conv2DMod(nn.Module): self.kernel = kernel self.stride = stride self.dilation = dilation - self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel))) - nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') + self.weight = nn.Parameter(torch.randn( + (out_chan, in_chan, kernel, kernel))) + nn.init.kaiming_normal_( + self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') def _get_same_padding(self, size, kernel, dilation, stride): return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2 @@ -468,7 +495,8 @@ class Conv2DMod(nn.Module): weights = w2 * (w1 + 1) if self.demod: - d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + EPS) + d = torch.rsqrt( + (weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + EPS) weights = weights * d x = x.reshape(1, -1, h, w) @@ -476,7 +504,8 @@ class Conv2DMod(nn.Module): _, _, *ws = weights.shape weights = weights.reshape(b * self.filters, *ws) - padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride) + padding = self._get_same_padding( + h, self.kernel, self.dilation, self.stride) x = F.conv2d(x, weights, padding=padding, groups=b) x = x.reshape(-1, self.filters, h, w) @@ -486,7 +515,8 @@ class Conv2DMod(nn.Module): class GeneratorBlockWithStructure(nn.Module): def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False): super().__init__() - self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None + self.upsample = nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=False) if upsample else None # Uses stylegan1 style blocks for injecting structural latent. self.conv0 = EqualConv2d(input_channels, filters, 3, padding=1) @@ -514,7 +544,8 @@ class GeneratorBlockWithStructure(nn.Module): noise1 = self.to_noise1(inoise).permute((0, 3, 1, 2)) noise2 = self.to_noise2(inoise).permute((0, 3, 1, 2)) - structure = torch.nn.functional.interpolate(structure_input, size=x.shape[2:], mode="nearest") + structure = torch.nn.functional.interpolate( + structure_input, size=x.shape[2:], mode="nearest") x = self.conv0(x) x = self.noise0(x, noise0) x = self.adain0(x, structure) @@ -534,7 +565,8 @@ class GeneratorBlockWithStructure(nn.Module): class GeneratorBlock(nn.Module): def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False, structure_input=False): super().__init__() - self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None + self.upsample = nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=False) if upsample else None self.structure_input = structure_input if self.structure_input: @@ -584,7 +616,8 @@ class Generator(nn.Module): self.latent_dim = latent_dim self.num_layers = int(log2(image_size) - 1) - filters = [network_capacity * (2 ** (i + 1)) for i in range(self.num_layers)][::-1] + filters = [network_capacity * (2 ** (i + 1)) + for i in range(self.num_layers)][::-1] set_fmap_max = partial(min, fmap_max) filters = list(map(set_fmap_max, filters)) @@ -595,9 +628,11 @@ class Generator(nn.Module): self.no_const = no_const if no_const: - self.to_initial_block = nn.ConvTranspose2d(latent_dim, init_channels, 4, 1, 0, bias=False) + self.to_initial_block = nn.ConvTranspose2d( + latent_dim, init_channels, 4, 1, 0, bias=False) else: - self.initial_block = nn.Parameter(torch.randn((1, init_channels, 4, 4))) + self.initial_block = nn.Parameter( + torch.randn((1, init_channels, 4, 4))) self.initial_conv = nn.Conv2d(filters[0], filters[0], 3, padding=1) self.blocks = nn.ModuleList([]) @@ -608,7 +643,8 @@ class Generator(nn.Module): not_last = ind != (self.num_layers - 1) num_layer = self.num_layers - ind - attn_fn = attn_and_ff(in_chan) if num_layer in attn_layers else None + attn_fn = attn_and_ff( + in_chan) if num_layer in attn_layers else None self.attns.append(attn_fn) @@ -644,7 +680,8 @@ class Generator(nn.Module): x = self.initial_conv(x) if structure_input is not None: - s = torch.nn.functional.interpolate(structure_input, size=x.shape[2:], mode="nearest") + s = torch.nn.functional.interpolate( + structure_input, size=x.shape[2:], mode="nearest") for style, block, attn in zip(styles, self.blocks, self.attns): if exists(attn): x = checkpoint(attn, x) @@ -652,8 +689,10 @@ class Generator(nn.Module): if exists(block.upsample): # In this case, the structural guidance is given by the extra information over the previous layer. twoX = (x.shape[2]*2, x.shape[3]*2) - sn = torch.nn.functional.interpolate(structure_input, size=twoX, mode="nearest") - s_int = torch.nn.functional.interpolate(s, size=twoX, mode="bilinear") + sn = torch.nn.functional.interpolate( + structure_input, size=twoX, mode="nearest") + s_int = torch.nn.functional.interpolate( + s, size=twoX, mode="bilinear") s_diff = sn - s_int else: # This is the initial case - just feed in the base structure. @@ -670,7 +709,8 @@ class StyleGan2GeneratorWithLatent(nn.Module): def __init__(self, image_size, latent_dim=512, style_depth=8, lr_mlp=.1, network_capacity=16, transparent=False, attn_layers=[], no_const=False, fmap_max=512, structure_input=False): super().__init__() - self.vectorizer = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp) + self.vectorizer = StyleVectorizer( + latent_dim, style_depth, lr_mul=lr_mlp) self.gen = Generator(image_size, latent_dim, network_capacity, transparent, attn_layers, no_const, fmap_max, structure_input=structure_input) self.mixed_prob = .9 @@ -702,7 +742,8 @@ class StyleGan2GeneratorWithLatent(nn.Module): style = self.noise(b*2, self.gen.latent_dim, x.device) w = self.vectorizer(style) # Randomly distribute styles across layers - w_styles = w[:,None,:].expand(-1, self.gen.num_layers, -1).clone() + w_styles = w[:, None, + :].expand(-1, self.gen.num_layers, -1).clone() for j in range(b): cutoff = int(torch.rand(()).numpy() * self.gen.num_layers) if cutoff == self.gen.num_layers or random() > self.mixed_prob: @@ -713,7 +754,8 @@ class StyleGan2GeneratorWithLatent(nn.Module): w_styles = w_styles[:b] else: get_latents_fn = self.mixed_list if random() < self.mixed_prob else self.noise_list - style = get_latents_fn(b, self.gen.num_layers, self.gen.latent_dim, device=x.device) + style = get_latents_fn( + b, self.gen.num_layers, self.gen.latent_dim, device=x.device) w_space = self.latent_to_w(self.vectorizer, style) w_styles = self.styles_def_to_tensor(w_space) @@ -721,12 +763,13 @@ class StyleGan2GeneratorWithLatent(nn.Module): if fit_starting_shape_to_structure: starting_shape = (x.shape[2] // 32, x.shape[3] // 32) # The underlying model expects the noise as b,h,w,1. Make it so. - return self.gen(w_styles, x[:,0,:,:].unsqueeze(dim=3), structure_input, starting_shape), w_styles + return self.gen(w_styles, x[:, 0, :, :].unsqueeze(dim=3), structure_input, starting_shape), w_styles def _init_weights(self): for m in self.modules(): if type(m) in {nn.Conv2d, ml.Linear} and hasattr(m, 'weight'): - nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') + nn.init.kaiming_normal_( + m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') for block in self.gen.blocks: nn.init.zeros_(block.to_noise1.weight) @@ -738,7 +781,8 @@ class StyleGan2GeneratorWithLatent(nn.Module): class DiscriminatorBlock(nn.Module): def __init__(self, input_channels, filters, downsample=True): super().__init__() - self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1)) + self.conv_res = nn.Conv2d( + input_channels, filters, 1, stride=(2 if downsample else 1)) self.net = nn.Sequential( nn.Conv2d(input_channels, filters, 3, padding=1), @@ -768,7 +812,8 @@ class StyleGan2Discriminator(nn.Module): num_layers = int(log2(image_size) - 1) blocks = [] - filters = [input_filters] + [(64) * (2 ** i) for i in range(num_layers + 1)] + filters = [input_filters] + [(64) * (2 ** i) + for i in range(num_layers + 1)] set_fmap_max = partial(min, fmap_max) filters = list(map(set_fmap_max, filters)) @@ -782,15 +827,18 @@ class StyleGan2Discriminator(nn.Module): num_layer = ind + 1 is_not_last = ind != (len(chan_in_out) - 1) - block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last) + block = DiscriminatorBlock( + in_chan, out_chan, downsample=is_not_last) blocks.append(block) - attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None + attn_fn = attn_and_ff( + out_chan) if num_layer in attn_layers else None attn_blocks.append(attn_fn) if quantize: - quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None + quantize_fn = PermuteToFrom(VectorQuantize( + out_chan, fq_dict_size)) if num_layer in fq_layers else None quantize_blocks.append(quantize_fn) else: quantize_blocks.append(None) @@ -838,7 +886,8 @@ class StyleGan2Discriminator(nn.Module): def _init_weights(self): for m in self.modules(): if type(m) in {nn.Conv2d, ml.Linear}: - nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') + nn.init.kaiming_normal_( + m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') class StyleGan2DivergenceLoss(L.ConfigurableLoss): @@ -850,7 +899,8 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss): self.for_gen = opt['gen_loss'] self.gp_frequency = opt['gradient_penalty_frequency'] self.noise = opt['noise'] if 'noise' in opt.keys() else 0 - self.logistic = opt_get(opt, ['logistic'], False) # Applies a logistic curve to the output logits, which is what the StyleGAN2 authors used. + # Applies a logistic curve to the output logits, which is what the StyleGAN2 authors used. + self.logistic = opt_get(opt, ['logistic'], False) def forward(self, net, state): real_input = state[self.real] @@ -867,7 +917,8 @@ class StyleGan2DivergenceLoss(L.ConfigurableLoss): else: return fake.mean() else: - real_input.requires_grad_() # <-- Needed to compute gradients on the input. + # <-- Needed to compute gradients on the input. + real_input.requires_grad_() real = D(real_input) if self.logistic: rl = F.softplus(-real).mean() @@ -907,13 +958,15 @@ class StyleGan2PathLengthLoss(L.ConfigurableLoss): else: print("Path length loss returned NaN!") - self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length) + self.pl_mean = self.pl_length_ma.update_average( + self.pl_mean, avg_pl_length) return 0 @register_model def register_stylegan2_lucidrains(opt_net, opt): - is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False + is_structured = opt_net['structured'] if 'structured' in opt_net.keys( + ) else False attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] return StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'], style_depth=opt_net['style_depth'], structure_input=is_structured, @@ -924,6 +977,7 @@ def register_stylegan2_lucidrains(opt_net, opt): def register_stylegan2_discriminator(opt_net, opt): attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else [] disc = StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn, - do_checkpointing=opt_get(opt_net, ['do_checkpointing'], False), + do_checkpointing=opt_get( + opt_net, ['do_checkpointing'], False), quantize=opt_get(opt_net, ['quantize'], False)) return StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability']) diff --git a/dlas/models/image_generation/stylegan/stylegan2_rosinality.py b/dlas/models/image_generation/stylegan/stylegan2_rosinality.py index 90a14863..2d52b28e 100644 --- a/dlas/models/image_generation/stylegan/stylegan2_rosinality.py +++ b/dlas/models/image_generation/stylegan/stylegan2_rosinality.py @@ -1,18 +1,14 @@ import math import random -import functools -import operator import torch from torch import nn from torch.nn import functional as F -from torch.autograd import Function - # Ops -> The rosinality repo uses native cuda kernels for fused LeakyReLUs and upsamplers. This version extracts the # "cpu" alternative code and uses that instead for compatibility reasons. -from trainer.networks import register_model -from utils.util import opt_get +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get class FusedLeakyReLU(nn.Module): @@ -67,12 +63,13 @@ def upfirdn2d_native( out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = F.pad( - out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), + max(pad_y0, 0), max(pad_y1, 0)] ) out = out[ :, - max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), - max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0), :, ] @@ -133,7 +130,8 @@ class Upsample(nn.Module): self.pad = (pad0, pad1) def forward(self, input): - out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + out = upfirdn2d(input, self.kernel, up=self.factor, + down=1, pad=self.pad) return out @@ -154,7 +152,8 @@ class Downsample(nn.Module): self.pad = (pad0, pad1) def forward(self, input): - out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + out = upfirdn2d(input, self.kernel, up=1, + down=self.factor, pad=self.pad) return out @@ -280,7 +279,8 @@ class ModulatedConv2d(nn.Module): pad0 = (p + 1) // 2 + factor - 1 pad1 = p // 2 + 1 - self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + self.blur = Blur(blur_kernel, pad=( + pad0, pad1), upsample_factor=factor) if downsample: factor = 2 @@ -330,7 +330,8 @@ class ModulatedConv2d(nn.Module): weight = weight.transpose(1, 2).reshape( batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size ) - out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + out = F.conv_transpose2d( + input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) out = self.blur(out) @@ -423,7 +424,8 @@ class ToRGB(nn.Module): if upsample: self.upsample = Upsample(blur_kernel) - self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.conv = ModulatedConv2d( + in_channel, 3, 1, style_dim, demodulate=False) self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) def forward(self, input, style, skip=None): @@ -496,7 +498,8 @@ class Generator(nn.Module): for layer_idx in range(self.num_layers): res = (layer_idx + 5) // 2 shape = [1, 1, 2 ** res, 2 ** res] - self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape)) + self.noises.register_buffer( + f"noise_{layer_idx}", torch.randn(*shape)) for i in range(3, self.log_size + 1): out_channel = self.channels[2 ** i] @@ -578,7 +581,8 @@ class Generator(nn.Module): for style in styles: style_t.append( - truncation_latent + truncation * (style - truncation_latent) + truncation_latent + truncation * + (style - truncation_latent) ) styles = style_t @@ -597,7 +601,8 @@ class Generator(nn.Module): inject_index = random.randint(1, self.n_latent - 1) latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) - latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat( + 1, self.n_latent - inject_index, 1) latent = torch.cat([latent, latent2], 1) @@ -723,7 +728,8 @@ class Discriminator(nn.Module): self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) self.final_linear = nn.Sequential( - EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), + EqualLinear(channels[4] * 4 * 4, channels[4], + activation="fused_lrelu"), EqualLinear(channels[4], 1), ) @@ -753,6 +759,7 @@ def register_stylegan2_rosinality_gen(opt_net, opt): kw = opt_get(opt_net, ['kwargs'], {}) return Generator(**kw) + @register_model def register_stylegan2_rosinality_disc(opt_net, opt): kw = opt_get(opt_net, ['kwargs'], {}) diff --git a/dlas/models/image_latents/byol/byol_model_wrapper.py b/dlas/models/image_latents/byol/byol_model_wrapper.py index 6c8bb3e3..c0877ee5 100644 --- a/dlas/models/image_latents/byol/byol_model_wrapper.py +++ b/dlas/models/image_latents/byol/byol_model_wrapper.py @@ -1,18 +1,18 @@ import copy import os from functools import wraps -import kornia.augmentation as augs +import kornia.augmentation as augs import torch import torch.nn.functional as F import torchvision from kornia import filters from torch import nn -from data.images.byol_attachment import RandomApply -from trainer.networks import register_model, create_model -from utils.util import checkpoint, opt_get -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.data.images.byol_attachment import RandomApply +from dlas.trainer.networks import create_model, register_model +from dlas.utils.util import checkpoint, opt_get def default(val, def_val): @@ -152,10 +152,12 @@ class NetWrapper(nn.Module): @singleton('projector') def _get_projector(self, hidden): if self.structural_mlp: - projector = StructuralMLP(hidden.shape, self.projection_size, self.projection_hidden_size) + projector = StructuralMLP( + hidden.shape, self.projection_size, self.projection_hidden_size) else: - _, dim = hidden.flatten(1,-1).shape - projector = MLP(dim, self.projection_size, self.projection_hidden_size) + _, dim = hidden.flatten(1, -1).shape + projector = MLP(dim, self.projection_size, + self.projection_hidden_size) return projector.to(hidden) def get_representation(self, x): @@ -189,7 +191,8 @@ class BYOL(nn.Module): moving_average_decay=0.99, use_momentum=True, structural_mlp=False, - positional_dimension=2, # 2 for images, 1 for audio, everything else isn't supported. + # 2 for images, 1 for audio, everything else isn't supported. + positional_dimension=2, perform_augmentation=True, ): super().__init__() @@ -199,7 +202,7 @@ class BYOL(nn.Module): self.perform_augmentation = perform_augmentation if self.perform_augmentation: - augmentations = [ \ + augmentations = [ RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), augs.RandomGrayscale(p=0.2), augs.RandomHorizontalFlip(), @@ -210,7 +213,8 @@ class BYOL(nn.Module): self.target_encoder = None self.target_ema_updater = EMA(moving_average_decay) - self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size) + self.online_predictor = MLP( + projection_size, projection_size, projection_hidden_size) # get device of network and make wrapper same device device = get_module_device(net) @@ -240,7 +244,8 @@ class BYOL(nn.Module): def update_for_step(self, step, __): assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder' assert self.target_encoder is not None, 'target encoder has not been created yet' - update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder) + update_moving_average(self.target_ema_updater, + self.target_encoder, self.online_encoder) def get_debug_values(self, step, __): # In the BYOL paper, this is made to increase over time. Not yet implemented, but still logging the value. @@ -248,8 +253,10 @@ class BYOL(nn.Module): def visual_dbg(self, step, path): if self.perform_augmentation and self.positional_dimension == 2: - torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,))) - torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,))) + torchvision.utils.save_image(self.im1.cpu().float( + ), os.path.join(path, "%i_image1.png" % (step,))) + torchvision.utils.save_image(self.im2.cpu().float( + ), os.path.join(path, "%i_image2.png" % (step,))) def forward(self, image_one, image_two): if self.perform_augmentation: @@ -266,7 +273,8 @@ class BYOL(nn.Module): online_pred_two = self.online_predictor(online_proj_two) with torch.no_grad(): - target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder + target_encoder = self._get_target_encoder( + ) if self.use_momentum else self.online_encoder target_proj_one = target_encoder(image_one).detach() target_proj_two = target_encoder(image_two).detach() diff --git a/dlas/models/image_latents/byol/byol_structural.py b/dlas/models/image_latents/byol/byol_structural.py index dbb8fe5f..c4db9076 100644 --- a/dlas/models/image_latents/byol/byol_structural.py +++ b/dlas/models/image_latents/byol/byol_structural.py @@ -4,17 +4,19 @@ import torch import torch.nn.functional as F from torch import nn -from data.images.byol_attachment import reconstructed_shared_regions -from models.image_latents.byol.byol_model_wrapper import singleton, EMA, get_module_device, set_requires_grad, \ - update_moving_average -from trainer.networks import create_model, register_model -from utils.util import checkpoint +from dlas.data.images.byol_attachment import reconstructed_shared_regions +from dlas.models.image_latents.byol.byol_model_wrapper import ( + EMA, get_module_device, set_requires_grad, singleton, + update_moving_average) +from dlas.trainer.networks import create_model, register_model +from dlas.utils.util import checkpoint + # loss function def structural_loss_fn(x, y): # Combine the structural dimensions into the batch dimension, then compute the "normal" BYOL loss. - x = x.permute(0,2,3,1).flatten(0,2) - y = y.permute(0,2,3,1).flatten(0,2) + x = x.permute(0, 2, 3, 1).flatten(0, 2) + y = y.permute(0, 2, 3, 1).flatten(0, 2) x = F.normalize(x, dim=-1, p=2) y = F.normalize(y, dim=-1, p=2) return 2 - 2 * (x * y).sum(dim=-1) @@ -70,7 +72,8 @@ class NetWrapper(nn.Module): @singleton('projector') def _get_projector(self, hidden): - projector = StructuralTail(hidden.shape[1], self.projection_size, self.projection_hidden_size) + projector = StructuralTail( + hidden.shape[1], self.projection_size, self.projection_hidden_size) return projector.to(hidden) def get_representation(self, x): @@ -116,13 +119,15 @@ class StructuralBYOL(nn.Module): for p in net.parameters(): p.DO_NOT_TRAIN = True self.frozen = True - self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer) + self.online_encoder = NetWrapper( + net, projection_size, projection_hidden_size, layer=hidden_layer) self.use_momentum = use_momentum self.target_encoder = None self.target_ema_updater = EMA(moving_average_decay) - self.online_predictor = StructuralTail(projection_size, projection_size, projection_hidden_size) + self.online_predictor = StructuralTail( + projection_size, projection_size, projection_hidden_size) # get device of network and make wrapper same device device = get_module_device(net) @@ -145,7 +150,8 @@ class StructuralBYOL(nn.Module): def update_for_step(self, step, __): assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder' assert self.target_encoder is not None, 'target encoder has not been created yet' - update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder) + update_moving_average(self.target_ema_updater, + self.target_encoder, self.online_encoder) if self.frozen and self.freeze_until < step: print("Unfreezing model weights. Let the latent training commence..") for p in self.online_encoder.net.parameters(): @@ -160,18 +166,23 @@ class StructuralBYOL(nn.Module): online_pred_two = self.online_predictor(online_proj_two) with torch.no_grad(): - target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder + target_encoder = self._get_target_encoder( + ) if self.use_momentum else self.online_encoder target_proj_one = target_encoder(image_one).detach() target_proj_two = target_encoder(image_two).detach() # In the structural BYOL, only the regions of the source image that are shared between the two augments are # compared. These regions can be extracted from the latents using `reconstruct_shared_regions`. if similar_region_params is not None: - online_pred_one, target_proj_two = reconstructed_shared_regions(online_pred_one, target_proj_two, similar_region_params) - loss_one = structural_loss_fn(online_pred_one, target_proj_two.detach()) + online_pred_one, target_proj_two = reconstructed_shared_regions( + online_pred_one, target_proj_two, similar_region_params) + loss_one = structural_loss_fn( + online_pred_one, target_proj_two.detach()) if similar_region_params is not None: - online_pred_two, target_proj_one = reconstructed_shared_regions(online_pred_two, target_proj_one, similar_region_params) - loss_two = structural_loss_fn(online_pred_two, target_proj_one.detach()) + online_pred_two, target_proj_one = reconstructed_shared_regions( + online_pred_two, target_proj_one, similar_region_params) + loss_two = structural_loss_fn( + online_pred_two, target_proj_one.detach()) loss = loss_one + loss_two return loss.mean() @@ -181,9 +192,11 @@ class StructuralBYOL(nn.Module): proj = self.online_predictor(enc) return enc, proj + @register_model def register_structural_byol(opt_net, opt): subnet = create_model(opt, opt_net['subnet']) return StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], - pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]), + pretrained_state_dict=opt_get( + opt_net, ["pretrained_path"]), freeze_until=opt_get(opt_net, ['freeze_until'], 0)) diff --git a/dlas/models/image_latents/fixup_resnet/DiscriminatorResnet_arch.py b/dlas/models/image_latents/fixup_resnet/DiscriminatorResnet_arch.py index 2c2a8fcd..d7f16cff 100644 --- a/dlas/models/image_latents/fixup_resnet/DiscriminatorResnet_arch.py +++ b/dlas/models/image_latents/fixup_resnet/DiscriminatorResnet_arch.py @@ -1,10 +1,11 @@ +import numpy as np import torch import torch.nn as nn -import numpy as np -import torch_intermediary as ml +import dlas.torch_intermediary as ml -__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152'] +__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', + 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152'] def conv3x3(in_planes, out_planes, stride=1): @@ -52,6 +53,7 @@ class FixupBasicBlock(nn.Module): return out + class FixupBottleneck(nn.Module): expansion = 4 @@ -104,26 +106,35 @@ class FixupResNet(nn.Module): self.bias1 = nn.Parameter(torch.zeros(1)) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.layer1 = self._make_layer(block, num_filters, layers[0], stride=2) - self.layer2 = self._make_layer(block, num_filters*2, layers[1], stride=2) - self.layer3 = self._make_layer(block, num_filters*4, layers[2], stride=2) - self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2) + self.layer2 = self._make_layer( + block, num_filters*2, layers[1], stride=2) + self.layer3 = self._make_layer( + block, num_filters*4, layers[2], stride=2) + self.layer4 = self._make_layer( + block, num_filters*8, layers[3], stride=2) self.bias2 = nn.Parameter(torch.zeros(1)) reduced_img_sz = int(input_img_size / 32) - self.fc1 = ml.Linear(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100) + self.fc1 = ml.Linear( + num_filters * 8 * reduced_img_sz * reduced_img_sz, 100) self.fc2 = ml.Linear(100, num_classes) for m in self.modules(): if isinstance(m, FixupBasicBlock): - nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5)) + nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt( + 2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5)) nn.init.constant_(m.conv2.weight, 0) if m.downsample is not None: - nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) + nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt( + 2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) elif isinstance(m, FixupBottleneck): - nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.25)) - nn.init.normal_(m.conv2.weight, mean=0, std=np.sqrt(2 / (m.conv2.weight.shape[0] * np.prod(m.conv2.weight.shape[2:]))) * self.num_layers ** (-0.25)) + nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt( + 2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.25)) + nn.init.normal_(m.conv2.weight, mean=0, std=np.sqrt( + 2 / (m.conv2.weight.shape[0] * np.prod(m.conv2.weight.shape[2:]))) * self.num_layers ** (-0.25)) nn.init.constant_(m.conv3.weight, 0) if m.downsample is not None: - nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) + nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt( + 2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) ''' elif isinstance(m, ml.Linear): nn.init.constant_(m.weight, 0) @@ -132,7 +143,8 @@ class FixupResNet(nn.Module): def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: - downsample = conv1x1(self.inplanes, planes * block.expansion, stride) + downsample = conv1x1(self.inplanes, planes * + block.expansion, stride) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) @@ -193,4 +205,5 @@ def fixup_resnet152(**kwargs): return model -__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152'] \ No newline at end of file +__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', + 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152'] diff --git a/dlas/models/image_latents/spinenet_arch.py b/dlas/models/image_latents/spinenet_arch.py index 8c887fb3..6d689726 100644 --- a/dlas/models/image_latents/spinenet_arch.py +++ b/dlas/models/image_latents/spinenet_arch.py @@ -4,10 +4,10 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.init import kaiming_normal - from torchvision.models.resnet import BasicBlock, Bottleneck -from models.arch_util import ConvGnSilu, ConvBnSilu, ConvBnRelu -from trainer.networks import register_model + +from dlas.models.arch_util import ConvBnRelu, ConvBnSilu, ConvGnSilu +from dlas.trainer.networks import register_model def constant_init(module, val, bias=0): @@ -16,6 +16,7 @@ def constant_init(module, val, bias=0): if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) + def kaiming_init(module, a=0, mode='fan_out', @@ -32,6 +33,7 @@ def kaiming_init(module, if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) + FILTER_SIZE_MAP = { 1: 32, 2: 64, @@ -42,6 +44,7 @@ FILTER_SIZE_MAP = { 7: 256, } + def make_res_layer(block, inplanes, planes, @@ -79,6 +82,7 @@ def make_res_layer(block, return nn.Sequential(*layers) + # The fixed SpineNet architecture discovered by NAS. # Each element represents a specification of a building block: # (block_level, block_fn, (input_offset0, input_offset1), is_output). @@ -133,20 +137,20 @@ SCALING_MAP = { class BlockSpec(object): - """A container class that specifies the block configuration for SpineNet.""" + """A container class that specifies the block configuration for SpineNet.""" - def __init__(self, level, block_fn, input_offsets, is_output): - self.level = level - self.block_fn = block_fn - self.input_offsets = input_offsets - self.is_output = is_output + def __init__(self, level, block_fn, input_offsets, is_output): + self.level = level + self.block_fn = block_fn + self.input_offsets = input_offsets + self.is_output = is_output def build_block_specs(block_specs=None): - """Builds the list of BlockSpec objects for SpineNet.""" - if not block_specs: - block_specs = SPINENET_BLOCK_SPECS - return [BlockSpec(*b) for b in block_specs] + """Builds the list of BlockSpec objects for SpineNet.""" + if not block_specs: + block_specs = SPINENET_BLOCK_SPECS + return [BlockSpec(*b) for b in block_specs] class Resample(nn.Module): @@ -156,10 +160,13 @@ class Resample(nn.Module): new_in_channels = int(in_channels * alpha) if block_type == Bottleneck: in_channels *= 4 - self.squeeze_conv = ConvGnSilu(in_channels, new_in_channels, kernel_size=1) + self.squeeze_conv = ConvGnSilu( + in_channels, new_in_channels, kernel_size=1) if scale < 1: - self.downsample_conv = ConvGnSilu(new_in_channels, new_in_channels, kernel_size=3, stride=2) - self.expand_conv = ConvGnSilu(new_in_channels, out_channels, kernel_size=1, activation=False) + self.downsample_conv = ConvGnSilu( + new_in_channels, new_in_channels, kernel_size=3, stride=2) + self.expand_conv = ConvGnSilu( + new_in_channels, out_channels, kernel_size=1, activation=False) def _resize(self, x): if self.scale == 1: @@ -170,7 +177,8 @@ class Resample(nn.Module): x = self.downsample_conv(x) if self.scale < 0.5: new_kernel_size = 3 if self.scale >= 0.25 else 5 - x = F.max_pool2d(x, kernel_size=new_kernel_size, stride=int(0.5/self.scale), padding=new_kernel_size//2) + x = F.max_pool2d(x, kernel_size=new_kernel_size, stride=int( + 0.5/self.scale), padding=new_kernel_size//2) return x def forward(self, inputs): @@ -182,9 +190,11 @@ class Resample(nn.Module): class Merge(nn.Module): """Merge two input tensors""" + def __init__(self, block_spec, alpha, filter_size_scale): super(Merge, self).__init__() - out_channels = int(FILTER_SIZE_MAP[block_spec.level] * filter_size_scale) + out_channels = int( + FILTER_SIZE_MAP[block_spec.level] * filter_size_scale) if block_spec.block_fn == Bottleneck: out_channels *= 4 self.block = block_spec.block_fn @@ -194,7 +204,8 @@ class Merge(nn.Module): in_channels = int(FILTER_SIZE_MAP[spec.level] * filter_size_scale) scale = 2**(spec.level - block_spec.level) self.resample_ops.append( - Resample(in_channels, out_channels, scale, spec.block_fn, alpha) + Resample(in_channels, out_channels, + scale, spec.block_fn, alpha) ) def forward(self, inputs): @@ -207,6 +218,7 @@ class Merge(nn.Module): class SpineNet(nn.Module): """Class to build SpineNet backbone""" + def __init__(self, arch, in_channels=3, @@ -227,7 +239,8 @@ class SpineNet(nn.Module): self._num_init_blocks = 2 self._early_double_reduce = double_reduce_early self.zero_init_residual = zero_init_residual - assert min(output_level) > 2 and max(output_level) < 8, "Output level out of range" + assert min(output_level) > 2 and max( + output_level) < 8, "Output level out of range" self.output_level = output_level self.use_input_norm = use_input_norm @@ -264,11 +277,12 @@ class SpineNet(nn.Module): self.endpoint_convs = nn.ModuleDict() for block_spec in self._block_specs: if block_spec.is_output: - in_channels = int(FILTER_SIZE_MAP[block_spec.level]*self._filter_size_scale) * 4 + in_channels = int( + FILTER_SIZE_MAP[block_spec.level]*self._filter_size_scale) * 4 self.endpoint_convs[str(block_spec.level)] = ConvGnSilu(in_channels, - self._endpoints_num_filters, - kernel_size=1, - activation=False) + self._endpoints_num_filters, + kernel_size=1, + activation=False) def _make_scale_permuted_network(self): self.merge_ops = nn.ModuleList() @@ -277,7 +291,8 @@ class SpineNet(nn.Module): self.merge_ops.append( Merge(spec, self._resample_alpha, self._filter_size_scale) ) - channels = int(FILTER_SIZE_MAP[spec.level] * self._filter_size_scale) + channels = int( + FILTER_SIZE_MAP[spec.level] * self._filter_size_scale) in_channels = channels * 4 if spec.block_fn == Bottleneck else channels self.scale_permuted_blocks.append( make_res_layer(spec.block_fn, @@ -302,8 +317,10 @@ class SpineNet(nn.Module): def forward(self, input): if self.conv1 is not None: if self.use_input_norm: - mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(input.device) - std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input.device) + mean = torch.Tensor([0.485, 0.456, 0.406]).view( + 1, 3, 1, 1).to(input.device) + std = torch.Tensor([0.229, 0.224, 0.225]).view( + 1, 3, 1, 1).to(input.device) input = (input - mean) / std feat = self.conv1(input) feat = self.maxpool(feat) @@ -316,7 +333,8 @@ class SpineNet(nn.Module): num_outgoing_connections = [0, 0] for i, spec in enumerate(self._block_specs): - target_feat = self.merge_ops[i]([block_feats[feat_idx] for feat_idx in spec.input_offsets]) + target_feat = self.merge_ops[i]( + [block_feats[feat_idx] for feat_idx in spec.input_offsets]) # Connect intermediate blocks with outdegree 0 to the output block. if spec.is_output: for j, (j_feat, j_connections) in enumerate( @@ -350,17 +368,21 @@ class SpinenetWithLogits(SpineNet): activation='relu', use_input_norm=False, double_reduce_early=True): - super().__init__(arch, in_channels, output_level, conv_cfg, norm_cfg, zero_init_residual, activation, use_input_norm, double_reduce_early) + super().__init__(arch, in_channels, output_level, conv_cfg, norm_cfg, + zero_init_residual, activation, use_input_norm, double_reduce_early) self.output_to_attach = output_to_attach self.tail = nn.Sequential(ConvBnRelu(256, 128, kernel_size=1, activation=True, norm=True, bias=False), - ConvBnRelu(128, 64, kernel_size=1, activation=True, norm=True, bias=False), - ConvBnRelu(64, num_labels, kernel_size=1, activation=False, norm=False, bias=True), + ConvBnRelu( + 128, 64, kernel_size=1, activation=True, norm=True, bias=False), + ConvBnRelu(64, num_labels, kernel_size=1, + activation=False, norm=False, bias=True), nn.Softmax(dim=1)) def forward(self, x): fea = super().forward(x)[self.output_to_attach] return self.tail(fea) + @register_model def register_spinenet(opt_net, opt): return SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm']) diff --git a/dlas/models/image_latents/vit_latent.py b/dlas/models/image_latents/vit_latent.py index b38671be..090520b3 100644 --- a/dlas/models/image_latents/vit_latent.py +++ b/dlas/models/image_latents/vit_latent.py @@ -2,35 +2,41 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models.arch_util import ResBlock -from models.lucidrains.x_transformers import Encoder -from trainer.networks import register_model -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.models.arch_util import ResBlock +from dlas.models.lucidrains.x_transformers import Encoder +from dlas.trainer.networks import register_model class VitLatent(nn.Module): def __init__(self, top_dim, hidden_dim, depth, dropout=.1): super().__init__() self.upper = nn.Sequential(nn.Conv2d(3, top_dim, kernel_size=7, padding=3, stride=2), - ResBlock(top_dim, use_conv=True, dropout=dropout), - ResBlock(top_dim, out_channels=top_dim*2, down=True, use_conv=True, dropout=dropout), - ResBlock(top_dim*2, use_conv=True, dropout=dropout), - ResBlock(top_dim*2, out_channels=top_dim*4, down=True, use_conv=True, dropout=dropout), - ResBlock(top_dim*4, use_conv=True, dropout=dropout), - ResBlock(top_dim*4, out_channels=hidden_dim, down=True, use_conv=True, dropout=dropout), + ResBlock(top_dim, use_conv=True, + dropout=dropout), + ResBlock(top_dim, out_channels=top_dim*2, + down=True, use_conv=True, dropout=dropout), + ResBlock(top_dim*2, use_conv=True, + dropout=dropout), + ResBlock(top_dim*2, out_channels=top_dim*4, + down=True, use_conv=True, dropout=dropout), + ResBlock(top_dim*4, use_conv=True, + dropout=dropout), + ResBlock(top_dim*4, out_channels=hidden_dim, + down=True, use_conv=True, dropout=dropout), nn.GroupNorm(8, hidden_dim)) self.encoder = Encoder( - dim=hidden_dim, - depth=depth, - heads=hidden_dim//64, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - ff_mult=2, - do_checkpointing=True - ) + dim=hidden_dim, + depth=depth, + heads=hidden_dim//64, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + ff_mult=2, + do_checkpointing=True + ) self.mlp = nn.Sequential(ml.Linear(hidden_dim, hidden_dim*2), nn.BatchNorm1d(hidden_dim*2), @@ -42,23 +48,24 @@ class VitLatent(nn.Module): def project(self, x): h = self.upper(x) - h = torch.flatten(h, 2).permute(0,2,1) - h = self.encoder(h)[:,0] + h = torch.flatten(h, 2).permute(0, 2, 1) + h = self.encoder(h)[:, 0] h_norm = F.normalize(h) return h_norm def forward(self, x1, x2): h1 = self.project(x1) - #p1 = self.mlp(h1) + # p1 = self.mlp(h1) h2 = self.project(x2) - #p2 = self.mlp(h2) + # p2 = self.mlp(h2) with torch.no_grad(): he1 = self.ema.project(x1) he2 = self.ema.project(x2) def csim(h1, h2): b = x1.shape[0] - sim = F.cosine_similarity(h1.unsqueeze(0), h2.unsqueeze(1).detach(), 2) + sim = F.cosine_similarity( + h1.unsqueeze(0), h2.unsqueeze(1).detach(), 2) eye = torch.eye(b, device=x1.device) neye = eye != 1 return -(sim*eye).sum()/b, (sim*neye).sum()/(b**2-b) @@ -83,6 +90,6 @@ def register_vit_latent(opt_net, opt): if __name__ == '__main__': net = VitLatent(128, 1024, 8) net.provide_ema(net) - x1 = torch.randn(2,3,244,244) - x2 = torch.randn(2,3,244,244) - net(x1,x2) + x1 = torch.randn(2, 3, 244, 244) + x2 = torch.randn(2, 3, 244, 244) + net(x1, x2) diff --git a/dlas/models/lucidrains/dalle/__init__.py b/dlas/models/lucidrains/dalle/__init__.py index d8f37633..07d1aee7 100644 --- a/dlas/models/lucidrains/dalle/__init__.py +++ b/dlas/models/lucidrains/dalle/__init__.py @@ -1 +1 @@ -# This directory contains some useful code from https://github.com/lucidrains/DALLE-pytorch/tree/main/dalle_pytorch \ No newline at end of file +# This directory contains some useful code from https://github.com/lucidrains/DALLE-pytorch/tree/main/dalle_pytorch diff --git a/dlas/models/lucidrains/dalle/attention.py b/dlas/models/lucidrains/dalle/attention.py index 09ec3978..c6657393 100644 --- a/dlas/models/lucidrains/dalle/attention.py +++ b/dlas/models/lucidrains/dalle/attention.py @@ -2,33 +2,39 @@ from inspect import isfunction from math import ceil import torch -from torch import nn, einsum import torch.nn.functional as F from einops import rearrange, repeat - from rotary_embedding_torch import apply_rotary_emb -import torch_intermediary as ml +from torch import einsum, nn + +import dlas.torch_intermediary as ml # helpers + def exists(val): return val is not None + def uniq(arr): - return{el: True for el in arr}.keys() + return {el: True for el in arr}.keys() + def default(val, d): if exists(val): return val return d() if isfunction(d) else d + def max_neg_value(t): return -torch.finfo(t.dtype).max -def stable_softmax(t, dim = -1, alpha = 32 ** 2): + +def stable_softmax(t, dim=-1, alpha=32 ** 2): t = t / alpha - t = t - torch.amax(t, dim = dim, keepdim = True).detach() - return (t * alpha).softmax(dim = dim) + t = t - torch.amax(t, dim=dim, keepdim=True).detach() + return (t * alpha).softmax(dim=dim) + def apply_pos_emb(pos_emb, qkv): n = qkv[0].shape[-2] @@ -37,10 +43,11 @@ def apply_pos_emb(pos_emb, qkv): # classes + class Attention(nn.Module): - def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False): + def __init__(self, dim, seq_len, causal=True, heads=8, dim_head=64, dropout=0., stable=False): super().__init__() - inner_dim = dim_head * heads + inner_dim = dim_head * heads self.heads = heads self.seq_len = seq_len self.scale = dim_head ** -0.5 @@ -48,18 +55,18 @@ class Attention(nn.Module): self.stable = stable self.causal = causal - self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False) + self.to_qkv = ml.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Sequential( ml.Linear(inner_dim, dim), nn.Dropout(dropout) ) - def forward(self, x, mask = None, rotary_pos_emb = None): + def forward(self, x, mask=None, rotary_pos_emb=None): b, n, _, h, device = *x.shape, self.heads, x.device softmax = torch.softmax if not self.stable else stable_softmax - qkv = self.to_qkv(x).chunk(3, dim = -1) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) if exists(rotary_pos_emb): q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v)) @@ -76,24 +83,25 @@ class Attention(nn.Module): if self.causal: i, j = dots.shape[-2:] - mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() + mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool() dots.masked_fill_(mask, mask_value) attn = softmax(dots, dim=-1) out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') - out = self.to_out(out) + out = self.to_out(out) return out # sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation + class SparseConvCausalAttention(nn.Module): - def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs): + def __init__(self, dim, seq_len, image_size=32, kernel_size=5, dilation=1, heads=8, dim_head=64, dropout=0., stable=False, **kwargs): super().__init__() assert kernel_size % 2 == 1, 'kernel size must be odd' - inner_dim = dim_head * heads + inner_dim = dim_head * heads self.seq_len = seq_len self.heads = heads self.scale = dim_head ** -0.5 @@ -103,15 +111,16 @@ class SparseConvCausalAttention(nn.Module): self.stable = stable - self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False) + self.to_qkv = ml.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Sequential( ml.Linear(inner_dim, dim), nn.Dropout(dropout) ) - def forward(self, x, mask = None, rotary_pos_emb = None): - b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device + def forward(self, x, mask=None, rotary_pos_emb=None): + b, n, _, h, img_size, kernel_size, dilation, seq_len, device = * \ + x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device softmax = torch.softmax if not self.stable else stable_softmax img_seq_len = img_size ** 2 @@ -120,22 +129,25 @@ class SparseConvCausalAttention(nn.Module): # padding padding = seq_len - n + 1 - mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool()) + mask = default(mask, lambda: torch.ones( + b, text_len, device=device).bool()) - x = F.pad(x, (0, 0, 0, padding), value = 0) + x = F.pad(x, (0, 0, 0, padding), value=0) mask = mask[:, :text_len] # derive query / keys / values - qkv = self.to_qkv(x).chunk(3, dim = -1) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv) + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange( + t, 'b n (h d) -> (b h) n d', h=h), qkv) if exists(rotary_pos_emb): q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v)) q *= self.scale - ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v)) + ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map( + lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v)) # text attention @@ -143,10 +155,11 @@ class SparseConvCausalAttention(nn.Module): mask_value = max_neg_value(dots_text) i, j = dots_text.shape[-2:] - text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() + text_causal_mask = torch.ones( + i, j, device=device).triu_(j - i + 1).bool() dots_text.masked_fill_(text_causal_mask, mask_value) - attn_text = softmax(dots_text, dim = -1) + attn_text = softmax(dots_text, dim=-1) out_text = einsum('b i j, b j d -> b i d', attn_text, v_text) # image attention @@ -154,9 +167,12 @@ class SparseConvCausalAttention(nn.Module): effective_kernel_size = (kernel_size - 1) * dilation + 1 padding = effective_kernel_size // 2 - k_img, v_img = map(lambda t: rearrange(t, 'b (h w) c -> b c h w', h = img_size), (k_img, v_img)) - k_img, v_img = map(lambda t: F.unfold(t, kernel_size, padding = padding, dilation = dilation), (k_img, v_img)) - k_img, v_img = map(lambda t: rearrange(t, 'b (d j) i -> b i j d', j = kernel_size ** 2), (k_img, v_img)) + k_img, v_img = map(lambda t: rearrange( + t, 'b (h w) c -> b c h w', h=img_size), (k_img, v_img)) + k_img, v_img = map(lambda t: F.unfold( + t, kernel_size, padding=padding, dilation=dilation), (k_img, v_img)) + k_img, v_img = map(lambda t: rearrange( + t, 'b (d j) i -> b i j d', j=kernel_size ** 2), (k_img, v_img)) # let image attend to all of text @@ -166,56 +182,63 @@ class SparseConvCausalAttention(nn.Module): # calculate causal attention for local convolution i, j = dots_image.shape[-2:] - img_seq = torch.arange(img_seq_len, device = device) - k_img_indices = rearrange(img_seq.float(), '(h w) -> () () h w', h = img_size) - k_img_indices = F.pad(k_img_indices, (padding,) * 4, value = img_seq_len) # padding set to be max, so it is never attended to - k_img_indices = F.unfold(k_img_indices, kernel_size, dilation = dilation) + img_seq = torch.arange(img_seq_len, device=device) + k_img_indices = rearrange( + img_seq.float(), '(h w) -> () () h w', h=img_size) + # padding set to be max, so it is never attended to + k_img_indices = F.pad(k_img_indices, (padding,) * 4, value=img_seq_len) + k_img_indices = F.unfold(k_img_indices, kernel_size, dilation=dilation) k_img_indices = rearrange(k_img_indices, 'b j i -> b i j') # mask image attention q_img_indices = rearrange(img_seq, 'i -> () i ()') - causal_mask = q_img_indices < k_img_indices + causal_mask = q_img_indices < k_img_indices # concat text mask with image causal mask - causal_mask = repeat(causal_mask, '() i j -> b i j', b = b * h) - mask = repeat(mask, 'b j -> (b h) i j', i = i, h = h) - mask = torch.cat((~mask, causal_mask), dim = -1) + causal_mask = repeat(causal_mask, '() i j -> b i j', b=b * h) + mask = repeat(mask, 'b j -> (b h) i j', i=i, h=h) + mask = torch.cat((~mask, causal_mask), dim=-1) # image can attend to all of text - dots = torch.cat((dots_image_to_text, dots_image), dim = -1) + dots = torch.cat((dots_image_to_text, dots_image), dim=-1) dots.masked_fill_(mask, mask_value) - attn = softmax(dots, dim = -1) + attn = softmax(dots, dim=-1) # aggregate - attn_image_to_text, attn_image = attn[..., :text_len], attn[..., text_len:] + attn_image_to_text, attn_image = attn[..., + :text_len], attn[..., text_len:] - out_image_to_image = einsum('b i j, b i j d -> b i d', attn_image, v_img) - out_image_to_text = einsum('b i j, b j d -> b i d', attn_image_to_text, v_text) + out_image_to_image = einsum( + 'b i j, b i j d -> b i d', attn_image, v_img) + out_image_to_text = einsum( + 'b i j, b j d -> b i d', attn_image_to_text, v_text) out_image = out_image_to_image + out_image_to_text # combine attended values for both text and image - out = torch.cat((out_text, out_image), dim = 1) + out = torch.cat((out_text, out_image), dim=1) - out = rearrange(out, '(b h) n d -> b n (h d)', h = h) - out = self.to_out(out) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out = self.to_out(out) return out[:, :n] # sparse axial causal attention + class SparseAxialCausalAttention(nn.Module): - def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs): + def __init__(self, dim, seq_len, image_size=32, axis=0, heads=8, dim_head=64, dropout=0., stable=False, **kwargs): super().__init__() - assert axis in {0, 1}, 'axis must be either 0 (along height) or 1 (along width)' + assert axis in { + 0, 1}, 'axis must be either 0 (along height) or 1 (along width)' self.axis = axis - inner_dim = dim_head * heads + inner_dim = dim_head * heads self.seq_len = seq_len self.heads = heads self.scale = dim_head ** -0.5 @@ -223,15 +246,16 @@ class SparseAxialCausalAttention(nn.Module): self.stable = stable - self.to_qkv = ml.Linear(dim, inner_dim * 3, bias = False) + self.to_qkv = ml.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Sequential( ml.Linear(inner_dim, dim), nn.Dropout(dropout) ) - def forward(self, x, mask = None, rotary_pos_emb = None): - b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device + def forward(self, x, mask=None, rotary_pos_emb=None): + b, n, _, h, img_size, axis, seq_len, device = * \ + x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device softmax = torch.softmax if not self.stable else stable_softmax img_seq_len = img_size ** 2 @@ -240,22 +264,25 @@ class SparseAxialCausalAttention(nn.Module): # padding padding = seq_len - n + 1 - mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool()) + mask = default(mask, lambda: torch.ones( + b, text_len, device=device).bool()) - x = F.pad(x, (0, 0, 0, padding), value = 0) + x = F.pad(x, (0, 0, 0, padding), value=0) mask = mask[:, :text_len] # derive queries / keys / values - qkv = self.to_qkv(x).chunk(3, dim = -1) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv) + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange( + t, 'b n (h d) -> (b h) n d', h=h), qkv) if exists(rotary_pos_emb): q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v)) q *= self.scale - ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v)) + ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map( + lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v)) # text attention @@ -263,10 +290,11 @@ class SparseAxialCausalAttention(nn.Module): mask_value = max_neg_value(dots_text) i, j = dots_text.shape[-2:] - text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() + text_causal_mask = torch.ones( + i, j, device=device).triu_(j - i + 1).bool() dots_text.masked_fill_(text_causal_mask, mask_value) - attn_text = softmax(dots_text, dim = -1) + attn_text = softmax(dots_text, dim=-1) out_text = einsum('b i j, b j d -> b i d', attn_text, v_text) # image attention @@ -276,93 +304,102 @@ class SparseAxialCausalAttention(nn.Module): # split out axis - q_img, k_img, v_img = map(lambda t: rearrange(t, split_axis_einops, h = img_size), (q_img, k_img, v_img)) + q_img, k_img, v_img = map(lambda t: rearrange( + t, split_axis_einops, h=img_size), (q_img, k_img, v_img)) # similarity - dots_image_to_image = einsum('b x i d, b x j d -> b x i j', q_img, k_img) + dots_image_to_image = einsum( + 'b x i d, b x j d -> b x i j', q_img, k_img) dots_image_to_text = einsum('b x i d, b j d -> b x i j', q_img, k_text) - dots = torch.cat((dots_image_to_text, dots_image_to_image), dim = -1) + dots = torch.cat((dots_image_to_text, dots_image_to_image), dim=-1) # mask so image has full attention to text, but causal along axis bh, x, i, j = dots.shape - causal_mask = torch.ones(i, img_size, device = device).triu_(img_size - i + 1).bool() - causal_mask = repeat(causal_mask, 'i j -> b x i j', b = bh, x = x) + causal_mask = torch.ones(i, img_size, device=device).triu_( + img_size - i + 1).bool() + causal_mask = repeat(causal_mask, 'i j -> b x i j', b=bh, x=x) - mask = repeat(mask, 'b j -> (b h) x i j', h = h, x = x, i = i) - mask = torch.cat((~mask, causal_mask), dim = -1) + mask = repeat(mask, 'b j -> (b h) x i j', h=h, x=x, i=i) + mask = torch.cat((~mask, causal_mask), dim=-1) dots.masked_fill_(mask, mask_value) # attention. - attn = softmax(dots, dim = -1) + attn = softmax(dots, dim=-1) # aggregate - attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[..., text_len:] + attn_image_to_text, attn_image_to_image = attn[..., + :text_len], attn[..., text_len:] - out_image_to_image = einsum('b x i j, b x j d -> b x i d', attn_image_to_image, v_img) - out_image_to_text = einsum('b x i j, b j d -> b x i d', attn_image_to_text, v_text) + out_image_to_image = einsum( + 'b x i j, b x j d -> b x i d', attn_image_to_image, v_img) + out_image_to_text = einsum( + 'b x i j, b j d -> b x i d', attn_image_to_text, v_text) out_image = out_image_to_image + out_image_to_text # merge back axis - out_image = rearrange(out_image, merge_axis_einops, x = img_size) + out_image = rearrange(out_image, merge_axis_einops, x=img_size) # combine attended values for both text and image - out = torch.cat((out_text, out_image), dim = 1) + out = torch.cat((out_text, out_image), dim=1) - out = rearrange(out, '(b h) n d -> b n (h d)', h = h) - out = self.to_out(out) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out = self.to_out(out) return out[:, :n] # microsoft sparse attention CUDA kernel + class SparseAttention(Attention): def __init__( self, *args, - block_size = 16, - text_seq_len = 256, - num_random_blocks = None, + block_size=16, + text_seq_len=256, + num_random_blocks=None, **kwargs ): super().__init__(*args, **kwargs) - from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig + from deepspeed.ops.sparse_attention import (SparseSelfAttention, + VariableSparsityConfig) self.block_size = block_size - num_random_blocks = default(num_random_blocks, self.seq_len // block_size // 4) + num_random_blocks = default( + num_random_blocks, self.seq_len // block_size // 4) global_block_indices = list(range(ceil(text_seq_len / block_size))) self.attn_fn = SparseSelfAttention( - sparsity_config = VariableSparsityConfig( - num_heads = self.heads, - block = self.block_size, - num_random_blocks = num_random_blocks, - global_block_indices = global_block_indices, - attention = 'unidirectional' if self.causal else 'bidirectional' + sparsity_config=VariableSparsityConfig( + num_heads=self.heads, + block=self.block_size, + num_random_blocks=num_random_blocks, + global_block_indices=global_block_indices, + attention='unidirectional' if self.causal else 'bidirectional' ), - max_seq_length = self.seq_len, - attn_mask_mode = 'add' + max_seq_length=self.seq_len, + attn_mask_mode='add' ) - def forward(self, x, mask = None, rotary_pos_emb = None): + def forward(self, x, mask=None, rotary_pos_emb=None): b, n, _, h, device = *x.shape, self.heads, x.device remainder = n % self.block_size - mask = default(mask, lambda: torch.ones(b, n, device = device).bool()) + mask = default(mask, lambda: torch.ones(b, n, device=device).bool()) if remainder > 0: padding = self.block_size - remainder - x = F.pad(x, (0, 0, 0, padding), value = 0) - mask = F.pad(mask, (0, padding), value = False) + x = F.pad(x, (0, 0, 0, padding), value=0) + mask = F.pad(mask, (0, padding), value=False) - qkv = self.to_qkv(x).chunk(3, dim = -1) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) if exists(rotary_pos_emb): q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v)) @@ -374,12 +411,13 @@ class SparseAttention(Attention): attn_mask = None if self.causal: i, j = q.shape[-2], k.shape[-2] - mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() - attn_mask = torch.zeros(i, j, device = device).to(q) + mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool() + attn_mask = torch.zeros(i, j, device=device).to(q) mask_value = max_neg_value(q) / 2 attn_mask.masked_fill_(mask, mask_value) - out = self.attn_fn(q, k, v, attn_mask = attn_mask, key_padding_mask = key_pad_mask) + out = self.attn_fn(q, k, v, attn_mask=attn_mask, + key_padding_mask=key_pad_mask) out = rearrange(out, 'b h n d -> b n (h d)') out = self.to_out(out) - return out[:, :n] \ No newline at end of file + return out[:, :n] diff --git a/dlas/models/lucidrains/dalle/reversible.py b/dlas/models/lucidrains/dalle/reversible.py index c55416ae..265cdc32 100644 --- a/dlas/models/lucidrains/dalle/reversible.py +++ b/dlas/models/lucidrains/dalle/reversible.py @@ -1,13 +1,13 @@ import functools +from operator import itemgetter import torch import torch.nn as nn -from operator import itemgetter from torch.autograd.function import Function from torch.utils.checkpoint import get_device_states, set_device_states # for routing arguments into the functions of the reversible layer -from utils.util import checkpoint +from dlas.utils.util import checkpoint def route_args(router, args, depth): @@ -17,11 +17,15 @@ def route_args(router, args, depth): for key in matched_keys: val = args[key] for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): - new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) - routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) + new_f_args, new_g_args = map(lambda route: ( + {key: val} if route else {}), routes) + routed_args[depth] = ({**f_args, **new_f_args}, + {**g_args, **new_g_args}) return routed_args # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html + + class Deterministic(nn.Module): def __init__(self, net): super().__init__() @@ -37,7 +41,7 @@ class Deterministic(nn.Module): self.cuda_in_fwd = True self.gpu_devices, self.gpu_states = get_device_states(*args) - def forward(self, *args, record_rng = False, set_rng = False, **kwargs): + def forward(self, *args, record_rng=False, set_rng=False, **kwargs): if record_rng: self.record_rng(*args) @@ -56,13 +60,15 @@ class Deterministic(nn.Module): # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py # once multi-GPU is confirmed working, refactor and send PR back to source + + class ReversibleBlock(nn.Module): def __init__(self, f, g): super().__init__() self.f = Deterministic(f) self.g = Deterministic(g) - def forward(self, x, f_args = {}, g_args = {}): + def forward(self, x, f_args={}, g_args={}): x1, x2 = torch.chunk(x, 2, dim=2) y1, y2 = None, None @@ -72,7 +78,7 @@ class ReversibleBlock(nn.Module): return torch.cat([y1, y2], dim=2) - def backward_pass(self, y, dy, f_args = {}, g_args = {}): + def backward_pass(self, y, dy, f_args={}, g_args={}): y1, y2 = torch.chunk(y, 2, dim=2) del y @@ -110,6 +116,7 @@ class ReversibleBlock(nn.Module): return x, dx + class _ReversibleFunction(Function): @staticmethod def forward(ctx, x, blocks, args): @@ -128,10 +135,12 @@ class _ReversibleFunction(Function): y, dy = block.backward_pass(y, dy, **kwargs) return dy, None, None + class SequentialSequence(nn.Module): - def __init__(self, layers, args_route = {}, layer_dropout = 0.): + def __init__(self, layers, args_route={}, layer_dropout=0.): super().__init__() - assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers' + assert all(len(route) == len(layers) for route in args_route.values( + )), 'each argument route map must have the same depth as the number of sequential layers' self.layers = layers self.args_route = args_route self.layer_dropout = layer_dropout @@ -145,11 +154,13 @@ class SequentialSequence(nn.Module): x = x + checkpoint(functools.partial(g, **g_args), x) return x + class ReversibleSequence(nn.Module): - def __init__(self, blocks, args_route = {}): + def __init__(self, blocks, args_route={}): super().__init__() self.args_route = args_route - self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks]) + self.blocks = nn.ModuleList( + [ReversibleBlock(f=f, g=g) for f, g in blocks]) def forward(self, x, **kwargs): x = torch.cat([x, x], dim=-1) @@ -158,5 +169,5 @@ class ReversibleSequence(nn.Module): args = route_args(self.args_route, kwargs, len(blocks)) args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args)) - out = _ReversibleFunction.apply(x, blocks, args) - return torch.stack(out.chunk(2, dim=-1)).mean(dim=0) \ No newline at end of file + out = _ReversibleFunction.apply(x, blocks, args) + return torch.stack(out.chunk(2, dim=-1)).mean(dim=0) diff --git a/dlas/models/lucidrains/dalle/transformer.py b/dlas/models/lucidrains/dalle/transformer.py index 357ce1ad..084240e9 100644 --- a/dlas/models/lucidrains/dalle/transformer.py +++ b/dlas/models/lucidrains/dalle/transformer.py @@ -1,43 +1,51 @@ from functools import partial -from itertools import islice, cycle +from itertools import cycle, islice import torch -from torch import nn, einsum import torch.nn.functional as F from einops import rearrange - -from models.lucidrains.dalle.reversible import ReversibleSequence, SequentialSequence -from models.lucidrains.dalle.attention import Attention, SparseAttention, SparseConvCausalAttention, SparseAxialCausalAttention - -from rotary_embedding_torch import RotaryEmbedding, broadcat from g_mlp_pytorch import gMLPBlock -import torch_intermediary as ml +from rotary_embedding_torch import RotaryEmbedding, broadcat +from torch import einsum, nn + +import dlas.torch_intermediary as ml +from dlas.models.lucidrains.dalle.attention import (Attention, SparseAttention, + SparseAxialCausalAttention, + SparseConvCausalAttention) +from dlas.models.lucidrains.dalle.reversible import (ReversibleSequence, + SequentialSequence) # helpers + def exists(val): return val is not None + def default(val, d): return val if exists(val) else d -def cast_tuple(val, depth = 1): + +def cast_tuple(val, depth=1): if isinstance(val, list): val = tuple(val) return val if isinstance(val, tuple) else (val,) * depth # classes + class DivideMax(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): - maxes = x.amax(dim = self.dim, keepdim = True).detach() + maxes = x.amax(dim=self.dim, keepdim=True).detach() return x / maxes # https://arxiv.org/abs/2103.17239 + + class LayerScale(nn.Module): def __init__(self, dim, depth, fn): super().__init__() @@ -51,13 +59,15 @@ class LayerScale(nn.Module): scale = torch.zeros(1, 1, dim).fill_(init_eps) self.scale = nn.Parameter(scale) self.fn = fn + def forward(self, x, **kwargs): return self.fn(x, **kwargs) * self.scale # layer norm + class PreNorm(nn.Module): - def __init__(self, dim, fn, sandwich = False): + def __init__(self, dim, fn, sandwich=False): super().__init__() self.norm = nn.LayerNorm(dim) self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity() @@ -70,13 +80,15 @@ class PreNorm(nn.Module): # feed forward + class GEGLU(nn.Module): def forward(self, x): - x, gates = x.chunk(2, dim = -1) + x, gates = x.chunk(2, dim=-1) return x * F.gelu(gates) + class FeedForward(nn.Module): - def __init__(self, dim, dropout = 0., mult = 4.): + def __init__(self, dim, dropout=0., mult=4.): super().__init__() self.net = nn.Sequential( ml.Linear(dim, dim * mult * 2), @@ -90,6 +102,7 @@ class FeedForward(nn.Module): # token shift classes + class PreShiftToken(nn.Module): def __init__(self, fn, image_size, seq_len): super().__init__() @@ -108,29 +121,31 @@ class PreShiftToken(nn.Module): x_text, x_img = x[:, :text_len], x[:, text_len:] x_img = F.pad(x_img, (0, 0, 0, padding)) - x_img = rearrange(x_img, 'b (h w) d -> b h w d', h = image_size) + x_img = rearrange(x_img, 'b (h w) d -> b h w d', h=image_size) # shift 1 from the left for text tokens - x_text_shift, x_text_pass = x_text.chunk(2, dim = -1) + x_text_shift, x_text_pass = x_text.chunk(2, dim=-1) x_text_shift = F.pad(x_text_shift, (0, 0, 1, -1)) - x_text = torch.cat((x_text_shift, x_text_pass), dim = -1) + x_text = torch.cat((x_text_shift, x_text_pass), dim=-1) # shift from top, left for image tokens - x_img_shift_top, x_img_shift_left, *x_img_pass = x_img.chunk(4, dim = -1) + x_img_shift_top, x_img_shift_left, *x_img_pass = x_img.chunk(4, dim=-1) x_img_shift_left = F.pad(x_img_shift_left, (0, 0, 1, -1)) x_img_shift_top = F.pad(x_img_shift_top, (0, 0, 0, 0, 1, -1)) - x_img = torch.cat((x_img_shift_top, x_img_shift_left, *x_img_pass), dim = -1) + x_img = torch.cat( + (x_img_shift_top, x_img_shift_left, *x_img_pass), dim=-1) # merge text and image sequence back together x_img = rearrange(x_img, 'b h w d -> b (h w) d') - x = torch.cat((x_text, x_img[:, :-padding]), dim = 1) + x = torch.cat((x_text, x_img[:, :-padding]), dim=1) return self.fn(x, **kwargs) # main transformer class + class Transformer(nn.Module): def __init__( self, @@ -138,21 +153,21 @@ class Transformer(nn.Module): dim, depth, seq_len, - reversible = False, - causal = True, - heads = 8, - dim_head = 64, - ff_mult = 4, - attn_dropout = 0., - ff_dropout = 0., - attn_types = None, - image_fmap_size = None, - oned_fmap_size = None, - sparse_attn = False, - stable = False, - sandwich_norm = False, - shift_tokens = False, - rotary_emb = True + reversible=False, + causal=True, + heads=8, + dim_head=64, + ff_mult=4, + attn_dropout=0., + ff_dropout=0., + attn_types=None, + image_fmap_size=None, + oned_fmap_size=None, + sparse_attn=False, + stable=False, + sandwich_norm=False, + shift_tokens=False, + rotary_emb=True ): super().__init__() layers = nn.ModuleList([]) @@ -164,40 +179,47 @@ class Transformer(nn.Module): for ind, sparse_attn, attn_type in zip(range(depth), sparse_layer, attn_type_layer): if attn_type == 'full': - attn_class = partial(Attention, stable = stable) + attn_class = partial(Attention, stable=stable) elif attn_type == 'sparse': attn_class = SparseAttention elif attn_type == 'axial_row': - attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable) + attn_class = partial(SparseAxialCausalAttention, seq_len=seq_len, + axis=0, image_size=image_fmap_size, stable=stable) elif attn_type == 'axial_col': - attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable) + attn_class = partial(SparseAxialCausalAttention, seq_len=seq_len, + axis=1, image_size=image_fmap_size, stable=stable) elif attn_type == 'conv_like': - attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size, stable = stable) + attn_class = partial( + SparseConvCausalAttention, seq_len=seq_len, image_size=image_fmap_size, stable=stable) elif attn_type == 'mlp': - attn_class = partial(gMLPBlock, seq_len = seq_len) + attn_class = partial(gMLPBlock, seq_len=seq_len) else: raise ValueError(f'attention type "{attn_type}" is not valid') if attn_type != 'mlp': - attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout) + attn = attn_class(dim, causal=causal, seq_len=seq_len, + heads=heads, dim_head=dim_head, dropout=attn_dropout) else: - attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4) + attn = attn_class(dim=dim, causal=causal, dim_ff=dim * 4) - ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout) + ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout) if shift_tokens: - attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff)) + attn, ff = map(lambda t: PreShiftToken( + t, image_size=image_fmap_size, seq_len=seq_len), (attn, ff)) layers.append(nn.ModuleList([ - LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)), - LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich = sandwich_norm)) + LayerScale(dim, ind + 1, PreNorm(dim, + attn, sandwich=sandwich_norm)), + LayerScale(dim, ind + 1, PreNorm(dim, + ff, sandwich=sandwich_norm)) ])) execute_type = ReversibleSequence if reversible else SequentialSequence route_attn = ((True, False),) * depth attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn} - self.layers = execute_type(layers, args_route = attn_route_map) + self.layers = execute_type(layers, args_route=attn_route_map) # generate positional embeddings for rotary @@ -206,28 +228,34 @@ class Transformer(nn.Module): assert 'mlp' not in attn_types, 'you cannot use gMLPs if rotary embedding is turned on' rot_dim = dim_head // 3 - img_seq_len = (image_fmap_size ** 2) if image_fmap_size is not None else oned_fmap_size + img_seq_len = (image_fmap_size ** + 2) if image_fmap_size is not None else oned_fmap_size text_len = seq_len - img_seq_len + 1 - text_pos_emb = RotaryEmbedding(dim = rot_dim) - img_axial_pos_emb = RotaryEmbedding(dim = rot_dim, freqs_for = 'pixel') + text_pos_emb = RotaryEmbedding(dim=rot_dim) + img_axial_pos_emb = RotaryEmbedding(dim=rot_dim, freqs_for='pixel') text_freqs = text_pos_emb(torch.arange(text_len)) - img_to_text_freqs = text_pos_emb(torch.full((img_seq_len,), 8192)) # image is given a position far away from text - text_freqs = torch.cat((text_freqs, img_to_text_freqs), dim = 0) + # image is given a position far away from text + img_to_text_freqs = text_pos_emb(torch.full((img_seq_len,), 8192)) + text_freqs = torch.cat((text_freqs, img_to_text_freqs), dim=0) - img_freqs_axial = img_axial_pos_emb(torch.linspace(-1, 1, steps = image_fmap_size if image_fmap_size is not None else oned_fmap_size)) - img_freqs = broadcat((rearrange(img_freqs_axial, 'i d -> i () d'), rearrange(img_freqs_axial, 'j d -> () j d')), dim = -1) + img_freqs_axial = img_axial_pos_emb( + torch.linspace(-1, 1, steps=image_fmap_size if image_fmap_size is not None else oned_fmap_size)) + img_freqs = broadcat((rearrange(img_freqs_axial, 'i d -> i () d'), + rearrange(img_freqs_axial, 'j d -> () j d')), dim=-1) img_freqs = rearrange(img_freqs, 'h w d -> (h w) d') - text_axial_freqs = img_axial_pos_emb(torch.full((text_len,), -10.)) # text is given a position of -10 apart from the image axial positions, which is from range [-1, 1] - text_axial_freqs = torch.cat((text_axial_freqs, text_axial_freqs), dim = -1) - img_freqs = torch.cat((text_axial_freqs, img_freqs), dim = 0) + # text is given a position of -10 apart from the image axial positions, which is from range [-1, 1] + text_axial_freqs = img_axial_pos_emb(torch.full((text_len,), -10.)) + text_axial_freqs = torch.cat( + (text_axial_freqs, text_axial_freqs), dim=-1) + img_freqs = torch.cat((text_axial_freqs, img_freqs), dim=0) - pos_emb = torch.cat((text_freqs, img_freqs), dim = -1) + pos_emb = torch.cat((text_freqs, img_freqs), dim=-1) pos_emb = rearrange(pos_emb, 'n d -> () n d') self.register_buffer('pos_emb', pos_emb) def forward(self, x, **kwargs): - return self.layers(x, rotary_pos_emb = self.pos_emb, **kwargs) \ No newline at end of file + return self.layers(x, rotary_pos_emb=self.pos_emb, **kwargs) diff --git a/dlas/models/lucidrains/performer/__init__.py b/dlas/models/lucidrains/performer/__init__.py index 8b137891..e69de29b 100644 --- a/dlas/models/lucidrains/performer/__init__.py +++ b/dlas/models/lucidrains/performer/__init__.py @@ -1 +0,0 @@ - diff --git a/dlas/models/lucidrains/performer/autoregressive_wrapper.py b/dlas/models/lucidrains/performer/autoregressive_wrapper.py index fa28b771..f02a66fe 100644 --- a/dlas/models/lucidrains/performer/autoregressive_wrapper.py +++ b/dlas/models/lucidrains/performer/autoregressive_wrapper.py @@ -1,13 +1,16 @@ from functools import partial + import torch -from torch import nn import torch.nn.functional as F +from torch import nn from torch.nn.utils.rnn import pad_sequence + def exists(val): return val is not None -def top_p(logits, thres = 0.9): + +def top_p(logits, thres=0.9): sorted_logits, sorted_indices = torch.sort(logits, descending=True) cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) @@ -18,21 +21,24 @@ def top_p(logits, thres = 0.9): sorted_logits[sorted_indices_to_remove] = float('-inf') return sorted_logits.scatter(1, sorted_indices, sorted_logits) -def top_k(logits, thres = 0.9): + +def top_k(logits, thres=0.9): k = int((1 - thres) * logits.shape[-1]) val, ind = torch.topk(logits, k) probs = torch.full_like(logits, float('-inf')) probs.scatter_(1, ind, val) return probs + def repetition_penalty_fn(logits, ctx, theta=1.2): w = torch.ones(logits.shape[-1], dtype=torch.float, device=logits.device) for i in torch.unique(ctx): w[i] = theta return logits/w + class AutoregressiveWrapper(nn.Module): - def __init__(self, net, ignore_index = 0, pad_value = 0): + def __init__(self, net, ignore_index=0, pad_value=0): super().__init__() self.pad_value = pad_value self.ignore_index = ignore_index @@ -41,7 +47,7 @@ class AutoregressiveWrapper(nn.Module): self.max_seq_len = net.max_seq_len @torch.no_grad() - def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, repetition_penalty=1.0, repetition_penalty_ctx=32, **kwargs): + def generate(self, start_tokens, seq_len, eos_token=None, temperature=1., filter_logits_fn=top_k, filter_thres=0.9, repetition_penalty=1.0, repetition_penalty_ctx=32, **kwargs): was_training = self.net.training num_dims = len(start_tokens.shape) @@ -55,24 +61,27 @@ class AutoregressiveWrapper(nn.Module): input_mask = kwargs.pop('mask', None) if input_mask is None: - input_mask = torch.full_like(out, True, dtype=torch.bool, device=out.device) - + input_mask = torch.full_like( + out, True, dtype=torch.bool, device=out.device) + # in case of conditional generation, if enc_mask is not provided use the correct context_mask context_mask = kwargs.pop('context_mask', None) if 'context' in kwargs and not exists(context_mask): context = kwargs['context'] - context_mask = torch.full(context.shape[:2], True, dtype=torch.bool, device=out.device) + context_mask = torch.full( + context.shape[:2], True, dtype=torch.bool, device=out.device) - kwargs.update(context_mask = context_mask) + kwargs.update(context_mask=context_mask) for _ in range(seq_len): x = out[:, -self.max_seq_len:] input_mask = input_mask[:, -self.max_seq_len:] logits = self.net(x, mask=input_mask, **kwargs)[:, -1, :] if repetition_penalty > 1.0: - logits = repetition_penalty_fn(logits, out[-repetition_penalty_ctx:], theta=repetition_penalty) - filtered_logits = filter_logits_fn(logits, thres = filter_thres) + logits = repetition_penalty_fn( + logits, out[-repetition_penalty_ctx:], theta=repetition_penalty) + filtered_logits = filter_logits_fn(logits, thres=filter_thres) probs = F.softmax(filtered_logits / temperature, dim=-1) sample = torch.multinomial(probs, 1) @@ -99,9 +108,10 @@ class AutoregressiveWrapper(nn.Module): mask = kwargs.pop('mask', None) if mask is not None and mask.shape[1] == x.shape[1]: mask = mask[:, :-1] - kwargs.update(mask = mask) + kwargs.update(mask=mask) out = self.net(xi, **kwargs) - loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index) + loss = F.cross_entropy(out.transpose(1, 2), xo, + ignore_index=self.ignore_index) return loss diff --git a/dlas/models/lucidrains/performer/performer_enc_dec.py b/dlas/models/lucidrains/performer/performer_enc_dec.py index e3691507..971b4514 100644 --- a/dlas/models/lucidrains/performer/performer_enc_dec.py +++ b/dlas/models/lucidrains/performer/performer_enc_dec.py @@ -1,62 +1,76 @@ -import re -import torch -from torch import nn - # for god knows why it cannot "see" performer_pytorch import os +import re import sys + +import torch +from torch import nn + +from dlas.models.lucidrains.performer.autoregressive_wrapper import \ + AutoregressiveWrapper +from dlas.models.lucidrains.performer.performer_pytorch import PerformerLM + sys.path.insert(0, os.path.dirname(os.path.realpath(__file__))) -from performer_pytorch import PerformerLM -from autoregressive_wrapper import AutoregressiveWrapper ENC_PREFIX = 'enc_' DEC_PREFIX = 'dec_' + def group_dict_by_key(cond, d): - return_val = [dict(),dict()] + return_val = [dict(), dict()] for key in d.keys(): match = bool(cond(key)) ind = int(not match) return_val[ind][key] = d[key] return (*return_val,) + def string_begins_with(prefix, str): return bool(re.match(f'^{prefix}', str)) + def group_by_key_prefix(prefix, d): return group_dict_by_key(lambda x: string_begins_with(prefix, x), d) + def group_by_key_prefix_and_remove_prefix(prefix, d): - kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: string_begins_with(prefix, x), d) - kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + kwargs_with_prefix, kwargs = group_dict_by_key( + lambda x: string_begins_with(prefix, x), d) + kwargs_without_prefix = dict( + map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) return kwargs_without_prefix, kwargs + def extract_enc_dec_kwargs(kwargs): - enc_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(ENC_PREFIX, kwargs) - dec_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(DEC_PREFIX, kwargs) + enc_kwargs, kwargs = group_by_key_prefix_and_remove_prefix( + ENC_PREFIX, kwargs) + dec_kwargs, kwargs = group_by_key_prefix_and_remove_prefix( + DEC_PREFIX, kwargs) return enc_kwargs, dec_kwargs, kwargs + def extract_and_set_enc_dec_kwargs(kwargs): enc_kwargs, dec_kwargs, kwargs = extract_enc_dec_kwargs(kwargs) if 'mask' in enc_kwargs: dec_kwargs.setdefault('context_mask', enc_kwargs['mask']) return enc_kwargs, dec_kwargs, kwargs + class PerformerEncDec(nn.Module): def __init__( self, dim, - ignore_index = 0, - pad_value = 0, - tie_token_embeds = False, - no_projection = False, + ignore_index=0, + pad_value=0, + tie_token_embeds=False, + no_projection=False, **kwargs ): super().__init__() enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs) - + assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder' enc_kwargs['dim'] = dec_kwargs['dim'] = dim @@ -72,15 +86,17 @@ class PerformerEncDec(nn.Module): enc.token_emb = dec.token_emb self.enc = enc - self.dec = AutoregressiveWrapper(dec, ignore_index = ignore_index, pad_value = pad_value) + self.dec = AutoregressiveWrapper( + dec, ignore_index=ignore_index, pad_value=pad_value) @torch.no_grad() def generate(self, seq_in, seq_out_start, seq_len, **kwargs): enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs) - encodings = self.enc(seq_in, return_encodings = True, **enc_kwargs) - return self.dec.generate(seq_out_start, seq_len, context = encodings, **{**dec_kwargs, **kwargs}) + encodings = self.enc(seq_in, return_encodings=True, **enc_kwargs) + return self.dec.generate(seq_out_start, seq_len, context=encodings, **{**dec_kwargs, **kwargs}) - def forward(self, seq_in, seq_out, enc_mask = None, **kwargs): + def forward(self, seq_in, seq_out, enc_mask=None, **kwargs): enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs) - encodings = self.enc(seq_in, mask = enc_mask, return_encodings = True, **enc_kwargs) - return self.dec(seq_out, context = encodings, context_mask = enc_mask, **dec_kwargs) \ No newline at end of file + encodings = self.enc(seq_in, mask=enc_mask, + return_encodings=True, **enc_kwargs) + return self.dec(seq_out, context=encodings, context_mask=enc_mask, **dec_kwargs) diff --git a/dlas/models/lucidrains/performer/performer_pytorch.py b/dlas/models/lucidrains/performer/performer_pytorch.py index 98ce769f..07f32e0b 100644 --- a/dlas/models/lucidrains/performer/performer_pytorch.py +++ b/dlas/models/lucidrains/performer/performer_pytorch.py @@ -1,18 +1,18 @@ import math +from contextlib import contextmanager +from distutils.version import LooseVersion +from functools import partial + +import dlas.torch_intermediary as ml import torch import torch.nn.functional as F +from axial_positional_embedding import AxialPositionalEmbedding +from dlas.models.lucidrains.performer.reversible import (ReversibleSequence, + SequentialSequence) +from einops import rearrange, repeat +from local_attention import LocalAttention from torch import nn from torch.cuda.amp import autocast -from einops import rearrange, repeat - -from functools import partial -from contextlib import contextmanager - -from local_attention import LocalAttention -from axial_positional_embedding import AxialPositionalEmbedding -from models.lucidrains.performer.reversible import ReversibleSequence, SequentialSequence - -from distutils.version import LooseVersion TORCH_GE_1_8_0 = LooseVersion(torch.__version__) >= LooseVersion('1.8.0') @@ -21,32 +21,39 @@ try: APEX_AVAILABLE = True except: APEX_AVAILABLE = False -import torch_intermediary as ml # helpers + def exists(val): return val is not None + def empty(tensor): return tensor.numel() == 0 + def default(val, d): return val if exists(val) else d + @contextmanager def null_context(): yield + def cast_tuple(val): return (val,) if not isinstance(val, tuple) else val + def get_module_device(module): return next(module.parameters()).device + def find_modules(nn_module, type): return [module for module in nn_module.modules() if isinstance(module, type)] + class Always(nn.Module): def __init__(self, val): super().__init__() @@ -57,14 +64,16 @@ class Always(nn.Module): # token shifting helper and classes -def shift(t, amount, mask = None): + +def shift(t, amount, mask=None): if amount == 0: return t if exists(mask): t = t.masked_fill(~mask[..., None], 0.) - return F.pad(t, (0, 0, amount, -amount), value = 0.) + return F.pad(t, (0, 0, amount, -amount), value=0.) + class PreShiftTokens(nn.Module): def __init__(self, shifts, fn): @@ -77,10 +86,11 @@ class PreShiftTokens(nn.Module): shifts = self.shifts segments = len(shifts) feats_per_shift = x.shape[-1] // segments - splitted = x.split(feats_per_shift, dim = -1) + splitted = x.split(feats_per_shift, dim=-1) segments_to_shift, rest = splitted[:segments], splitted[segments:] - segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts))) - x = torch.cat((*segments_to_shift, *rest), dim = -1) + segments_to_shift = list(map(lambda args: shift( + *args, mask=mask), zip(segments_to_shift, shifts))) + x = torch.cat((*segments_to_shift, *rest), dim=-1) return self.fn(x, **kwargs) # kernel functions @@ -88,17 +98,19 @@ class PreShiftTokens(nn.Module): # transcribed from jax to pytorch from # https://github.com/google-research/google-research/blob/master/performer/fast_attention/jax/fast_attention.py -def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None): + +def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device=None): b, h, *_ = data.shape data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. ratio = (projection_matrix.shape[0] ** -0.5) - projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h) + projection = repeat(projection_matrix, 'j d -> b h j d', b=b, h=h) projection = projection.type_as(data) - data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) + data_dash = torch.einsum('...id,...jd->...ij', + (data_normalizer * data), projection) diag_data = data ** 2 diag_data = torch.sum(diag_data, dim=-1) @@ -108,14 +120,15 @@ def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, ep if is_query: data_dash = ratio * ( torch.exp(data_dash - diag_data - - torch.amax(data_dash, dim=-1, keepdim=True)) + eps) + torch.amax(data_dash, dim=-1, keepdim=True)) + eps) else: data_dash = ratio * ( torch.exp(data_dash - diag_data - torch.amax(data_dash, dim=(-1, -2), keepdim=True)) + eps) return data_dash.type_as(data) -def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(), kernel_epsilon = 0.001, normalize_data = True, device = None): + +def generalized_kernel(data, *, projection_matrix, kernel_fn=nn.ReLU(), kernel_epsilon=0.001, normalize_data=True, device=None): b, h, *_ = data.shape data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. @@ -123,43 +136,48 @@ def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(), kernel if projection_matrix is None: return kernel_fn(data_normalizer * data) + kernel_epsilon - projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h) + projection = repeat(projection_matrix, 'j d -> b h j d', b=b, h=h) projection = projection.type_as(data) - data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) + data_dash = torch.einsum('...id,...jd->...ij', + (data_normalizer * data), projection) data_prime = kernel_fn(data_dash) + kernel_epsilon return data_prime.type_as(data) -def orthogonal_matrix_chunk(cols, device = None): - unstructured_block = torch.randn((cols, cols), device = device) + +def orthogonal_matrix_chunk(cols, device=None): + unstructured_block = torch.randn((cols, cols), device=device) if TORCH_GE_1_8_0: - q, r = torch.linalg.qr(unstructured_block.cpu(), mode = 'reduced') + q, r = torch.linalg.qr(unstructured_block.cpu(), mode='reduced') else: - q, r = torch.qr(unstructured_block.cpu(), some = True) + q, r = torch.qr(unstructured_block.cpu(), some=True) q, r = map(lambda t: t.to(device), (q, r)) return q.t() -def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, device = None): + +def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling=0, device=None): nb_full_blocks = int(nb_rows / nb_columns) block_list = [] for _ in range(nb_full_blocks): - q = orthogonal_matrix_chunk(nb_columns, device = device) + q = orthogonal_matrix_chunk(nb_columns, device=device) block_list.append(q) remaining_rows = nb_rows - nb_full_blocks * nb_columns if remaining_rows > 0: - q = orthogonal_matrix_chunk(nb_columns, device = device) + q = orthogonal_matrix_chunk(nb_columns, device=device) block_list.append(q[:remaining_rows]) final_matrix = torch.cat(block_list) if scaling == 0: - multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1) + multiplier = torch.randn( + (nb_rows, nb_columns), device=device).norm(dim=1) elif scaling == 1: - multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device) + multiplier = math.sqrt((float(nb_columns))) * \ + torch.ones((nb_rows,), device=device) else: raise ValueError(f'Invalid scaling {scaling}') @@ -168,8 +186,10 @@ def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, device = # linear attention classes with softmax kernel # non-causal linear attention + + def linear_attention(q, k, v): - k_cumsum = k.sum(dim = -2) + k_cumsum = k.sum(dim=-2) D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q)) context = torch.einsum('...nd,...ne->...de', k, v) out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv) @@ -177,14 +197,18 @@ def linear_attention(q, k, v): # efficient causal linear attention, created by EPFL # TODO: rewrite EPFL's CUDA kernel to do mixed precision and remove half to float conversion and back -def causal_linear_attention(q, k, v, eps = 1e-6): + + +def causal_linear_attention(q, k, v, eps=1e-6): from fast_transformers.causal_product import CausalDotProduct autocast_enabled = torch.is_autocast_enabled() is_half = isinstance(q, torch.cuda.HalfTensor) assert not is_half or APEX_AVAILABLE, 'half tensors can only be used if nvidia apex is available' - cuda_context = null_context if not autocast_enabled else partial(autocast, enabled = False) + cuda_context = null_context if not autocast_enabled else partial( + autocast, enabled=False) - causal_dot_product_fn = amp.float_function(CausalDotProduct.apply) if is_half else CausalDotProduct.apply + causal_dot_product_fn = amp.float_function( + CausalDotProduct.apply) if is_half else CausalDotProduct.apply k_cumsum = k.cumsum(dim=-2) + eps D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q)) @@ -200,35 +224,42 @@ def causal_linear_attention(q, k, v, eps = 1e-6): # inefficient causal linear attention, without cuda code, for reader's reference # not being used -def causal_linear_attention_noncuda(q, k, v, chunk_size = 128, eps = 1e-6): + + +def causal_linear_attention_noncuda(q, k, v, chunk_size=128, eps=1e-6): last_k_cumsum = 0 last_context_cumsum = 0 outs = [] - for q, k, v in zip(*map(lambda t: t.chunk(chunk_size, dim = -2), (q, k, v))): + for q, k, v in zip(*map(lambda t: t.chunk(chunk_size, dim=-2), (q, k, v))): k_cumsum = last_k_cumsum + k.cumsum(dim=-2) - D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q) + eps) + D_inv = 1. / torch.einsum('...nd,...nd->...n', + q, k_cumsum.type_as(q) + eps) context = torch.einsum('...nd,...ne->...nde', k, v) context_cumsum = last_context_cumsum + context.cumsum(dim=-3) - out = torch.einsum('...nde,...nd,...n->...ne', context_cumsum, q, D_inv) + out = torch.einsum('...nde,...nd,...n->...ne', + context_cumsum, q, D_inv) last_k_cumsum = k_cumsum[:, :, -1:] last_context_cumsum = context_cumsum[:, :, -1:] outs.append(out) - return torch.cat(outs, dim = -2) + return torch.cat(outs, dim=-2) + class FastAttention(nn.Module): - def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), no_projection = False): + def __init__(self, dim_heads, nb_features=None, ortho_scaling=0, causal=False, generalized_attention=False, kernel_fn=nn.ReLU(), no_projection=False): super().__init__() - nb_features = default(nb_features, int(dim_heads * math.log(dim_heads))) + nb_features = default(nb_features, int( + dim_heads * math.log(dim_heads))) self.dim_heads = dim_heads self.nb_features = nb_features self.ortho_scaling = ortho_scaling - self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling) + self.create_projection = partial( + gaussian_orthogonal_random_matrix, nb_rows=self.nb_features, nb_columns=dim_heads, scaling=ortho_scaling) projection_matrix = self.create_projection() self.register_buffer('projection_matrix', projection_matrix) @@ -250,7 +281,7 @@ class FastAttention(nn.Module): @torch.no_grad() def redraw_projection_matrix(self, device): - projections = self.create_projection(device = device) + projections = self.create_projection(device=device) self.projection_matrix.copy_(projections) del projections @@ -258,17 +289,19 @@ class FastAttention(nn.Module): device = q.device if self.no_projection: - q = q.softmax(dim = -1) - k = torch.exp(k) if self.causal else k.softmax(dim = -2) + q = q.softmax(dim=-1) + k = torch.exp(k) if self.causal else k.softmax(dim=-2) elif self.generalized_attention: - create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device) + create_kernel = partial(generalized_kernel, kernel_fn=self.kernel_fn, + projection_matrix=self.projection_matrix, device=device) q, k = map(create_kernel, (q, k)) else: - create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device) - q = create_kernel(q, is_query = True) - k = create_kernel(k, is_query = False) + create_kernel = partial( + softmax_kernel, projection_matrix=self.projection_matrix, device=device) + q = create_kernel(q, is_query=True) + k = create_kernel(k, is_query=False) attn_fn = linear_attention if not self.causal else self.causal_linear_fn out = attn_fn(q, k, v) @@ -276,6 +309,7 @@ class FastAttention(nn.Module): # a module for keeping track of when to update the projections + class ProjectionUpdater(nn.Module): def __init__(self, instance, feature_redraw_interval): super().__init__() @@ -309,6 +343,7 @@ class ProjectionUpdater(nn.Module): # classes + class ReZero(nn.Module): def __init__(self, fn): super().__init__() @@ -318,6 +353,7 @@ class ReZero(nn.Module): def forward(self, x, **kwargs): return self.fn(x, **kwargs) * self.g + class PreScaleNorm(nn.Module): def __init__(self, dim, fn, eps=1e-5): super().__init__() @@ -330,16 +366,19 @@ class PreScaleNorm(nn.Module): x = x / n * self.g return self.fn(x, **kwargs) + class PreLayerNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn + def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) + class Chunk(nn.Module): - def __init__(self, chunks, fn, along_dim = -1): + def __init__(self, chunks, fn, along_dim=-1): super().__init__() self.dim = along_dim self.chunks = chunks @@ -348,11 +387,12 @@ class Chunk(nn.Module): def forward(self, x, **kwargs): if self.chunks == 1: return self.fn(x, **kwargs) - chunks = x.chunk(self.chunks, dim = self.dim) - return torch.cat([self.fn(c, **kwargs) for c in chunks], dim = self.dim) + chunks = x.chunk(self.chunks, dim=self.dim) + return torch.cat([self.fn(c, **kwargs) for c in chunks], dim=self.dim) + class FeedForward(nn.Module): - def __init__(self, dim, mult = 4, dropout = 0., activation = None, glu = False): + def __init__(self, dim, mult=4, dropout=0., activation=None, glu=False): super().__init__() activation = default(activation, nn.GELU) @@ -374,52 +414,58 @@ class FeedForward(nn.Module): x = self.w2(x) return x + class Attention(nn.Module): def __init__( self, dim, - causal = False, - heads = 8, - dim_head = 64, - local_heads = 0, - local_window_size = 256, - nb_features = None, - feature_redraw_interval = 1000, - generalized_attention = False, - kernel_fn = nn.ReLU(), - dropout = 0., - no_projection = False, - qkv_bias = False, - attn_out_bias = True + causal=False, + heads=8, + dim_head=64, + local_heads=0, + local_window_size=256, + nb_features=None, + feature_redraw_interval=1000, + generalized_attention=False, + kernel_fn=nn.ReLU(), + dropout=0., + no_projection=False, + qkv_bias=False, + attn_out_bias=True ): super().__init__() assert dim % heads == 0, 'dimension must be divisible by number of heads' dim_head = default(dim_head, dim // heads) inner_dim = dim_head * heads - self.fast_attention = FastAttention(dim_head, nb_features, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, no_projection = no_projection) + self.fast_attention = FastAttention( + dim_head, nb_features, causal=causal, generalized_attention=generalized_attention, kernel_fn=kernel_fn, no_projection=no_projection) self.heads = heads self.global_heads = heads - local_heads - self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None + self.local_attn = LocalAttention(window_size=local_window_size, causal=causal, autopad=True, dropout=dropout, look_forward=int( + not causal), rel_pos_emb_config=(dim_head, local_heads)) if local_heads > 0 else None - self.to_q = ml.Linear(dim, inner_dim, bias = qkv_bias) - self.to_k = ml.Linear(dim, inner_dim, bias = qkv_bias) - self.to_v = ml.Linear(dim, inner_dim, bias = qkv_bias) - self.to_out = ml.Linear(inner_dim, dim, bias = attn_out_bias) + self.to_q = ml.Linear(dim, inner_dim, bias=qkv_bias) + self.to_k = ml.Linear(dim, inner_dim, bias=qkv_bias) + self.to_v = ml.Linear(dim, inner_dim, bias=qkv_bias) + self.to_out = ml.Linear(inner_dim, dim, bias=attn_out_bias) self.dropout = nn.Dropout(dropout) - def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, **kwargs): + def forward(self, x, pos_emb=None, context=None, mask=None, context_mask=None, **kwargs): b, n, _, h, gh = *x.shape, self.heads, self.global_heads cross_attend = exists(context) context = default(context, x) - context_mask = default(context_mask, mask) if not cross_attend else context_mask + context_mask = default( + context_mask, mask) if not cross_attend else context_mask q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) - (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v)) + q, k, v = map(lambda t: rearrange( + t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + (q, lq), (k, lk), (v, lv) = map( + lambda t: (t[:, :gh], t[:, gh:]), (q, k, v)) attn_outs = [] @@ -436,26 +482,29 @@ class Attention(nn.Module): if not empty(lq): assert not cross_attend, 'local attention is not compatible with cross attention' - out = self.local_attn(lq, lk, lv, input_mask = mask) + out = self.local_attn(lq, lk, lv, input_mask=mask) attn_outs.append(out) - out = torch.cat(attn_outs, dim = 1) + out = torch.cat(attn_outs, dim=1) out = rearrange(out, 'b h n d -> b n (h d)') - out = self.to_out(out) + out = self.to_out(out) return self.dropout(out) + class SelfAttention(Attention): - def forward(self, *args, context = None, **kwargs): + def forward(self, *args, context=None, **kwargs): assert not exists(context), 'self attention should not receive context' return super().forward(*args, **kwargs) + class CrossAttention(Attention): - def forward(self, *args, context = None, **kwargs): + def forward(self, *args, context=None, **kwargs): assert exists(context), 'cross attention should receive context' - return super().forward(*args, context = context, **kwargs) + return super().forward(*args, context=context, **kwargs) # positional embeddings + class AbsolutePositionalEmbedding(nn.Module): def __init__(self, dim, max_seq_len): super().__init__() @@ -468,21 +517,24 @@ class AbsolutePositionalEmbedding(nn.Module): # rotary positional embedding helpers + def rotate_every_two(x): - x = rearrange(x, '... (d j) -> ... d j', j = 2) - x1, x2 = x.unbind(dim = -1) - x = torch.stack((-x2, x1), dim = -1) + x = rearrange(x, '... (d j) -> ... d j', j=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) return rearrange(x, '... d j -> ... (d j)') + def apply_rotary_pos_emb(q, k, sinu_pos): - sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j = 2) - sin, cos = sinu_pos.unbind(dim = -2) - sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j = 2), (sin, cos)) + sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j=2) + sin, cos = sinu_pos.unbind(dim=-2) + sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j=2), (sin, cos)) q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k)) return q, k # sinusoidal positional embeddings + class FixedPositionalEmbedding(nn.Module): def __init__(self, dim, max_seq_len): super().__init__() @@ -497,6 +549,7 @@ class FixedPositionalEmbedding(nn.Module): # performer + class Performer(nn.Module): def __init__( self, @@ -504,34 +557,37 @@ class Performer(nn.Module): depth, heads, dim_head, - local_attn_heads = 0, - local_window_size = 256, - causal = False, - ff_mult = 4, - nb_features = None, - feature_redraw_interval = 1000, - reversible = False, - ff_chunks = 1, - generalized_attention = False, - kernel_fn = nn.ReLU(), - use_scalenorm = False, - use_rezero = False, - ff_glu = False, - ff_dropout = 0., - attn_dropout = 0., - cross_attend = False, - no_projection = False, - auto_check_redraw = True, - qkv_bias = True, - attn_out_bias = True, - shift_tokens = False + local_attn_heads=0, + local_window_size=256, + causal=False, + ff_mult=4, + nb_features=None, + feature_redraw_interval=1000, + reversible=False, + ff_chunks=1, + generalized_attention=False, + kernel_fn=nn.ReLU(), + use_scalenorm=False, + use_rezero=False, + ff_glu=False, + ff_dropout=0., + attn_dropout=0., + cross_attend=False, + no_projection=False, + auto_check_redraw=True, + qkv_bias=True, + attn_out_bias=True, + shift_tokens=False ): super().__init__() layers = nn.ModuleList([]) local_attn_heads = cast_tuple(local_attn_heads) - local_attn_heads = local_attn_heads * depth if len(local_attn_heads) == 1 else local_attn_heads - assert len(local_attn_heads) == depth, 'tuple specifying number of local attention heads per depth must be equal to the total depth' - assert all(map(lambda n: n >= 0 and n <= heads, local_attn_heads)), 'local attention head value must be less than the total number of heads' + local_attn_heads = local_attn_heads * \ + depth if len(local_attn_heads) == 1 else local_attn_heads + assert len( + local_attn_heads) == depth, 'tuple specifying number of local attention heads per depth must be equal to the total depth' + assert all(map(lambda n: n >= 0 and n <= heads, local_attn_heads) + ), 'local attention head value must be less than the total number of heads' if use_scalenorm: wrapper_fn = partial(PreScaleNorm, dim) @@ -542,8 +598,10 @@ class Performer(nn.Module): for _, local_heads in zip(range(depth), local_attn_heads): - attn = SelfAttention(dim, causal = causal, heads = heads, dim_head = dim_head, local_heads = local_heads, local_window_size = local_window_size, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection, qkv_bias = qkv_bias, attn_out_bias = attn_out_bias) - ff = Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1) + attn = SelfAttention(dim, causal=causal, heads=heads, dim_head=dim_head, local_heads=local_heads, local_window_size=local_window_size, nb_features=nb_features, + generalized_attention=generalized_attention, kernel_fn=kernel_fn, dropout=attn_dropout, no_projection=no_projection, qkv_bias=qkv_bias, attn_out_bias=attn_out_bias) + ff = Chunk(ff_chunks, FeedForward(dim, mult=ff_mult, + dropout=ff_dropout, glu=ff_glu), along_dim=1) if shift_tokens: shift = (0, 1) if causal else (-1, 0, 1) @@ -556,8 +614,10 @@ class Performer(nn.Module): continue layers.append(nn.ModuleList([ - wrapper_fn(CrossAttention(dim, heads = heads, dim_head = dim_head, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection, qkv_bias = qkv_bias, attn_out_bias = attn_out_bias)), - wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1)) + wrapper_fn(CrossAttention(dim, heads=heads, dim_head=dim_head, nb_features=nb_features, generalized_attention=generalized_attention, + kernel_fn=kernel_fn, dropout=attn_dropout, no_projection=no_projection, qkv_bias=qkv_bias, attn_out_bias=attn_out_bias)), + wrapper_fn(Chunk(ff_chunks, FeedForward( + dim, mult=ff_mult, dropout=ff_dropout, glu=ff_glu), along_dim=1)) ])) execute_type = ReversibleSequence if reversible else SequentialSequence @@ -565,12 +625,15 @@ class Performer(nn.Module): route_attn = ((True, False),) * depth * (2 if cross_attend else 1) route_context = ((False, False), (True, False)) * depth attn_route_map = {'mask': route_attn, 'pos_emb': route_attn} - context_route_map = {'context': route_context, 'context_mask': route_context} if cross_attend else {} - self.net = execute_type(layers, args_route = {**attn_route_map, **context_route_map}) + context_route_map = {'context': route_context, + 'context_mask': route_context} if cross_attend else {} + self.net = execute_type( + layers, args_route={**attn_route_map, **context_route_map}) # keeping track of when to redraw projections for all attention layers self.auto_check_redraw = auto_check_redraw - self.proj_updater = ProjectionUpdater(self.net, feature_redraw_interval) + self.proj_updater = ProjectionUpdater( + self.net, feature_redraw_interval) def fix_projection_matrices_(self): self.proj_updater.feature_redraw_interval = None @@ -580,6 +643,7 @@ class Performer(nn.Module): self.proj_updater.redraw_projections() return self.net(x, **kwargs) + class PerformerLM(nn.Module): def __init__( self, @@ -589,33 +653,33 @@ class PerformerLM(nn.Module): dim, depth, heads, - dim_head = 64, - local_attn_heads = 0, - local_window_size = 256, - causal = False, - ff_mult = 4, - nb_features = None, - feature_redraw_interval = 1000, - reversible = False, - ff_chunks = 1, - ff_glu = False, - emb_dropout = 0., - ff_dropout = 0., - attn_dropout = 0., - generalized_attention = False, - kernel_fn = nn.ReLU(), - use_scalenorm = False, - use_rezero = False, - cross_attend = False, - no_projection = False, - tie_embed = False, - rotary_position_emb = True, - axial_position_emb = False, - axial_position_shape = None, - auto_check_redraw = True, - qkv_bias = False, - attn_out_bias = False, - shift_tokens = False + dim_head=64, + local_attn_heads=0, + local_window_size=256, + causal=False, + ff_mult=4, + nb_features=None, + feature_redraw_interval=1000, + reversible=False, + ff_chunks=1, + ff_glu=False, + emb_dropout=0., + ff_dropout=0., + attn_dropout=0., + generalized_attention=False, + kernel_fn=nn.ReLU(), + use_scalenorm=False, + use_rezero=False, + cross_attend=False, + no_projection=False, + tie_embed=False, + rotary_position_emb=True, + axial_position_emb=False, + axial_position_shape=None, + auto_check_redraw=True, + qkv_bias=False, + attn_out_bias=False, + shift_tokens=False ): super().__init__() local_attn_heads = cast_tuple(local_attn_heads) @@ -626,9 +690,11 @@ class PerformerLM(nn.Module): if rotary_position_emb: self.pos_emb = FixedPositionalEmbedding(dim, max_seq_len) - self.layer_pos_emb = FixedPositionalEmbedding(dim_head, max_seq_len) + self.layer_pos_emb = FixedPositionalEmbedding( + dim_head, max_seq_len) elif axial_position_emb: - axial_position_shape = default(axial_position_shape, (math.ceil(max_seq_len / 64), 64)) + axial_position_shape = default( + axial_position_shape, (math.ceil(max_seq_len / 64), 64)) self.pos_emb = AxialPositionalEmbedding(dim, axial_position_shape) self.layer_pos_emb = Always(None) else: @@ -637,7 +703,8 @@ class PerformerLM(nn.Module): self.dropout = nn.Dropout(emb_dropout) - self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias, attn_out_bias, shift_tokens) + self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, + generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias, attn_out_bias, shift_tokens) self.norm = nn.LayerNorm(dim) self.to_out = ml.Linear(dim, num_tokens) if not tie_embed else None @@ -647,7 +714,7 @@ class PerformerLM(nn.Module): def fix_projection_matrices_(self): self.performer.fix_projection_matrices_() - def forward(self, x, return_encodings = False, **kwargs): + def forward(self, x, return_encodings=False, **kwargs): b, n, device = *x.shape, x.device assert n <= self.max_seq_len, f'sequence length {n} must be less than the max sequence length {self.max_seq_len}' @@ -660,7 +727,7 @@ class PerformerLM(nn.Module): # performer layers layer_pos_emb = self.layer_pos_emb(x) - x = self.performer(x, pos_emb = layer_pos_emb, **kwargs) + x = self.performer(x, pos_emb=layer_pos_emb, **kwargs) # norm and to logits x = self.norm(x) diff --git a/dlas/models/lucidrains/performer/reversible.py b/dlas/models/lucidrains/performer/reversible.py index 0b740105..98ed6ecd 100644 --- a/dlas/models/lucidrains/performer/reversible.py +++ b/dlas/models/lucidrains/performer/reversible.py @@ -1,9 +1,11 @@ +from operator import itemgetter + import torch import torch.nn as nn -from operator import itemgetter from torch.autograd.function import Function from torch.utils.checkpoint import get_device_states, set_device_states + # for routing arguments into the functions of the reversible layer def route_args(router, args, depth): routed_args = [(dict(), dict()) for _ in range(depth)] @@ -12,11 +14,15 @@ def route_args(router, args, depth): for key in matched_keys: val = args[key] for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): - new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) - routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) + new_f_args, new_g_args = map(lambda route: ( + {key: val} if route else {}), routes) + routed_args[depth] = ({**f_args, **new_f_args}, + {**g_args, **new_g_args}) return routed_args # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html + + class Deterministic(nn.Module): def __init__(self, net): super().__init__() @@ -32,7 +38,7 @@ class Deterministic(nn.Module): self.cuda_in_fwd = True self.gpu_devices, self.gpu_states = get_device_states(*args) - def forward(self, *args, record_rng = False, set_rng = False, **kwargs): + def forward(self, *args, record_rng=False, set_rng=False, **kwargs): if record_rng: self.record_rng(*args) @@ -51,13 +57,15 @@ class Deterministic(nn.Module): # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py # once multi-GPU is confirmed working, refactor and send PR back to source + + class ReversibleBlock(nn.Module): def __init__(self, f, g): super().__init__() self.f = Deterministic(f) self.g = Deterministic(g) - def forward(self, x, f_args = {}, g_args = {}): + def forward(self, x, f_args={}, g_args={}): x1, x2 = torch.chunk(x, 2, dim=2) y1, y2 = None, None @@ -67,7 +75,7 @@ class ReversibleBlock(nn.Module): return torch.cat([y1, y2], dim=2) - def backward_pass(self, y, dy, f_args = {}, g_args = {}): + def backward_pass(self, y, dy, f_args={}, g_args={}): y1, y2 = torch.chunk(y, 2, dim=2) del y @@ -105,6 +113,7 @@ class ReversibleBlock(nn.Module): return x, dx + class _ReversibleFunction(Function): @staticmethod def forward(ctx, x, blocks, args): @@ -123,10 +132,12 @@ class _ReversibleFunction(Function): y, dy = block.backward_pass(y, dy, **kwargs) return dy, None, None + class SequentialSequence(nn.Module): - def __init__(self, layers, args_route = {}): + def __init__(self, layers, args_route={}): super().__init__() - assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers' + assert all(len(route) == len(layers) for route in args_route.values( + )), 'each argument route map must have the same depth as the number of sequential layers' self.layers = layers self.args_route = args_route @@ -139,11 +150,13 @@ class SequentialSequence(nn.Module): x = x + g(x, **g_args) return x + class ReversibleSequence(nn.Module): - def __init__(self, blocks, args_route = {}): + def __init__(self, blocks, args_route={}): super().__init__() self.args_route = args_route - self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks]) + self.blocks = nn.ModuleList( + [ReversibleBlock(f=f, g=g) for f, g in blocks]) def forward(self, x, **kwargs): x = torch.cat([x, x], dim=-1) @@ -152,5 +165,5 @@ class ReversibleSequence(nn.Module): args = route_args(self.args_route, kwargs, len(blocks)) args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args)) - out = _ReversibleFunction.apply(x, blocks, args) + out = _ReversibleFunction.apply(x, blocks, args) return torch.stack(out.chunk(2, dim=-1)).sum(dim=0) diff --git a/dlas/models/lucidrains/vq.py b/dlas/models/lucidrains/vq.py index 13a4b2ae..77f79fdd 100644 --- a/dlas/models/lucidrains/vq.py +++ b/dlas/models/lucidrains/vq.py @@ -1,68 +1,81 @@ import functools +from contextlib import contextmanager import torch -from torch import nn, einsum -import torch.nn.functional as F import torch.distributed as distributed +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import einsum, nn from torch.cuda.amp import autocast -from einops import rearrange, repeat -from contextlib import contextmanager -import torch_intermediary as ml +import dlas.torch_intermediary as ml def par(t, nm): print(f'grad report {nm}: {t}') return t + def reg(t, nm): - l = torch.tensor([0], requires_grad=True, device=t.device, dtype=torch.float) + l = torch.tensor([0], requires_grad=True, + device=t.device, dtype=torch.float) l.register_hook(functools.partial(par, nm=nm)) t = t + l return t + def exists(val): return val is not None + def default(val, d): return val if exists(val) else d + def noop(*args, **kwargs): pass -def l2norm(t): - return F.normalize(t, p = 2, dim = -1) -def log(t, eps = 1e-20): - return torch.log(t.clamp(min = eps)) +def l2norm(t): + return F.normalize(t, p=2, dim=-1) + + +def log(t, eps=1e-20): + return torch.log(t.clamp(min=eps)) + def gumbel_noise(t): noise = torch.zeros_like(t).uniform_(0, 1) return -log(-log(noise)) -def gumbel_sample(t, temperature = 1., dim = -1): - if temperature == 0: - return t.argmax(dim = dim) - return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim) +def gumbel_sample(t, temperature=1., dim=-1): + if temperature == 0: + return t.argmax(dim=dim) + + return ((t / temperature) + gumbel_noise(t)).argmax(dim=dim) + def ema_inplace(moving_avg, new, decay): - moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) -def laplace_smoothing(x, n_categories, eps = 1e-5): + +def laplace_smoothing(x, n_categories, eps=1e-5): return (x + eps) / (x.sum() + n_categories * eps) + def sample_vectors(samples, num): num_samples, device = samples.shape[0], samples.device if num_samples >= num: - indices = torch.randperm(num_samples, device = device)[:num] + indices = torch.randperm(num_samples, device=device)[:num] else: - indices = torch.randint(0, num_samples, (num,), device = device) + indices = torch.randint(0, num_samples, (num,), device=device) return samples[indices] -def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False): + +def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False): dim, dtype, device = samples.shape[-1], samples.dtype, samples.device means = sample_vectors(samples, num_clusters) @@ -72,16 +85,16 @@ def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False): dists = samples @ means.t() else: diffs = rearrange(samples, 'n d -> n () d') \ - - rearrange(means, 'c d -> () c d') - dists = -(diffs ** 2).sum(dim = -1) + - rearrange(means, 'c d -> () c d') + dists = -(diffs ** 2).sum(dim=-1) - buckets = dists.max(dim = -1).indices - bins = torch.bincount(buckets, minlength = num_clusters) + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) zero_mask = bins == 0 bins_min_clamped = bins.masked_fill(zero_mask, 1) - new_means = buckets.new_zeros(num_clusters, dim, dtype = dtype) - new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d = dim), samples) + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples) new_means = new_means / bins_min_clamped[..., None] if use_cosine_sim: @@ -93,29 +106,31 @@ def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False): # regularization losses + def orthgonal_loss_fn(t): # eq (2) from https://arxiv.org/abs/2112.00384 n = t.shape[0] normed_codes = l2norm(t) - identity = torch.eye(n, device = t.device) + identity = torch.eye(n, device=t.device) cosine_sim = einsum('i d, j d -> i j', normed_codes, normed_codes) return ((cosine_sim - identity) ** 2).sum() / (n ** 2) # distance types + class EuclideanCodebook(nn.Module): def __init__( self, dim, codebook_size, - kmeans_init = False, - kmeans_iters = 10, - decay = 0.8, - eps = 1e-5, - threshold_ema_dead_code = 2, - use_ddp = False, - learnable_codebook = False, - sample_codebook_temp = 0 + kmeans_init=False, + kmeans_iters=10, + decay=0.8, + eps=1e-5, + threshold_ema_dead_code=2, + use_ddp=False, + learnable_codebook=False, + sample_codebook_temp=0 ): super().__init__() self.decay = decay @@ -145,7 +160,8 @@ class EuclideanCodebook(nn.Module): if self.initted: return - embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + embed, cluster_size = kmeans( + data, self.codebook_size, self.kmeans_iters) self.embed.data.copy_(embed) self.embed_avg.data.copy_(embed.clone()) self.cluster_size.data.copy_(cluster_size) @@ -167,9 +183,9 @@ class EuclideanCodebook(nn.Module): if not torch.any(expired_codes): return batch_samples = rearrange(batch_samples, '... d -> (...) d') - self.replace(batch_samples, mask = expired_codes) + self.replace(batch_samples, mask=expired_codes) - @autocast(enabled = False) + @autocast(enabled=False) def forward(self, x, used_codes=[]): shape, dtype = x.shape, x.dtype flatten = rearrange(x, '... d -> (...) d') @@ -186,9 +202,11 @@ class EuclideanCodebook(nn.Module): ) for uc in used_codes: - mask = torch.arange(0, self.codebook_size, device=x.device).unsqueeze(0).repeat(x.shape[0],1) == uc.unsqueeze(1) + mask = torch.arange(0, self.codebook_size, device=x.device).unsqueeze( + 0).repeat(x.shape[0], 1) == uc.unsqueeze(1) dist[mask] = -torch.inf - embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp) + embed_ind = gumbel_sample( + dist, dim=-1, temperature=self.sample_codebook_temp) embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) embed_ind = embed_ind.view(*shape[:-1]) quantize = F.embedding(embed_ind, self.embed) @@ -207,26 +225,28 @@ class EuclideanCodebook(nn.Module): self.all_reduce_fn(embed_sum) ema_inplace(self.embed_avg, embed_sum.t(), self.decay) - cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum() + cluster_size = laplace_smoothing( + self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum() embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) self.embed.data.copy_(embed_normalized) self.expire_codes_(x) return quantize, embed_ind + class CosineSimCodebook(nn.Module): def __init__( self, dim, codebook_size, - kmeans_init = False, - kmeans_iters = 10, - decay = 0.8, - eps = 1e-5, - threshold_ema_dead_code = 2, - use_ddp = False, - learnable_codebook = False, - sample_codebook_temp = 0. + kmeans_init=False, + kmeans_iters=10, + decay=0.8, + eps=1e-5, + threshold_ema_dead_code=2, + use_ddp=False, + learnable_codebook=False, + sample_codebook_temp=0. ): super().__init__() self.decay = decay @@ -258,7 +278,7 @@ class CosineSimCodebook(nn.Module): return embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters, - use_cosine_sim = True) + use_cosine_sim=True) self.embed.data.copy_(embed) self.cluster_size.data.copy_(cluster_size) self.initted.data.copy_(torch.Tensor([True])) @@ -280,9 +300,9 @@ class CosineSimCodebook(nn.Module): if not torch.any(expired_codes): return batch_samples = rearrange(batch_samples, '... d -> (...) d') - self.replace(batch_samples, mask = expired_codes) + self.replace(batch_samples, mask=expired_codes) - @autocast(enabled = False) + @autocast(enabled=False) def forward(self, x, used_codes=[]): shape, dtype = x.shape, x.dtype flatten = rearrange(x, '... d -> (...) d') @@ -294,9 +314,11 @@ class CosineSimCodebook(nn.Module): dist = flatten @ embed.t() for uc in used_codes: - mask = torch.arange(0, self.codebook_size, device=x.device).unsqueeze(0).repeat(x.shape[0],1) == uc.unsqueeze(1) + mask = torch.arange(0, self.codebook_size, device=x.device).unsqueeze( + 0).repeat(x.shape[0], 1) == uc.unsqueeze(1) dist[mask] = -torch.inf - embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp) + embed_ind = gumbel_sample( + dist, dim=-1, temperature=self.sample_codebook_temp) embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) embed_ind = embed_ind.view(*shape[:-1]) @@ -328,28 +350,29 @@ class CosineSimCodebook(nn.Module): # main class + class VectorQuantize(nn.Module): def __init__( self, dim, codebook_size, - n_embed = None, - codebook_dim = None, - decay = 0.8, - eps = 1e-5, - kmeans_init = False, - kmeans_iters = 10, - use_cosine_sim = False, - threshold_ema_dead_code = 0, - channel_last = True, - accept_image_fmap = False, - commitment_weight = None, - commitment = 1., # deprecate in next version, turn off by default - orthogonal_reg_weight = 0., - orthogonal_reg_active_codes_only = False, - orthogonal_reg_max_codes = None, - sample_codebook_temp = 0., - sync_codebook = False + n_embed=None, + codebook_dim=None, + decay=0.8, + eps=1e-5, + kmeans_init=False, + kmeans_iters=10, + use_cosine_sim=False, + threshold_ema_dead_code=0, + channel_last=True, + accept_image_fmap=False, + commitment_weight=None, + commitment=1., # deprecate in next version, turn off by default + orthogonal_reg_weight=0., + orthogonal_reg_active_codes_only=False, + orthogonal_reg_max_codes=None, + sample_codebook_temp=0., + sync_codebook=False ): super().__init__() n_embed = default(n_embed, codebook_size) @@ -357,9 +380,9 @@ class VectorQuantize(nn.Module): codebook_dim = default(codebook_dim, dim) requires_projection = codebook_dim != dim self.project_in = ml.Linear(dim, codebook_dim) if requires_projection \ - else nn.Identity() + else nn.Identity() self.project_out = ml.Linear(codebook_dim, dim) if requires_projection \ - else nn.Identity() + else nn.Identity() self.eps = eps self.commitment_weight = default(commitment_weight, commitment) @@ -370,19 +393,19 @@ class VectorQuantize(nn.Module): self.orthogonal_reg_max_codes = orthogonal_reg_max_codes codebook_class = EuclideanCodebook if not use_cosine_sim \ - else CosineSimCodebook + else CosineSimCodebook self._codebook = codebook_class( - dim = codebook_dim, - codebook_size = n_embed, - kmeans_init = kmeans_init, - kmeans_iters = kmeans_iters, - decay = decay, - eps = eps, - threshold_ema_dead_code = threshold_ema_dead_code, - use_ddp = sync_codebook, - learnable_codebook = has_codebook_orthogonal_loss, - sample_codebook_temp = sample_codebook_temp + dim=codebook_dim, + codebook_size=n_embed, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + eps=eps, + threshold_ema_dead_code=threshold_ema_dead_code, + use_ddp=sync_codebook, + learnable_codebook=has_codebook_orthogonal_loss, + sample_codebook_temp=sample_codebook_temp ) self.codebook_size = codebook_size @@ -410,7 +433,7 @@ class VectorQuantize(nn.Module): quantize, embed_ind = self._codebook(x, used_codes) - loss = torch.tensor([0.], device = device, requires_grad = self.training) + loss = torch.tensor([0.], device=device, requires_grad=self.training) if self.training: if self.commitment_weight > 0: @@ -427,7 +450,8 @@ class VectorQuantize(nn.Module): num_codes = codebook.shape[0] if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes: - rand_ids = torch.randperm(num_codes, device = device)[:self.orthogonal_reg_max_codes] + rand_ids = torch.randperm(num_codes, device=device)[ + :self.orthogonal_reg_max_codes] codebook = codebook[rand_ids] orthogonal_reg_loss = orthgonal_loss_fn(codebook) @@ -439,7 +463,9 @@ class VectorQuantize(nn.Module): quantize = rearrange(quantize, 'b n d -> b d n') if self.accept_image_fmap: - quantize = rearrange(quantize, 'b (h w) c -> b c h w', h = height, w = width) - embed_ind = rearrange(embed_ind, 'b (h w) -> b h w', h = height, w = width) + quantize = rearrange( + quantize, 'b (h w) c -> b c h w', h=height, w=width) + embed_ind = rearrange( + embed_ind, 'b (h w) -> b h w', h=height, w=width) return quantize, embed_ind, loss diff --git a/dlas/models/lucidrains/x_transformers.py b/dlas/models/lucidrains/x_transformers.py index a49af93d..cc16aa8e 100644 --- a/dlas/models/lucidrains/x_transformers.py +++ b/dlas/models/lucidrains/x_transformers.py @@ -1,17 +1,17 @@ import functools import math -import torch -from torch import nn, einsum -import torch.nn.functional as F +from collections import namedtuple from functools import partial from inspect import isfunction -from collections import namedtuple -from einops import rearrange, repeat, reduce +import torch +import torch.nn.functional as F +from einops import rearrange, reduce, repeat from einops.layers.torch import Rearrange - +from torch import einsum, nn from torch.utils.checkpoint import checkpoint -import torch_intermediary as ml + +import dlas.torch_intermediary as ml DEFAULT_DIM_HEAD = 64 @@ -108,8 +108,10 @@ def group_by_key_prefix(prefix, d): def groupby_prefix_and_trim(prefix, d): - kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) - kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + kwargs_with_prefix, kwargs = group_dict_by_key( + partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict( + map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) return kwargs_without_prefix, kwargs @@ -143,7 +145,8 @@ class FixedPositionalEmbedding(nn.Module): self.register_buffer('inv_freq', inv_freq) def forward(self, x, seq_dim=1, offset=0): - t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + t = torch.arange(x.shape[seq_dim], device=x.device).type_as( + self.inv_freq) + offset sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) return rearrange(emb, 'n d -> () n d') @@ -174,9 +177,11 @@ class RelativePositionBias(nn.Module): is_small = n < max_exact val_if_large = max_exact + ( - torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + torch.log(n.float() / max_exact) / math.log(max_distance / + max_exact) * (num_buckets - max_exact) ).long() - val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + val_if_large = torch.min( + val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) return ret @@ -214,7 +219,7 @@ class AlibiPositionalBias(nn.Module): closest_power_of_2 = 2 ** math.floor(math.log2(heads)) return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][ - :heads - closest_power_of_2] + :heads - closest_power_of_2] def forward(self, qk_dots): h, i, j, device = *qk_dots.shape[-3:], qk_dots.device @@ -254,13 +259,15 @@ class LearnedAlibiPositionalBias(AlibiPositionalBias): else: i_arange = torch.arange(i, device=device) j_arange = torch.arange(j, device=device) - bias = rearrange(j_arange, 'j -> 1 1 1 j') - rearrange(i_arange, 'i -> 1 1 i 1') + bias = rearrange(j_arange, 'j -> 1 1 1 j') - \ + rearrange(i_arange, 'i -> 1 1 i 1') self.register_buffer('bias', bias, persistent=False) if self.bidirectional: past_slopes = get_slopes(self.learned_logslopes) future_slopes = get_slopes(self.learned_logslopes_future) - bias = torch.tril(bias * past_slopes) + torch.triu(bias * future_slopes) + bias = torch.tril(bias * past_slopes) + \ + torch.triu(bias * future_slopes) else: slopes = get_slopes(self.learned_logslopes) bias = bias * slopes @@ -303,7 +310,7 @@ class Scale(nn.Module): def forward(self, x, **kwargs): out = self.fn(x, **kwargs) - scale_fn = lambda t: t * self.value + def scale_fn(t): return t * self.value if not isinstance(out, tuple): return scale_fn(out) @@ -319,7 +326,7 @@ class Rezero(nn.Module): def forward(self, x, **kwargs): out = self.fn(x, **kwargs) - rezero_fn = lambda t: t * self.g + def rezero_fn(t): return t * self.g if not isinstance(out, tuple): return rezero_fn(out) @@ -359,7 +366,8 @@ class RMSScaleShiftNorm(nn.Module): self.eps = eps self.g = nn.Parameter(torch.ones(dim)) if conv_ch_order: - self.scale_shift_process = nn.Conv1d(embed_dim, dim*2, kernel_size=1, bias=bias) + self.scale_shift_process = nn.Conv1d( + embed_dim, dim*2, kernel_size=1, bias=bias) self.cdim = 1 self.pdim = -1 else: @@ -385,7 +393,8 @@ class RMSScaleShiftNorm(nn.Module): class Residual(nn.Module): def __init__(self, dim, scale_residual=False, mask_residual=False): super().__init__() - self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None + self.residual_scale = nn.Parameter( + torch.ones(dim)) if scale_residual else None if mask_residual: self.residual_scale.data.zero_() @@ -400,7 +409,8 @@ class GRUGating(nn.Module): def __init__(self, dim, scale_residual=False): super().__init__() self.gru = nn.GRUCell(dim, dim) - self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None + self.residual_scale = nn.Parameter( + torch.ones(dim)) if scale_residual else None def forward(self, x, residual): if exists(self.residual_scale): @@ -439,7 +449,8 @@ class ShiftTokens(nn.Module): feats_per_shift = x.shape[-1] // segments splitted = x.split(feats_per_shift, dim=-1) segments_to_shift, rest = splitted[:segments], splitted[segments:] - segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts))) + segments_to_shift = list(map(lambda args: shift( + *args, mask=mask), zip(segments_to_shift, shifts))) x = torch.cat((*segments_to_shift, *rest), dim=-1) return self.fn(x, **kwargs) @@ -556,7 +567,8 @@ class Attention(nn.Module): if qk_norm: scale_init_value = default(scale_init_value, -3) # if not provided, initialize as though it were sequence length of 1024 - self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value) + self.scale = nn.Parameter(torch.ones( + 1, heads, 1, 1) * scale_init_value) # talking heads self.talking_heads = talking_heads @@ -584,7 +596,8 @@ class Attention(nn.Module): # attention on attention self.attn_on_attn = on_attn out_dim = default(out_dim, dim) - self.to_out = nn.Sequential(ml.Linear(v_dim, out_dim * 2), nn.GLU()) if on_attn else ml.Linear(v_dim, out_dim) + self.to_out = nn.Sequential(ml.Linear( + v_dim, out_dim * 2), nn.GLU()) if on_attn else ml.Linear(v_dim, out_dim) self.rel_pos_bias = rel_pos_bias if rel_pos_bias: @@ -632,7 +645,8 @@ class Attention(nn.Module): v = self.to_v(v_input) if not collab_heads: - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange( + t, 'b n (h d) -> b h n d', h=h), (q, k, v)) else: q = einsum('b i d, h d -> b h i d', q, self.collab_mixing) k = rearrange(k, 'b n d -> b () n d') @@ -647,25 +661,32 @@ class Attention(nn.Module): if exists(rotary_pos_emb) and not has_context: l = rotary_pos_emb.shape[-1] - (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) - ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl)) - q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) + (ql, qr), (kl, kr), (vl, vr) = map( + lambda t: (t[..., :l], t[..., l:]), (q, k, v)) + ql, kl, vl = map(lambda t: apply_rotary_pos_emb( + t, rotary_pos_emb), (ql, kl, vl)) + q, k, v = map(lambda t: torch.cat(t, dim=-1), + ((ql, qr), (kl, kr), (vl, vr))) input_mask = None if any(map(exists, (mask, context_mask))): - q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + q_mask = default(mask, lambda: torch.ones( + (b, n), device=device).bool()) k_mask = q_mask if not exists(context) else context_mask - k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + k_mask = default(k_mask, lambda: torch.ones( + (b, k.shape[-2]), device=device).bool()) q_mask = rearrange(q_mask, 'b i -> b () i ()') k_mask = rearrange(k_mask, 'b j -> b () () j') input_mask = q_mask * k_mask if self.num_mem_kv > 0: - mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + mem_k, mem_v = map(lambda t: repeat( + t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) k = torch.cat((mem_k, k), dim=-2) v = torch.cat((mem_v, v), dim=-2) if exists(input_mask): - input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + input_mask = F.pad( + input_mask, (self.num_mem_kv, 0), value=True) if collab_heads: k = k.expand(-1, h, -1, -1) @@ -683,7 +704,8 @@ class Attention(nn.Module): pre_softmax_attn = dots.clone() if talking_heads: - dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + dots = einsum('b h i j, h k -> b k i j', dots, + self.pre_softmax_proj).contiguous() if self.rel_pos_bias: dots = self.rel_pos(dots) @@ -704,7 +726,8 @@ class Attention(nn.Module): i, j = dots.shape[-2:] range_q = torch.arange(j - i, j, device=device) range_k = torch.arange(j, device=device) - dist = rearrange(range_q, 'i -> () () i ()') - rearrange(range_k, 'j -> () () () j') + dist = rearrange(range_q, 'i -> () () i ()') - \ + rearrange(range_k, 'j -> () () () j') mask = dist > self.max_attend_past dots.masked_fill_(mask, mask_value) del mask @@ -712,7 +735,8 @@ class Attention(nn.Module): if self.causal: i, j = dots.shape[-2:] r = torch.arange(i, device=device) - mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = rearrange( + r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') mask = F.pad(mask, (j - i, 0), value=False) dots.masked_fill_(mask, mask_value) del mask @@ -730,7 +754,8 @@ class Attention(nn.Module): attn = self.dropout(attn) if talking_heads: - attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + attn = einsum('b h i j, h k -> b k i j', attn, + self.post_softmax_proj).contiguous() out = einsum('b h i j, b h j d -> b h i d', attn, v) @@ -801,23 +826,27 @@ class AttentionLayers(nn.Module): rel_pos_bias = 'rel_pos_bias' in attn_kwargs self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb - self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.pia_pos_emb = FixedPositionalEmbedding( + dim) if position_infused_attn else None rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32) - self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None + self.rotary_pos_emb = RotaryEmbedding( + rotary_emb_dim) if rotary_pos_emb else None assert not ( - alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both' + alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both' if alibi_pos_bias: alibi_num_heads = default(alibi_num_heads, heads) assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads' alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias - self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal) + self.rel_pos = alibi_pos_klass( + heads=alibi_num_heads, bidirectional=not causal) else: self.rel_pos = None - assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm' + assert not ( + not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm' self.pre_norm = pre_norm self.sandwich_norm = sandwich_norm @@ -848,7 +877,8 @@ class AttentionLayers(nn.Module): if use_qk_norm_attn: attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists( qk_norm_attn_seq_len) else None - attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value} + attn_kwargs = {**attn_kwargs, 'qk_norm': True, + 'scale_init_value': attn_scale_init_value} # zero init @@ -865,15 +895,19 @@ class AttentionLayers(nn.Module): assert 1 < par_ratio <= par_depth, 'par ratio out of range' default_block = tuple(filter(not_equals('f'), default_block)) par_attn = par_depth // par_ratio - depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + # 2 / 3 attention layer cutoff suggested by PAR paper + depth_cut = par_depth * 2 // 3 par_width = (depth_cut + depth_cut // par_attn) // par_attn - assert len(default_block) <= par_width, 'default block is too large for par_ratio' - par_block = default_block + ('f',) * (par_width - len(default_block)) + assert len( + default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + \ + ('f',) * (par_width - len(default_block)) par_head = par_block * par_attn layer_types = par_head + ('f',) * (par_depth - len(par_head)) elif exists(sandwich_coef): assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' - layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + layer_types = ('a',) * sandwich_coef + default_block * \ + (depth - sandwich_coef) + ('f',) * sandwich_coef else: layer_types = default_block * depth @@ -890,7 +924,8 @@ class AttentionLayers(nn.Module): is_last_layer = ind == (len(self.layer_types) - 1) if layer_type == 'a': - layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + layer = Attention(dim, heads=heads, + causal=causal, **attn_kwargs) elif layer_type == 'c': layer = Attention(dim, heads=heads, **attn_kwargs) elif layer_type == 'f': @@ -902,7 +937,8 @@ class AttentionLayers(nn.Module): if layer_shift_tokens > 0: shift_range_upper = layer_shift_tokens + 1 shift_range_lower = -layer_shift_tokens if not causal else 0 - layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer) + layer = ShiftTokens( + range(shift_range_lower, shift_range_upper), layer) if exists(branch_fn): layer = branch_fn(layer) @@ -966,8 +1002,10 @@ class AttentionLayers(nn.Module): seq_len = x.shape[1] if past_key_values is not None: seq_len += past_key_values[0][0].shape[-2] - max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len]) - rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) + max_rotary_emb_length = max(list(map(lambda m: ( + m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len]) + rotary_pos_emb = self.rotary_pos_emb( + max_rotary_emb_length, x.device) present_key_values = [] cross_attn_count = 0 @@ -995,13 +1033,14 @@ class AttentionLayers(nn.Module): if layer_type == 'a': out, inter, k, v = chkpt_fn(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, - prev_attn, layer_mem, layer_past) + prev_attn, layer_mem, layer_past) elif layer_type == 'c': if exists(full_context): out, inter, k, v = chkpt_fn(block, x, full_context[cross_attn_count], mask, context_mask, None, None, - None, prev_attn, None, layer_past) + None, prev_attn, None, layer_past) else: - out, inter, k, v = chkpt_fn(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past) + out, inter, k, v = chkpt_fn( + block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past) elif layer_type == 'f': out = chkpt_fn(block, x) @@ -1071,7 +1110,8 @@ class ViTransformerWrapper(nn.Module): emb_dropout=0. ): super().__init__() - assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder' + assert isinstance( + attn_layers, Encoder), 'attention layers must be an Encoder' assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' dim = attn_layers.dim num_patches = (image_size // patch_size) ** 2 @@ -1086,7 +1126,8 @@ class ViTransformerWrapper(nn.Module): self.attn_layers = attn_layers self.norm = nn.LayerNorm(dim) - self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None + self.mlp_head = FeedForward( + dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None def forward( self, @@ -1095,7 +1136,8 @@ class ViTransformerWrapper(nn.Module): ): p = self.patch_size - x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) + x = rearrange( + img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) x = self.patch_to_embedding(x) b, n, _ = x.shape @@ -1129,7 +1171,8 @@ class TransformerWrapper(nn.Module): use_pos_emb=True ): super().__init__() - assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + assert isinstance( + attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' dim = attn_layers.dim emb_dim = default(emb_dim, dim) @@ -1141,22 +1184,25 @@ class TransformerWrapper(nn.Module): # nn.Embedding self.token_emb = ml.Embedding(num_tokens, emb_dim) self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( - use_pos_emb and not attn_layers.has_pos_emb) else always(0) + use_pos_emb and not attn_layers.has_pos_emb) else always(0) self.emb_dropout = nn.Dropout(emb_dropout) - self.project_emb = ml.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.project_emb = ml.Linear( + emb_dim, dim) if emb_dim != dim else nn.Identity() self.attn_layers = attn_layers self.norm = nn.LayerNorm(dim) self.init_() - self.to_logits = ml.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + self.to_logits = ml.Linear( + dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() # memory tokens (like [cls]) from Memory Transformers paper num_memory_tokens = default(num_memory_tokens, 0) self.num_memory_tokens = num_memory_tokens if num_memory_tokens > 0: - self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + self.memory_tokens = nn.Parameter( + torch.randn(num_memory_tokens, dim)) def init_(self): nn.init.kaiming_normal_(self.token_emb.weight) @@ -1191,7 +1237,8 @@ class TransformerWrapper(nn.Module): mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:] mems = [*mems_r, *mems_l] - x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x, intermediates = self.attn_layers( + x, mask=mask, mems=mems, return_hiddens=True, **kwargs) x = self.norm(x) mem, x = x[:, :num_mem], x[:, num_mem:] @@ -1204,7 +1251,8 @@ class TransformerWrapper(nn.Module): res = [out] if return_attn: - attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + attn_maps = list(map(lambda t: t.post_softmax_attn, + intermediates.attn_intermediates)) res.append(attn_maps) if use_cache: res.append(intermediates.past_key_values) @@ -1227,22 +1275,25 @@ class ContinuousTransformerWrapper(nn.Module): use_pos_emb=True ): super().__init__() - assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + assert isinstance( + attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' dim = attn_layers.dim self.max_seq_len = max_seq_len self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if ( - use_pos_emb and not attn_layers.has_pos_emb) else always(0) + use_pos_emb and not attn_layers.has_pos_emb) else always(0) self.emb_dropout = nn.Dropout(emb_dropout) - self.project_in = ml.Linear(dim_in, dim) if exists(dim_in) else nn.Identity() + self.project_in = ml.Linear(dim_in, dim) if exists( + dim_in) else nn.Identity() self.attn_layers = attn_layers self.norm = nn.LayerNorm(dim) - self.project_out = ml.Linear(dim, dim_out) if exists(dim_out) else nn.Identity() + self.project_out = ml.Linear(dim, dim_out) if exists( + dim_out) else nn.Identity() def forward( self, @@ -1260,14 +1311,16 @@ class ContinuousTransformerWrapper(nn.Module): x = x + self.pos_emb(x) x = self.emb_dropout(x) - x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x, intermediates = self.attn_layers( + x, mask=mask, mems=mems, return_hiddens=True, **kwargs) x = self.norm(x) out = self.project_out(x) if not return_embeddings else x res = [out] if return_attn: - attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + attn_maps = list(map(lambda t: t.post_softmax_attn, + intermediates.attn_intermediates)) res.append(attn_maps) if use_cache: res.append(intermediates.past_key_values) diff --git a/dlas/models/optical_flow/PWCNet.py b/dlas/models/optical_flow/PWCNet.py index b5467922..fff40319 100644 --- a/dlas/models/optical_flow/PWCNet.py +++ b/dlas/models/optical_flow/PWCNet.py @@ -5,133 +5,142 @@ Jinwei Gu and Zhile Ren """ +import os + +# from spatial_correlation_sampler import spatial_correlation_sample +import numpy as np import torch import torch.nn as nn from torch.autograd import Variable -import os -#from spatial_correlation_sampler import spatial_correlation_sample -import numpy as np - - - - __all__ = [ 'pwc_dc_net', 'pwc_dc_net_old' - ] +] -from models.flownet2.networks.correlation_package.correlation import CorrelationFunction, Correlation +from models.flownet2.networks.correlation_package.correlation import ( + Correlation, CorrelationFunction) from models.flownet2.networks.resample2d_package.resample2d import Resample2d from trainer.networks import register_model def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): return nn.Sequential( - nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, - padding=padding, dilation=dilation, bias=True), - nn.LeakyReLU(0.1)) + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.LeakyReLU(0.1)) + def predict_flow(in_planes): - return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True) + return nn.Conv2d(in_planes, 2, kernel_size=3, stride=1, padding=1, bias=True) + def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): return nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride, padding, bias=True) + class PWCDCNet(nn.Module): """ PWC-DC net. add dilation convolution and densenet connections """ + def __init__(self, md=4, pretrained=False): """ input: md --- maximum displacement (for correlation. default: 4), after warpping """ - super(PWCDCNet,self).__init__() + super(PWCDCNet, self).__init__() - self.upsample = nn.Upsample(scale_factor=4, mode='bilinear') + self.upsample = nn.Upsample(scale_factor=4, mode='bilinear') - self.conv1a = conv(3, 16, kernel_size=3, stride=2) + self.conv1a = conv(3, 16, kernel_size=3, stride=2) self.conv1aa = conv(16, 16, kernel_size=3, stride=1) - self.conv1b = conv(16, 16, kernel_size=3, stride=1) - self.conv2a = conv(16, 32, kernel_size=3, stride=2) + self.conv1b = conv(16, 16, kernel_size=3, stride=1) + self.conv2a = conv(16, 32, kernel_size=3, stride=2) self.conv2aa = conv(32, 32, kernel_size=3, stride=1) - self.conv2b = conv(32, 32, kernel_size=3, stride=1) - self.conv3a = conv(32, 64, kernel_size=3, stride=2) + self.conv2b = conv(32, 32, kernel_size=3, stride=1) + self.conv3a = conv(32, 64, kernel_size=3, stride=2) self.conv3aa = conv(64, 64, kernel_size=3, stride=1) - self.conv3b = conv(64, 64, kernel_size=3, stride=1) - self.conv4a = conv(64, 96, kernel_size=3, stride=2) + self.conv3b = conv(64, 64, kernel_size=3, stride=1) + self.conv4a = conv(64, 96, kernel_size=3, stride=2) self.conv4aa = conv(96, 96, kernel_size=3, stride=1) - self.conv4b = conv(96, 96, kernel_size=3, stride=1) - self.conv5a = conv(96, 128, kernel_size=3, stride=2) - self.conv5aa = conv(128,128, kernel_size=3, stride=1) - self.conv5b = conv(128,128, kernel_size=3, stride=1) - self.conv6aa = conv(128,196, kernel_size=3, stride=2) - self.conv6a = conv(196,196, kernel_size=3, stride=1) - self.conv6b = conv(196,196, kernel_size=3, stride=1) + self.conv4b = conv(96, 96, kernel_size=3, stride=1) + self.conv5a = conv(96, 128, kernel_size=3, stride=2) + self.conv5aa = conv(128, 128, kernel_size=3, stride=1) + self.conv5b = conv(128, 128, kernel_size=3, stride=1) + self.conv6aa = conv(128, 196, kernel_size=3, stride=2) + self.conv6a = conv(196, 196, kernel_size=3, stride=1) + self.conv6b = conv(196, 196, kernel_size=3, stride=1) - #self.corr = Correlation(padding=md, kernel_size=1, patch_size=md, stride=1) - self.corr = Correlation(pad_size=md, kernel_size=1, max_displacement=md, stride1=1, stride2=1, corr_multiply=1) + # self.corr = Correlation(padding=md, kernel_size=1, patch_size=md, stride=1) + self.corr = Correlation( + pad_size=md, kernel_size=1, max_displacement=md, stride1=1, stride2=1, corr_multiply=1) self.leakyRELU = nn.LeakyReLU(0.1) nd = (2*md+1)**2 - dd = np.cumsum([128,128,96,64,32]) + dd = np.cumsum([128, 128, 96, 64, 32]) od = nd self.conv6_0 = conv(od, 128, kernel_size=3, stride=1) - self.conv6_1 = conv(od+dd[0],128, kernel_size=3, stride=1) - self.conv6_2 = conv(od+dd[1],96, kernel_size=3, stride=1) - self.conv6_3 = conv(od+dd[2],64, kernel_size=3, stride=1) - self.conv6_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.conv6_1 = conv(od+dd[0], 128, kernel_size=3, stride=1) + self.conv6_2 = conv(od+dd[1], 96, kernel_size=3, stride=1) + self.conv6_3 = conv(od+dd[2], 64, kernel_size=3, stride=1) + self.conv6_4 = conv(od+dd[3], 32, kernel_size=3, stride=1) self.predict_flow6 = predict_flow(od+dd[4]) self.deconv6 = deconv(2, 2, kernel_size=4, stride=2, padding=1) self.upfeat6 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) od = nd+128+4 self.conv5_0 = conv(od, 128, kernel_size=3, stride=1) - self.conv5_1 = conv(od+dd[0],128, kernel_size=3, stride=1) - self.conv5_2 = conv(od+dd[1],96, kernel_size=3, stride=1) - self.conv5_3 = conv(od+dd[2],64, kernel_size=3, stride=1) - self.conv5_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.conv5_1 = conv(od+dd[0], 128, kernel_size=3, stride=1) + self.conv5_2 = conv(od+dd[1], 96, kernel_size=3, stride=1) + self.conv5_3 = conv(od+dd[2], 64, kernel_size=3, stride=1) + self.conv5_4 = conv(od+dd[3], 32, kernel_size=3, stride=1) self.predict_flow5 = predict_flow(od+dd[4]) self.deconv5 = deconv(2, 2, kernel_size=4, stride=2, padding=1) self.upfeat5 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) od = nd+96+4 self.conv4_0 = conv(od, 128, kernel_size=3, stride=1) - self.conv4_1 = conv(od+dd[0],128, kernel_size=3, stride=1) - self.conv4_2 = conv(od+dd[1],96, kernel_size=3, stride=1) - self.conv4_3 = conv(od+dd[2],64, kernel_size=3, stride=1) - self.conv4_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.conv4_1 = conv(od+dd[0], 128, kernel_size=3, stride=1) + self.conv4_2 = conv(od+dd[1], 96, kernel_size=3, stride=1) + self.conv4_3 = conv(od+dd[2], 64, kernel_size=3, stride=1) + self.conv4_4 = conv(od+dd[3], 32, kernel_size=3, stride=1) self.predict_flow4 = predict_flow(od+dd[4]) self.deconv4 = deconv(2, 2, kernel_size=4, stride=2, padding=1) self.upfeat4 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) od = nd+64+4 self.conv3_0 = conv(od, 128, kernel_size=3, stride=1) - self.conv3_1 = conv(od+dd[0],128, kernel_size=3, stride=1) - self.conv3_2 = conv(od+dd[1],96, kernel_size=3, stride=1) - self.conv3_3 = conv(od+dd[2],64, kernel_size=3, stride=1) - self.conv3_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.conv3_1 = conv(od+dd[0], 128, kernel_size=3, stride=1) + self.conv3_2 = conv(od+dd[1], 96, kernel_size=3, stride=1) + self.conv3_3 = conv(od+dd[2], 64, kernel_size=3, stride=1) + self.conv3_4 = conv(od+dd[3], 32, kernel_size=3, stride=1) self.predict_flow3 = predict_flow(od+dd[4]) self.deconv3 = deconv(2, 2, kernel_size=4, stride=2, padding=1) self.upfeat3 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) od = nd+32+4 self.conv2_0 = conv(od, 128, kernel_size=3, stride=1) - self.conv2_1 = conv(od+dd[0],128, kernel_size=3, stride=1) - self.conv2_2 = conv(od+dd[1],96, kernel_size=3, stride=1) - self.conv2_3 = conv(od+dd[2],64, kernel_size=3, stride=1) - self.conv2_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.conv2_1 = conv(od+dd[0], 128, kernel_size=3, stride=1) + self.conv2_2 = conv(od+dd[1], 96, kernel_size=3, stride=1) + self.conv2_3 = conv(od+dd[2], 64, kernel_size=3, stride=1) + self.conv2_4 = conv(od+dd[3], 32, kernel_size=3, stride=1) self.predict_flow2 = predict_flow(od+dd[4]) self.deconv2 = deconv(2, 2, kernel_size=4, stride=2, padding=1) - self.dc_conv1 = conv(od+dd[4], 128, kernel_size=3, stride=1, padding=1, dilation=1) - self.dc_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) - self.dc_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) - self.dc_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) - self.dc_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) - self.dc_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc_conv1 = conv( + od+dd[4], 128, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc_conv2 = conv(128, 128, kernel_size=3, + stride=1, padding=2, dilation=2) + self.dc_conv3 = conv(128, 128, kernel_size=3, + stride=1, padding=4, dilation=4) + self.dc_conv4 = conv(128, 96, kernel_size=3, + stride=1, padding=8, dilation=8) + self.dc_conv5 = conv(96, 64, kernel_size=3, + stride=1, padding=16, dilation=16) + self.dc_conv6 = conv(64, 32, kernel_size=3, + stride=1, padding=1, dilation=1) self.dc_conv7 = predict_flow(32) for m in self.modules(): @@ -190,8 +199,8 @@ class PWCDCNet(nn.Module): return [param for name, param in self.named_parameters() if 'bias' in name] def forward(self, x): - im1 = x[:,:3,:,:] - im2 = x[:,3:,:,:] + im1 = x[:, :3, :, :] + im2 = x[:, 3:, :, :] c11 = self.conv1b(self.conv1aa(self.conv1a(im1))) c21 = self.conv1b(self.conv1aa(self.conv1a(im2))) @@ -206,74 +215,67 @@ class PWCDCNet(nn.Module): c16 = self.conv6b(self.conv6a(self.conv6aa(c15))) c26 = self.conv6b(self.conv6a(self.conv6aa(c25))) - corr6 = self.corr(c16, c26) corr6 = self.leakyRELU(corr6) - - x = torch.cat((self.conv6_0(corr6), corr6),1) - x = torch.cat((self.conv6_1(x), x),1) - x = torch.cat((self.conv6_2(x), x),1) - x = torch.cat((self.conv6_3(x), x),1) - x = torch.cat((self.conv6_4(x), x),1) + x = torch.cat((self.conv6_0(corr6), corr6), 1) + x = torch.cat((self.conv6_1(x), x), 1) + x = torch.cat((self.conv6_2(x), x), 1) + x = torch.cat((self.conv6_3(x), x), 1) + x = torch.cat((self.conv6_4(x), x), 1) flow6 = self.predict_flow6(x) up_flow6 = self.deconv6(flow6) up_feat6 = self.upfeat6(x) - warp5 = self.warp(c25, up_flow6*0.625) corr5 = self.corr(c15, warp5) corr5 = self.leakyRELU(corr5) x = torch.cat((corr5, c15, up_flow6, up_feat6), 1) - x = torch.cat((self.conv5_0(x), x),1) - x = torch.cat((self.conv5_1(x), x),1) - x = torch.cat((self.conv5_2(x), x),1) - x = torch.cat((self.conv5_3(x), x),1) - x = torch.cat((self.conv5_4(x), x),1) + x = torch.cat((self.conv5_0(x), x), 1) + x = torch.cat((self.conv5_1(x), x), 1) + x = torch.cat((self.conv5_2(x), x), 1) + x = torch.cat((self.conv5_3(x), x), 1) + x = torch.cat((self.conv5_4(x), x), 1) flow5 = self.predict_flow5(x) up_flow5 = self.deconv5(flow5) up_feat5 = self.upfeat5(x) - warp4 = self.warp(c24, up_flow5*1.25) corr4 = self.corr(c14, warp4) corr4 = self.leakyRELU(corr4) x = torch.cat((corr4, c14, up_flow5, up_feat5), 1) - x = torch.cat((self.conv4_0(x), x),1) - x = torch.cat((self.conv4_1(x), x),1) - x = torch.cat((self.conv4_2(x), x),1) - x = torch.cat((self.conv4_3(x), x),1) - x = torch.cat((self.conv4_4(x), x),1) + x = torch.cat((self.conv4_0(x), x), 1) + x = torch.cat((self.conv4_1(x), x), 1) + x = torch.cat((self.conv4_2(x), x), 1) + x = torch.cat((self.conv4_3(x), x), 1) + x = torch.cat((self.conv4_4(x), x), 1) flow4 = self.predict_flow4(x) up_flow4 = self.deconv4(flow4) up_feat4 = self.upfeat4(x) - warp3 = self.warp(c23, up_flow4*2.5) corr3 = self.corr(c13, warp3) corr3 = self.leakyRELU(corr3) - x = torch.cat((corr3, c13, up_flow4, up_feat4), 1) - x = torch.cat((self.conv3_0(x), x),1) - x = torch.cat((self.conv3_1(x), x),1) - x = torch.cat((self.conv3_2(x), x),1) - x = torch.cat((self.conv3_3(x), x),1) - x = torch.cat((self.conv3_4(x), x),1) + x = torch.cat((self.conv3_0(x), x), 1) + x = torch.cat((self.conv3_1(x), x), 1) + x = torch.cat((self.conv3_2(x), x), 1) + x = torch.cat((self.conv3_3(x), x), 1) + x = torch.cat((self.conv3_4(x), x), 1) flow3 = self.predict_flow3(x) up_flow3 = self.deconv3(flow3) up_feat3 = self.upfeat3(x) - warp2 = self.warp(c22, up_flow3*5.0) corr2 = self.corr(c12, warp2) corr2 = self.leakyRELU(corr2) x = torch.cat((corr2, c12, up_flow3, up_feat3), 1) - x = torch.cat((self.conv2_0(x), x),1) - x = torch.cat((self.conv2_1(x), x),1) - x = torch.cat((self.conv2_2(x), x),1) - x = torch.cat((self.conv2_3(x), x),1) - x = torch.cat((self.conv2_4(x), x),1) + x = torch.cat((self.conv2_0(x), x), 1) + x = torch.cat((self.conv2_1(x), x), 1) + x = torch.cat((self.conv2_2(x), x), 1) + x = torch.cat((self.conv2_3(x), x), 1) + x = torch.cat((self.conv2_4(x), x), 1) flow2 = self.predict_flow2(x) x = self.dc_conv4(self.dc_conv3(self.dc_conv2(self.dc_conv1(x)))) @@ -286,10 +288,11 @@ class PWCDCNet(nn.Module): # flow6 = 20*4*self.upsample(flow6) if self.training: - return flow2,flow3,flow4,flow5,flow6 + return flow2, flow3, flow4, flow5, flow6 else: return flow2 + def pwc(data=None): model = PWCDCNet() @@ -300,6 +303,7 @@ def pwc(data=None): model.load_state_dict(data) return model + def pwc_dc_net(path=None): model = PWCDCNet() @@ -319,6 +323,6 @@ def register_pwc_humanflow(opt_net, opt): if __name__ == '__main__': pwc = pwc_dc_net('../../../experiments/pwc_humanflow.pth') - t = torch.randn(1,6,64,64) + t = torch.randn(1, 6, 64, 64) out = pwc(t) print(out.shape) diff --git a/dlas/models/vqvae/dvae.py b/dlas/models/vqvae/dvae.py index 5a0ba08f..f01065e9 100644 --- a/dlas/models/vqvae/dvae.py +++ b/dlas/models/vqvae/dvae.py @@ -1,17 +1,14 @@ import functools -import math from math import sqrt import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from torch import einsum -from models.vqvae.vector_quantizer import VectorQuantize -from models.vqvae.vqvae import Quantize -from trainer.networks import register_model -from utils.util import opt_get +from dlas.models.vqvae.vector_quantizer import VectorQuantize +from dlas.trainer.networks import register_model +from dlas.utils.util import opt_get def eval_decorator(fn): @@ -28,9 +25,9 @@ class ResBlock(nn.Module): def __init__(self, chan, conv, activation): super().__init__() self.net = nn.Sequential( - conv(chan, chan, 3, padding = 1), + conv(chan, chan, 3, padding=1), activation(), - conv(chan, chan, 3, padding = 1), + conv(chan, chan, 3, padding=1), activation(), conv(chan, chan, 1) ) @@ -48,7 +45,8 @@ class UpsampledConv(nn.Module): self.conv = conv(*args, **kwargs) def forward(self, x): - up = nn.functional.interpolate(x, scale_factor=self.stride, mode='nearest') + up = nn.functional.interpolate( + x, scale_factor=self.stride, mode='nearest') return self.conv(up) @@ -56,18 +54,18 @@ class DiscreteVAE(nn.Module): def __init__( self, positional_dims=2, - num_tokens = 512, - codebook_dim = 512, - num_layers = 3, - num_resnet_blocks = 0, - hidden_dim = 64, - channels = 3, - stride = 2, - kernel_size = 3, - activation = 'relu', - straight_through = False, - record_codes = False, - discretization_loss_averaging_steps = 100, + num_tokens=512, + codebook_dim=512, + num_layers=3, + num_resnet_blocks=0, + hidden_dim=64, + channels=3, + stride=2, + kernel_size=3, + activation='relu', + straight_through=False, + record_codes=False, + discretization_loss_averaging_steps=100, quantizer_use_cosine_sim=True, quantizer_codebook_misses_to_expiration=40, quantizer_codebook_embedding_compression=None, @@ -81,7 +79,8 @@ class DiscreteVAE(nn.Module): self.straight_through = straight_through self.positional_dims = positional_dims - assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now. + # This VAE only supports 1d and 2d inputs for now. + assert positional_dims > 0 and positional_dims < 3 if positional_dims == 2: conv = nn.Conv2d conv_transpose = functools.partial(UpsampledConv, conv) @@ -96,7 +95,6 @@ class DiscreteVAE(nn.Module): else: assert NotImplementedError() - enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)] dec_chans = list(reversed(enc_chans)) @@ -105,15 +103,18 @@ class DiscreteVAE(nn.Module): dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0] dec_chans = [dec_init_chan, *dec_chans] - enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans)) + enc_chans_io, dec_chans_io = map(lambda t: list( + zip(t[:-1], t[1:])), (enc_chans, dec_chans)) enc_layers = [] dec_layers = [] pad = (kernel_size - 1) // 2 for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): - enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), act())) - dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), act())) + enc_layers.append(nn.Sequential( + conv(enc_in, enc_out, kernel_size, stride=stride, padding=pad), act())) + dec_layers.append(nn.Sequential(conv_transpose( + dec_in, dec_out, kernel_size, stride=stride, padding=pad), act())) for _ in range(num_resnet_blocks): dec_layers.insert(0, ResBlock(dec_chans[1], conv, act)) @@ -149,7 +150,8 @@ class DiscreteVAE(nn.Module): @torch.no_grad() @eval_decorator def get_codebook_indices(self, images): - logits = self.encoder(images).permute((0,2,3,1) if len(images.shape) == 4 else (0,2,1)) + logits = self.encoder(images).permute( + (0, 2, 3, 1) if len(images.shape) == 4 else (0, 2, 1)) sampled, codes, commitment_loss = self.quantizer(logits) return codes @@ -175,7 +177,8 @@ class DiscreteVAE(nn.Module): return images[-1], images[-2] def infer(self, img): - logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) + logits = self.encoder(img).permute( + (0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1)) sampled, codes, commitment_loss = self.quantizer(logits) return self.decode(codes) @@ -186,9 +189,11 @@ class DiscreteVAE(nn.Module): self, img ): - logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) + logits = self.encoder(img).permute( + (0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1)) sampled, codes, commitment_loss = self.quantizer(logits) - sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1)) + sampled = sampled.permute( + (0, 3, 1, 2) if len(img.shape) == 4 else (0, 2, 1)) if self.training: out = sampled @@ -211,7 +216,8 @@ class DiscreteVAE(nn.Module): if self.record_codes and self.internal_step % 50 == 0: codes = codes.flatten() l = codes.shape[0] - i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l + i = self.code_ind if ( + self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l self.codes[i:i+l] = codes.cpu() self.code_ind = self.code_ind + l if self.code_ind >= self.codes.shape[0]: @@ -225,15 +231,15 @@ def register_dvae(opt_net, opt): if __name__ == '__main__': - #v = DiscreteVAE() - #o=v(torch.randn(1,3,256,256)) - #print(o.shape) + # v = DiscreteVAE() + # o=v(torch.randn(1,3,256,256)) + # print(o.shape) v = DiscreteVAE(channels=80, positional_dims=1, num_tokens=4096, codebook_dim=1024, hidden_dim=512, stride=2, num_resnet_blocks=2, kernel_size=3, num_layers=2, quantizer_codebook_embedding_compression=64) - #v.eval() - loss, commitment, out = v(torch.randn(1,80,256)) + # v.eval() + loss, commitment, out = v(torch.randn(1, 80, 256)) print(out.shape) - codes = v.get_codebook_indices(torch.randn(1,80,256)) + codes = v.get_codebook_indices(torch.randn(1, 80, 256)) back, back_emb = v.decode(codes) print(back.shape) diff --git a/dlas/models/vqvae/gumbel_quantizer.py b/dlas/models/vqvae/gumbel_quantizer.py index e1d95f26..43a9f1d7 100644 --- a/dlas/models/vqvae/gumbel_quantizer.py +++ b/dlas/models/vqvae/gumbel_quantizer.py @@ -3,8 +3,8 @@ import torch.nn as nn import torch.nn.functional as F from torch import einsum -from utils.weight_scheduler import LinearDecayWeightScheduler -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.utils.weight_scheduler import LinearDecayWeightScheduler class GumbelQuantizer(nn.Module): @@ -14,7 +14,8 @@ class GumbelQuantizer(nn.Module): # nn.Embedding self.codebook = ml.Embedding(num_tokens, codebook_dim) self.straight_through = straight_through - self.temperature_scheduler = LinearDecayWeightScheduler(10, 5000, .9, 2000) + self.temperature_scheduler = LinearDecayWeightScheduler( + 10, 5000, .9, 2000) self.step = 0 self.norm = SwitchNorm(num_tokens) @@ -33,27 +34,30 @@ class GumbelQuantizer(nn.Module): if hard: index = y_soft.max(dim, keepdim=True)[1] - y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) + y_hard = torch.zeros_like( + logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) ret = y_hard - y_soft.detach() + y_soft else: ret = y_soft return ret def forward(self, h): - h = h.permute(0,2,1) + h = h.permute(0, 2, 1) logits = self.to_logits(h) - logits = self.gumbel_softmax(logits, tau=self.temperature_scheduler.get_weight_for_step(self.step), dim=1, hard=self.straight_through) + logits = self.gumbel_softmax(logits, tau=self.temperature_scheduler.get_weight_for_step( + self.step), dim=1, hard=self.straight_through) logits = self.norm(logits) codes = logits.argmax(dim=1).flatten(1) sampled = einsum('b n l, n d -> b d l', logits, self.codebook.weight) - return sampled.permute(0,2,1), 0, codes + return sampled.permute(0, 2, 1), 0, codes + if __name__ == '__main__': - j = torch.randn(8,40,1024) + j = torch.randn(8, 40, 1024) m = GumbelQuantizer(1024, 1024, 4096) m2 = DiscreteDecoder(1024, (512, 256), 2) - l=m2(m(j)[0].permute(0,2,1)) + l = m2(m(j)[0].permute(0, 2, 1)) mean = 0 for ls in l: mean = mean + ls.mean() - mean.backward() \ No newline at end of file + mean.backward() diff --git a/dlas/models/vqvae/scaled_weight_conv.py b/dlas/models/vqvae/scaled_weight_conv.py index 1e5d6d54..d745c6fe 100644 --- a/dlas/models/vqvae/scaled_weight_conv.py +++ b/dlas/models/vqvae/scaled_weight_conv.py @@ -1,12 +1,11 @@ -from typing import Optional, List +from typing import List, Optional import torch import torch.nn as nn +import torch.nn.functional as F from torch import Tensor from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd from torch.nn.modules.utils import _ntuple -import torch.nn.functional as F - _pair = _ntuple(2) @@ -14,7 +13,7 @@ _pair = _ntuple(2) # Indexes the

index of input=b,c,h,w,p by the long tensor index=b,1,h,w. Result is b,c,h,w. # Frankly - IMO - this is what torch.gather should do. def index_2d(input, index): - index = index.repeat(1,input.shape[1],1,1) + index = index.repeat(1, input.shape[1], 1, 1) e = torch.eye(input.shape[-1], device=input.device) result = e[index] * input return result.sum(-1) @@ -27,9 +26,9 @@ class ScaledWeightConv(_ConvNd): in_channels: int, out_channels: int, kernel_size, - stride = 1, - padding = 0, - dilation = 1, + stride=1, + padding=0, + dilation=1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', @@ -39,11 +38,14 @@ class ScaledWeightConv(_ConvNd): padding = _pair(padding) dilation = _pair(dilation) super().__init__( - in_channels, out_channels, _pair(kernel_size), stride, padding, dilation, + in_channels, out_channels, _pair( + kernel_size), stride, padding, dilation, False, _pair(0), groups, bias, padding_mode) - self.weight_scales = nn.ParameterList([nn.Parameter(torch.ones(out_channels, in_channels, kernel_size, kernel_size)) for _ in range(breadth)]) - self.shifts = nn.ParameterList([nn.Parameter(torch.zeros(out_channels, in_channels, kernel_size, kernel_size)) for _ in range(breadth)]) + self.weight_scales = nn.ParameterList([nn.Parameter(torch.ones( + out_channels, in_channels, kernel_size, kernel_size)) for _ in range(breadth)]) + self.shifts = nn.ParameterList([nn.Parameter(torch.zeros( + out_channels, in_channels, kernel_size, kernel_size)) for _ in range(breadth)]) for w, s in zip(self.weight_scales, self.shifts): w.FOR_SCALE_SHIFT = True s.FOR_SCALE_SHIFT = True @@ -67,7 +69,8 @@ class ScaledWeightConv(_ConvNd): # This is an exceptionally inefficient way of achieving this functionality. The hope is that if this is any # good at all, this can be made more efficient by performing a single conv pass with multiple masks. - weighted_convs = [self._weighted_conv_forward(input, self.weight * scale + shift) for scale, shift in zip(self.weight_scales, self.shifts)] + weighted_convs = [self._weighted_conv_forward( + input, self.weight * scale + shift) for scale, shift in zip(self.weight_scales, self.shifts)] weighted_convs = torch.stack(weighted_convs, dim=-1) needed_mask = weighted_convs.shape[-2] @@ -97,9 +100,9 @@ class ScaledWeightConvTranspose(_ConvTransposeNd): in_channels: int, out_channels: int, kernel_size, - stride = 1, - padding = 0, - output_padding = 0, + stride=1, + padding=0, + output_padding=0, groups: int = 1, bias: bool = True, dilation: int = 1, @@ -111,11 +114,14 @@ class ScaledWeightConvTranspose(_ConvTransposeNd): dilation = _pair(dilation) output_padding = _pair(output_padding) super().__init__( - in_channels, out_channels, _pair(kernel_size), stride, padding, dilation, + in_channels, out_channels, _pair( + kernel_size), stride, padding, dilation, True, output_padding, groups, bias, padding_mode) - self.weight_scales = nn.ParameterList([nn.Parameter(torch.ones(in_channels, out_channels, kernel_size, kernel_size)) for _ in range(breadth)]) - self.shifts = nn.ParameterList([nn.Parameter(torch.zeros(in_channels, out_channels, kernel_size, kernel_size)) for _ in range(breadth)]) + self.weight_scales = nn.ParameterList([nn.Parameter(torch.ones( + in_channels, out_channels, kernel_size, kernel_size)) for _ in range(breadth)]) + self.shifts = nn.ParameterList([nn.Parameter(torch.zeros( + in_channels, out_channels, kernel_size, kernel_size)) for _ in range(breadth)]) for w, s in zip(self.weight_scales, self.shifts): w.FOR_SCALE_SHIFT = True s.FOR_SCALE_SHIFT = True @@ -125,7 +131,8 @@ class ScaledWeightConvTranspose(_ConvTransposeNd): def _conv_transpose_forward(self, input, weight, output_size) -> Tensor: if self.padding_mode != 'zeros': - raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d') + raise ValueError( + 'Only `zeros` padding mode is supported for ConvTranspose2d') output_padding = self._output_padding( input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) @@ -154,16 +161,16 @@ class ScaledWeightConvTranspose(_ConvTransposeNd): def create_wrapped_conv_transpose_from_template(conv: nn.Conv2d, breadth: int): wrapped = ScaledWeightConvTranspose(conv.in_channels, - conv.out_channels, - conv.kernel_size, - conv.stride, - conv.padding, - conv.output_padding, - conv.groups, - conv.bias, - conv.dilation, - conv.padding_mode, - breadth) + conv.out_channels, + conv.kernel_size, + conv.stride, + conv.padding, + conv.output_padding, + conv.groups, + conv.bias, + conv.dilation, + conv.padding_mode, + breadth) wrapped.weight = conv.weight wrapped.weight.DO_NOT_TRAIN = True wrapped.weight.requires_grad = False diff --git a/dlas/models/vqvae/vector_quantizer.py b/dlas/models/vqvae/vector_quantizer.py index 6d057c25..6e76eb4f 100644 --- a/dlas/models/vqvae/vector_quantizer.py +++ b/dlas/models/vqvae/vector_quantizer.py @@ -1,13 +1,13 @@ import torch -from torch import nn, einsum import torch.nn.functional as F from einops import rearrange, repeat +from torch import nn -from models.arch_util import l2norm, sample_vectors, default, ema_inplace -import torch_intermediary as ml +import dlas.torch_intermediary as ml +from dlas.models.arch_util import default, ema_inplace, l2norm, sample_vectors -def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False): +def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False): dim, dtype, device = samples.shape[-1], samples.dtype, samples.device means = sample_vectors(samples, num_clusters) @@ -16,16 +16,17 @@ def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False): if use_cosine_sim: dists = samples @ means.t() else: - diffs = rearrange(samples, 'n d -> n () d') - rearrange(means, 'c d -> () c d') - dists = -(diffs ** 2).sum(dim = -1) + diffs = rearrange(samples, 'n d -> n () d') - \ + rearrange(means, 'c d -> () c d') + dists = -(diffs ** 2).sum(dim=-1) - buckets = dists.max(dim = -1).indices - bins = torch.bincount(buckets, minlength = num_clusters) + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) zero_mask = bins == 0 bins = bins.masked_fill(zero_mask, 1) - new_means = buckets.new_zeros(num_clusters, dim, dtype = dtype) - new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d = dim), samples) + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples) new_means = new_means / bins[..., None] if use_cosine_sim: @@ -37,15 +38,16 @@ def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False): # distance types + class EuclideanCodebook(nn.Module): def __init__( self, dim, codebook_size, - kmeans_init = False, - kmeans_iters = 10, - decay = 0.8, - eps = 1e-5 + kmeans_init=False, + kmeans_iters=10, + decay=0.8, + eps=1e-5 ): super().__init__() self.decay = decay @@ -68,7 +70,8 @@ class EuclideanCodebook(nn.Module): self.initted.data.copy_(torch.Tensor([True])) def replace(self, samples, mask): - modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed) + modified_codebook = torch.where(mask[..., None], sample_vectors( + samples, self.codebook_size), self.embed) self.embed.data.copy_(modified_codebook) def forward(self, x): @@ -85,7 +88,7 @@ class EuclideanCodebook(nn.Module): + embed.pow(2).sum(0, keepdim=True) ) - embed_ind = dist.max(dim = -1).indices + embed_ind = dist.max(dim=-1).indices embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(x.dtype) embed_ind = embed_ind.view(*shape[:-1]) quantize = F.embedding(embed_ind, self.embed) @@ -94,21 +97,23 @@ class EuclideanCodebook(nn.Module): ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) embed_sum = flatten.t() @ embed_onehot ema_inplace(self.embed_avg, embed_sum.t(), self.decay) - cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum() + cluster_size = laplace_smoothing( + self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum() embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) self.embed.data.copy_(embed_normalized) return quantize, embed_ind + class CosineSimCodebook(nn.Module): def __init__( self, dim, codebook_size, - kmeans_init = False, - kmeans_iters = 10, - decay = 0.8, - eps = 1e-5 + kmeans_init=False, + kmeans_iters=10, + decay=0.8, + eps=1e-5 ): super().__init__() self.decay = decay @@ -126,13 +131,15 @@ class CosineSimCodebook(nn.Module): self.register_buffer('embed', embed) def init_embed_(self, data): - embed = kmeans(data, self.codebook_size, self.kmeans_iters, use_cosine_sim = True) + embed = kmeans(data, self.codebook_size, + self.kmeans_iters, use_cosine_sim=True) self.embed.data.copy_(embed) self.initted.data.copy_(torch.Tensor([True])) def replace(self, samples, mask): samples = l2norm(samples) - modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed) + modified_codebook = torch.where(mask[..., None], sample_vectors( + samples, self.codebook_size), self.embed) self.embed.data.copy_(modified_codebook) def forward(self, x): @@ -145,7 +152,7 @@ class CosineSimCodebook(nn.Module): embed = l2norm(self.embed) dist = flatten @ embed.t() - embed_ind = dist.max(dim = -1).indices + embed_ind = dist.max(dim=-1).indices embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) embed_ind = embed_ind.view(*shape[:-1]) @@ -159,46 +166,50 @@ class CosineSimCodebook(nn.Module): embed_sum = flatten.t() @ embed_onehot embed_normalized = (embed_sum / bins.unsqueeze(0)).t() embed_normalized = l2norm(embed_normalized) - embed_normalized = torch.where(zero_mask[..., None], embed, embed_normalized) + embed_normalized = torch.where( + zero_mask[..., None], embed, embed_normalized) ema_inplace(self.embed, embed_normalized, self.decay) return quantize, embed_ind # main class + class VectorQuantize(nn.Module): def __init__( self, dim, codebook_size, - n_embed = None, - codebook_dim = None, - decay = 0.8, - eps = 1e-5, - kmeans_init = False, - kmeans_iters = 10, - use_cosine_sim = False, - max_codebook_misses_before_expiry = 0 + n_embed=None, + codebook_dim=None, + decay=0.8, + eps=1e-5, + kmeans_init=False, + kmeans_iters=10, + use_cosine_sim=False, + max_codebook_misses_before_expiry=0 ): super().__init__() n_embed = default(n_embed, codebook_size) codebook_dim = default(codebook_dim, dim) requires_projection = codebook_dim != dim - self.project_in = ml.Linear(dim, codebook_dim) if requires_projection else nn.Identity() - self.project_out = ml.Linear(codebook_dim, dim) if requires_projection else nn.Identity() + self.project_in = ml.Linear( + dim, codebook_dim) if requires_projection else nn.Identity() + self.project_out = ml.Linear( + codebook_dim, dim) if requires_projection else nn.Identity() self.eps = eps klass = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook self._codebook = klass( - dim = codebook_dim, - codebook_size = n_embed, - kmeans_init = kmeans_init, - kmeans_iters = kmeans_iters, - decay = decay, - eps = eps + dim=codebook_dim, + codebook_size=n_embed, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + eps=eps ) self.codebook_size = codebook_size @@ -221,7 +232,7 @@ class VectorQuantize(nn.Module): return embed_ind = rearrange(embed_ind, '... -> (...)') - misses = torch.bincount(embed_ind, minlength = self.codebook_size) == 0 + misses = torch.bincount(embed_ind, minlength=self.codebook_size) == 0 self.codebook_misses += misses expired_codes = self.codebook_misses >= self.max_codebook_misses_before_expiry @@ -230,7 +241,7 @@ class VectorQuantize(nn.Module): self.codebook_misses.masked_fill_(expired_codes, 0) batch_samples = rearrange(batch_samples, '... d -> (...) d') - self._codebook.replace(batch_samples, mask = expired_codes) + self._codebook.replace(batch_samples, mask=expired_codes) def forward(self, x): x = self.project_in(x) diff --git a/dlas/models/vqvae/vqvae.py b/dlas/models/vqvae/vqvae.py index 259d2c11..c6d97f66 100644 --- a/dlas/models/vqvae/vqvae.py +++ b/dlas/models/vqvae/vqvae.py @@ -19,13 +19,12 @@ import torch +import torch.distributed as distributed from torch import nn from torch.nn import functional as F -import torch.distributed as distributed - -from trainer.networks import register_model -from utils.util import checkpoint, opt_get +from dlas.trainer.networks import register_model +from dlas.utils.util import checkpoint, opt_get class Quantize(nn.Module): @@ -69,10 +68,12 @@ class Quantize(nn.Module): self.cluster_size.data.mul_(self.decay).add_( embed_onehot_sum, alpha=1 - self.decay ) - self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) + self.embed_avg.data.mul_(self.decay).add_( + embed_sum, alpha=1 - self.decay) n = self.cluster_size.sum() cluster_size = ( - (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n + (self.cluster_size + self.eps) / + (n + self.n_embed * self.eps) * n ) embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) self.embed.data.copy_(embed_normalized) @@ -156,7 +157,8 @@ class Decoder(nn.Module): if stride == 4: blocks.extend( [ - conv_transpose_module(channel, channel // 2, 4, stride=2, padding=1), + conv_transpose_module( + channel, channel // 2, 4, stride=2, padding=1), nn.ReLU(inplace=True), conv_transpose_module( channel // 2, out_channel, 4, stride=2, padding=1 @@ -166,7 +168,8 @@ class Decoder(nn.Module): elif stride == 2: blocks.append( - conv_transpose_module(channel, out_channel, 4, stride=2, padding=1) + conv_transpose_module( + channel, out_channel, 4, stride=2, padding=1) ) self.blocks = nn.Sequential(*blocks) @@ -194,14 +197,17 @@ class VQVAE(nn.Module): in_channel = abs(in_channel) self.codebook_size = codebook_size - self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, conv_module=conv_module) - self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, conv_module=conv_module) + self.enc_b = Encoder(in_channel, channel, n_res_block, + n_res_channel, stride=4, conv_module=conv_module) + self.enc_t = Encoder(channel, channel, n_res_block, + n_res_channel, stride=2, conv_module=conv_module) self.quantize_conv_t = conv_module(channel, codebook_dim, 1) self.quantize_t = Quantize(codebook_dim, codebook_size) self.dec_t = Decoder( codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2, conv_module=conv_module, conv_transpose_module=conv_transpose_module ) - self.quantize_conv_b = conv_module(codebook_dim + channel, codebook_dim, 1) + self.quantize_conv_b = conv_module( + codebook_dim + channel, codebook_dim, 1) self.quantize_b = Quantize(codebook_dim, codebook_size) self.upsample_t = conv_transpose_module( codebook_dim, codebook_dim, 4, stride=2, padding=1 @@ -231,17 +237,21 @@ class VQVAE(nn.Module): enc_b = checkpoint(self.enc_b, input) enc_t = checkpoint(self.enc_t, enc_b) - quant_t = self.quantize_conv_t(enc_t).permute((0,2,3,1) if len(input.shape) == 4 else (0,2,1)) + quant_t = self.quantize_conv_t(enc_t).permute( + (0, 2, 3, 1) if len(input.shape) == 4 else (0, 2, 1)) quant_t, diff_t, id_t = self.quantize_t(quant_t) - quant_t = quant_t.permute((0,3,1,2) if len(input.shape) == 4 else (0,2,1)) + quant_t = quant_t.permute((0, 3, 1, 2) if len( + input.shape) == 4 else (0, 2, 1)) diff_t = diff_t.unsqueeze(0) dec_t = checkpoint(self.dec_t, quant_t) enc_b = torch.cat([dec_t, enc_b], 1) - quant_b = checkpoint(self.quantize_conv_b, enc_b).permute((0,2,3,1) if len(input.shape) == 4 else (0,2,1)) + quant_b = checkpoint(self.quantize_conv_b, enc_b).permute( + (0, 2, 3, 1) if len(input.shape) == 4 else (0, 2, 1)) quant_b, diff_b, id_b = self.quantize_b(quant_b) - quant_b = quant_b.permute((0,3,1,2) if len(input.shape) == 4 else (0,2,1)) + quant_b = quant_b.permute((0, 3, 1, 2) if len( + input.shape) == 4 else (0, 2, 1)) diff_b = diff_b.unsqueeze(0) return quant_t, quant_b, diff_t + diff_b, id_t, id_b @@ -262,9 +272,11 @@ class VQVAE(nn.Module): def decode_code(self, code_t, code_b): quant_t = self.quantize_t.embed_code(code_t) - quant_t = quant_t.permute((0,3,1,2) if len(code_t.shape) == 4 else (0,2,1)) + quant_t = quant_t.permute((0, 3, 1, 2) if len( + code_t.shape) == 4 else (0, 2, 1)) quant_b = self.quantize_b.embed_code(code_b) - quant_b = quant_b.permute((0,3,1,2) if len(code_t.shape) == 4 else (0,2,1)) + quant_b = quant_b.permute((0, 3, 1, 2) if len( + code_t.shape) == 4 else (0, 2, 1)) dec = self.decode(quant_t, quant_b) @@ -273,12 +285,13 @@ class VQVAE(nn.Module): # Performs decode_code() with the outputs from encode_only_quantized. def decode_code_joined(self, input): b, s = input.shape - assert s % 3 == 0 # If not, this tensor didn't come from encode_only_quantized. + # If not, this tensor didn't come from encode_only_quantized. + assert s % 3 == 0 s = s // 3 # This doesn't work with batching. TODO: fixme. - t = input[:,:s] - self.codebook_size - b = input[:,s:] + t = input[:, :s] - self.codebook_size + b = input[:, s:] return self.decode_code(t, b) @@ -299,8 +312,9 @@ def register_vqvae_audio(opt_net, opt): if __name__ == '__main__': - model = VQVAE(in_channel=80, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d) - #res=model(torch.randn(1,80,2048)) + model = VQVAE(in_channel=80, conv_module=nn.Conv1d, + conv_transpose_module=nn.ConvTranspose1d) + # res=model(torch.randn(1,80,2048)) e = model.encode_only_quantized(torch.randn(1, 80, 2048)) k = model.decode_code_joined(e) - print(k.shape) \ No newline at end of file + print(k.shape) diff --git a/dlas/multi_modal_train.py b/dlas/multi_modal_train.py index c82787fa..9bf3c2ae 100644 --- a/dlas/multi_modal_train.py +++ b/dlas/multi_modal_train.py @@ -10,12 +10,12 @@ # state when re-started. import argparse +import torch import yaml -import train -import utils.options as option -from utils.util import OrderedYaml -import torch +import dlas.train +import dlas.utils.options as option +from dlas.utils.util import OrderedYaml def main(master_opt, launcher): @@ -28,7 +28,7 @@ def main(master_opt, launcher): sub_opt_parsed = option.parse(sub_opt, is_train=True) trainer = train.Trainer() - #### distributed training settings + # distributed training settings if launcher == 'none': # disabled distributed training sub_opt_parsed['dist'] = False trainer.rank = -1 @@ -56,8 +56,10 @@ def main(master_opt, launcher): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structured_trans_invariance.yml') - parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', + default='../options/train_exd_imgset_chained_structured_trans_invariance.yml') + parser.add_argument( + '--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/dlas/process_video.py b/dlas/process_video.py index 650c04cf..ee07db68 100644 --- a/dlas/process_video.py +++ b/dlas/process_video.py @@ -11,10 +11,10 @@ import torchvision.transforms.functional as F from PIL import Image from tqdm import tqdm -from trainer.ExtensibleTrainer import ExtensibleTrainer -from utils import options as option -import utils.util as util -from data import create_dataloader +import dlas.utils.util as util +from dlas.data import create_dataloader +from dlas.trainer.ExtensibleTrainer import ExtensibleTrainer +from dlas.utils import options as option class FfmpegBackedVideoDataset(data.Dataset): @@ -34,7 +34,8 @@ class FfmpegBackedVideoDataset(data.Dataset): self.max_working_files = 20 self.data_type = self.opt['data_type'] - self.vertical_splits = self.opt['vertical_splits'] if 'vertical_splits' in opt.keys() else 1 + self.vertical_splits = self.opt['vertical_splits'] if 'vertical_splits' in opt.keys( + ) else 1 def get_time_for_it(self, it): secs = it / self.frame_rate + self.start_at @@ -51,10 +52,13 @@ class FfmpegBackedVideoDataset(data.Dataset): actual_index = index # Extract the frame. Command template: `ffmpeg -ss 17:00.0323 -i