forked from mrq/DL-Art-School
It works!
This commit is contained in:
parent
36c7c1fbdb
commit
341f28dd82
|
@ -1,24 +1,22 @@
|
|||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
from torch import LongTensor
|
||||
from tqdm import tqdm
|
||||
|
||||
import models.tacotron2.layers as layers
|
||||
from models.tacotron2.taco_utils import load_wav_to_torch, load_filepaths_and_text
|
||||
|
||||
from models.tacotron2.text import text_to_sequence
|
||||
from utils.util import opt_get
|
||||
from models.tacotron2.taco_utils import load_filepaths_and_text
|
||||
from models.tacotron2.text import symbols
|
||||
import torch.nn.functional as F
|
||||
from models.tacotron2.text import text_to_sequence
|
||||
|
||||
|
||||
class GptTtsDataset(torch.utils.data.Dataset):
|
||||
NUMBER_SYMBOLS = len(symbols)+3
|
||||
TEXT_START_TOKEN = LongTensor([NUMBER_SYMBOLS-3])
|
||||
TEXT_STOP_TOKEN = LongTensor([NUMBER_SYMBOLS-2])
|
||||
MAX_SYMBOLS_PER_PHRASE = 200
|
||||
NUMBER_SYMBOLS = len(symbols)
|
||||
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2
|
||||
TEXT_START_TOKEN = LongTensor([NUMBER_TEXT_TOKENS-1])
|
||||
TEXT_STOP_TOKEN = LongTensor([NUMBER_TEXT_TOKENS-2])
|
||||
|
||||
def __init__(self, opt):
|
||||
self.path = os.path.dirname(opt['path'])
|
||||
|
@ -49,11 +47,11 @@ class GptTtsDataset(torch.utils.data.Dataset):
|
|||
|
||||
|
||||
class GptTtsCollater():
|
||||
NUMBER_SYMBOLS = len(symbols)+3
|
||||
TEXT_PAD_TOKEN = NUMBER_SYMBOLS-1
|
||||
MAX_SYMBOLS_PER_PHRASE = 200
|
||||
NUMBER_SYMBOLS = len(symbols)
|
||||
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2
|
||||
|
||||
def __init__(self, opt):
|
||||
|
||||
self.MEL_DICTIONARY_SIZE = opt['mel_vocab_size']+3
|
||||
self.MEL_PAD_TOKEN = self.MEL_DICTIONARY_SIZE-1
|
||||
|
||||
|
@ -64,9 +62,13 @@ class GptTtsCollater():
|
|||
max_mel_len = max(mel_lens)
|
||||
texts = []
|
||||
qmels = []
|
||||
# This is the sequential "background" tokens that are used as padding for text tokens, as specified in the DALLE paper.
|
||||
text_range_embedding = torch.arange(max_text_len) + self.NUMBER_SYMBOLS
|
||||
for b in batch:
|
||||
text, qmel, _ = b
|
||||
texts.append(F.pad(text, (0, max_text_len-len(text)), value=self.TEXT_PAD_TOKEN))
|
||||
text = F.pad(text, (0, max_text_len-len(text)), value=0)
|
||||
text = torch.where(text == 0, text_range_embedding, text)
|
||||
texts.append(text)
|
||||
qmels.append(F.pad(qmel, (0, max_mel_len-len(qmel)), value=self.MEL_PAD_TOKEN))
|
||||
|
||||
filenames = [j[2] for j in batch]
|
||||
|
|
|
@ -1,48 +1,37 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from munch import munchify
|
||||
from torch import LongTensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from models.arch_util import ConvGnSilu
|
||||
from models.gpt_voice.pixelshuffle_1d import PixelUnshuffle1D, PixelShuffle1D
|
||||
from models.tacotron2 import hparams
|
||||
from models.gpt_voice.lucidrains_gpt import Transformer
|
||||
from models.tacotron2.taco_utils import get_mask_from_lengths
|
||||
from models.tacotron2.tacotron2 import Postnet
|
||||
from models.tacotron2.text import symbols
|
||||
from models.gpt_voice.min_gpt import GPT, GPTConfig
|
||||
from trainer.networks import register_model
|
||||
|
||||
|
||||
class GptTts(nn.Module):
|
||||
NUMBER_SYMBOLS = len(symbols)+3
|
||||
TEXT_START_TOKEN = NUMBER_SYMBOLS-3
|
||||
TEXT_STOP_TOKEN = NUMBER_SYMBOLS-2
|
||||
TEXT_PAD_TOKEN = NUMBER_SYMBOLS-1
|
||||
MAX_SYMBOLS_PER_PHRASE = 200
|
||||
NUMBER_SYMBOLS = len(symbols)
|
||||
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2
|
||||
MEL_DICTIONARY_SIZE = 512+3
|
||||
MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3
|
||||
MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2
|
||||
MEL_PAD_TOKEN = MEL_DICTIONARY_SIZE-1
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
model_dim = 512
|
||||
max_symbols_per_phrase = 200
|
||||
max_mel_frames = 900 * 3 // 8 # The VQVAE outputs 3/8 of the input mel as tokens.
|
||||
mel_dim=80
|
||||
max_mel_frames = 900 * 3 // 8 # 900 is the max number of MEL frames. The VQVAE outputs 3/8 of the input mel as tokens.
|
||||
|
||||
self.model_dim = model_dim
|
||||
self.max_mel_frames = max_mel_frames
|
||||
self.text_embedding = nn.Embedding(self.NUMBER_SYMBOLS, model_dim)
|
||||
self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim)
|
||||
self.mel_embedding = nn.Embedding(self.MEL_DICTIONARY_SIZE, model_dim)
|
||||
# *_tags are additively applied to
|
||||
self.text_pos_embedding = nn.Embedding(max_symbols_per_phrase, model_dim)
|
||||
self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE, model_dim)
|
||||
self.mel_pos_embedding = nn.Embedding(max_mel_frames, model_dim)
|
||||
self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames, n_embd=model_dim, n_head=8), do_pos_emb=False)
|
||||
#self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames, n_embd=model_dim, n_head=8), do_pos_emb=False)
|
||||
self.gpt = Transformer(dim=model_dim, depth=8, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=8)
|
||||
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = nn.Linear(model_dim, self.NUMBER_SYMBOLS)
|
||||
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
||||
self.mel_head = nn.Linear(model_dim, self.MEL_DICTIONARY_SIZE)
|
||||
|
||||
def forward(self, text_inputs, text_lengths, mel_targets, output_lengths):
|
||||
|
@ -55,8 +44,8 @@ class GptTts(nn.Module):
|
|||
|
||||
# Compute logits for text and mel heads
|
||||
text_logits = self.final_norm(enc[:, :text_emb.shape[1]])
|
||||
text_logits = self.text_head(text_logits)
|
||||
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
|
||||
text_logits = self.text_head(text_logits)
|
||||
mel_logits = self.mel_head(mel_logits)
|
||||
|
||||
# Compute loss
|
||||
|
@ -67,25 +56,18 @@ class GptTts(nn.Module):
|
|||
mel_logits = mel_logits.permute(0,2,1)[:,:,:-1]
|
||||
loss_mel = F.cross_entropy(mel_logits, mel_targets, reduction='none')
|
||||
|
||||
# Apply a reduction factor across MEL_PAD and TEXT_PAD tokens.
|
||||
pad_loss_reduction_factor = .01
|
||||
text_pad_mask = ~get_mask_from_lengths(text_lengths-1, text_inputs.shape[1]-1) # -1 to strip off <BOS>, which is accounted for in text_lengths and output_lengths.
|
||||
mel_pad_mask = ~get_mask_from_lengths(output_lengths-1, mel_targets.shape[1])
|
||||
loss_text = loss_text * torch.ones_like(loss_text).masked_fill_(text_pad_mask, pad_loss_reduction_factor)
|
||||
loss_mel = loss_mel * torch.ones_like(loss_mel).masked_fill_(mel_pad_mask, pad_loss_reduction_factor)
|
||||
|
||||
# Fix up mel_logits so it can go into a VAE decoder as well.
|
||||
mel_codes = torch.argmax(F.softmax(mel_logits, dim=1), dim=1)
|
||||
mel_pad_mask = ~get_mask_from_lengths(output_lengths-1, mel_targets.shape[1])
|
||||
mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask, 0)
|
||||
mel_codes = mel_codes[:,:
|
||||
|
||||
|
||||
|
||||
-1] # Strip off <EOS> token too (or padding). The important part is that the output sequence length is identical to the VAE input.
|
||||
mel_codes = mel_codes[:,:-1] # Strip off <EOS> token too (or padding). The important part is that the output sequence length is identical to the VAE input.
|
||||
extra_mask = mel_codes < self.MEL_DICTIONARY_SIZE-3 # The VAE doesn't know about START/STOP/PAD
|
||||
mel_codes = mel_codes * extra_mask
|
||||
|
||||
return loss_text.mean(), loss_mel.mean(), mel_codes
|
||||
# This class also returns the mel_targets for validation purposes. Format those.
|
||||
mel_targets = mel_targets[:,:-1]
|
||||
mel_targets = mel_targets * (mel_targets < self.MEL_DICTIONARY_SIZE-3)
|
||||
return loss_text.mean(), loss_mel.mean(), mel_codes, mel_targets
|
||||
|
||||
def inference(self, text_inputs):
|
||||
text_emb = self.text_embedding(text_inputs)
|
||||
|
|
193
codes/models/gpt_voice/lucidrains_gpt.py
Normal file
193
codes/models/gpt_voice/lucidrains_gpt.py
Normal file
|
@ -0,0 +1,193 @@
|
|||
from inspect import isfunction
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
# helpers
|
||||
from models.gpt_voice.reversible import ReversibleSequence, SequentialSequence
|
||||
from utils.util import sequential_checkpoint
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
def cast_tuple(val, depth = 1):
|
||||
if isinstance(val, list):
|
||||
val = tuple(val)
|
||||
return val if isinstance(val, tuple) else (val,) * depth
|
||||
|
||||
|
||||
class DivideMax(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
maxes = x.amax(dim = self.dim, keepdim = True)
|
||||
return x / maxes
|
||||
|
||||
|
||||
# https://arxiv.org/abs/2103.17239
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(self, dim, depth, fn):
|
||||
super().__init__()
|
||||
if depth <= 18:
|
||||
init_eps = 0.1
|
||||
elif depth > 18 and depth <= 24:
|
||||
init_eps = 1e-5
|
||||
else:
|
||||
init_eps = 1e-6
|
||||
|
||||
scale = torch.zeros(1, 1, dim).fill_(init_eps)
|
||||
self.scale = nn.Parameter(scale)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(x, **kwargs) * self.scale
|
||||
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
def forward(self, x):
|
||||
x, gates = x.chunk(2, dim = -1)
|
||||
return x * F.gelu(gates)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dropout = 0., mult = 4.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, dim * mult * 2),
|
||||
GEGLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim * mult, dim)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
|
||||
t = t / alpha
|
||||
t = t - torch.amax(t, dim = dim, keepdim = True)
|
||||
return (t * alpha).softmax(dim = dim)
|
||||
|
||||
|
||||
# classes
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.seq_len = seq_len
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.stable = stable
|
||||
self.causal = causal
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
b, n, _, h, device = *x.shape, self.heads, x.device
|
||||
softmax = torch.softmax if not self.stable else stable_softmax
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
mask_value = max_neg_value(dots)
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b () () j')
|
||||
dots.masked_fill_(~mask, mask_value)
|
||||
del mask
|
||||
|
||||
if self.causal:
|
||||
i, j = dots.shape[-2:]
|
||||
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
|
||||
dots.masked_fill_(mask, mask_value)
|
||||
|
||||
attn = softmax(dots, dim=-1)
|
||||
|
||||
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
seq_len,
|
||||
reversible = False,
|
||||
causal = True,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
ff_mult = 4,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
sparse_attn = False,
|
||||
stable = False
|
||||
):
|
||||
super().__init__()
|
||||
layers = nn.ModuleList([])
|
||||
sparse_layer = cast_tuple(sparse_attn, depth)
|
||||
|
||||
for ind, sparse_attn in zip(range(depth), sparse_layer):
|
||||
attn = Attention(dim, stable=stable, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
|
||||
|
||||
layers.append(nn.ModuleList([
|
||||
LayerScale(dim, ind + 1, PreNorm(dim, attn)),
|
||||
LayerScale(dim, ind + 1, PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout)))
|
||||
]))
|
||||
|
||||
execute_type = ReversibleSequence if reversible else SequentialSequence
|
||||
route_attn = ((True, False),) * depth
|
||||
attn_route_map = {'mask': route_attn}
|
||||
|
||||
self.layers = execute_type(layers, args_route = attn_route_map)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
156
codes/models/gpt_voice/reversible.py
Normal file
156
codes/models/gpt_voice/reversible.py
Normal file
|
@ -0,0 +1,156 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd.function import Function
|
||||
from torch.utils.checkpoint import get_device_states, set_device_states
|
||||
|
||||
# for routing arguments into the functions of the reversible layer
|
||||
def route_args(router, args, depth):
|
||||
routed_args = [(dict(), dict()) for _ in range(depth)]
|
||||
matched_keys = [key for key in args.keys() if key in router]
|
||||
|
||||
for key in matched_keys:
|
||||
val = args[key]
|
||||
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
|
||||
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
|
||||
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
|
||||
return routed_args
|
||||
|
||||
# following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
|
||||
class Deterministic(nn.Module):
|
||||
def __init__(self, net):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
self.cpu_state = None
|
||||
self.cuda_in_fwd = None
|
||||
self.gpu_devices = None
|
||||
self.gpu_states = None
|
||||
|
||||
def record_rng(self, *args):
|
||||
self.cpu_state = torch.get_rng_state()
|
||||
if torch.cuda._initialized:
|
||||
self.cuda_in_fwd = True
|
||||
self.gpu_devices, self.gpu_states = get_device_states(*args)
|
||||
|
||||
def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
|
||||
if record_rng:
|
||||
self.record_rng(*args)
|
||||
|
||||
if not set_rng:
|
||||
return self.net(*args, **kwargs)
|
||||
|
||||
rng_devices = []
|
||||
if self.cuda_in_fwd:
|
||||
rng_devices = self.gpu_devices
|
||||
|
||||
with torch.random.fork_rng(devices=rng_devices, enabled=True):
|
||||
torch.set_rng_state(self.cpu_state)
|
||||
if self.cuda_in_fwd:
|
||||
set_device_states(self.gpu_devices, self.gpu_states)
|
||||
return self.net(*args, **kwargs)
|
||||
|
||||
# heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
|
||||
# once multi-GPU is confirmed working, refactor and send PR back to source
|
||||
class ReversibleBlock(nn.Module):
|
||||
def __init__(self, f, g):
|
||||
super().__init__()
|
||||
self.f = Deterministic(f)
|
||||
self.g = Deterministic(g)
|
||||
|
||||
def forward(self, x, f_args = {}, g_args = {}):
|
||||
x1, x2 = torch.chunk(x, 2, dim=2)
|
||||
y1, y2 = None, None
|
||||
|
||||
with torch.no_grad():
|
||||
y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
|
||||
y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
|
||||
|
||||
return torch.cat([y1, y2], dim=2)
|
||||
|
||||
def backward_pass(self, y, dy, f_args = {}, g_args = {}):
|
||||
y1, y2 = torch.chunk(y, 2, dim=2)
|
||||
del y
|
||||
|
||||
dy1, dy2 = torch.chunk(dy, 2, dim=2)
|
||||
del dy
|
||||
|
||||
with torch.enable_grad():
|
||||
y1.requires_grad = True
|
||||
gy1 = self.g(y1, set_rng=True, **g_args)
|
||||
torch.autograd.backward(gy1, dy2)
|
||||
|
||||
with torch.no_grad():
|
||||
x2 = y2 - gy1
|
||||
del y2, gy1
|
||||
|
||||
dx1 = dy1 + y1.grad
|
||||
del dy1
|
||||
y1.grad = None
|
||||
|
||||
with torch.enable_grad():
|
||||
x2.requires_grad = True
|
||||
fx2 = self.f(x2, set_rng=True, **f_args)
|
||||
torch.autograd.backward(fx2, dx1, retain_graph=True)
|
||||
|
||||
with torch.no_grad():
|
||||
x1 = y1 - fx2
|
||||
del y1, fx2
|
||||
|
||||
dx2 = dy2 + x2.grad
|
||||
del dy2
|
||||
x2.grad = None
|
||||
|
||||
x = torch.cat([x1, x2.detach()], dim=2)
|
||||
dx = torch.cat([dx1, dx2], dim=2)
|
||||
|
||||
return x, dx
|
||||
|
||||
class _ReversibleFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, blocks, args):
|
||||
ctx.args = args
|
||||
for block, kwarg in zip(blocks, args):
|
||||
x = block(x, **kwarg)
|
||||
ctx.y = x.detach()
|
||||
ctx.blocks = blocks
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
y = ctx.y
|
||||
args = ctx.args
|
||||
for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
|
||||
y, dy = block.backward_pass(y, dy, **kwargs)
|
||||
return dy, None, None
|
||||
|
||||
class SequentialSequence(nn.Module):
|
||||
def __init__(self, layers, args_route = {}, layer_dropout = 0.):
|
||||
super().__init__()
|
||||
assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
|
||||
self.layers = layers
|
||||
self.args_route = args_route
|
||||
self.layer_dropout = layer_dropout
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
args = route_args(self.args_route, kwargs, len(self.layers))
|
||||
layers_and_args = list(zip(self.layers, args))
|
||||
|
||||
for (f, g), (f_args, g_args) in layers_and_args:
|
||||
x = x + f(x, **f_args)
|
||||
x = x + g(x, **g_args)
|
||||
return x
|
||||
|
||||
class ReversibleSequence(nn.Module):
|
||||
def __init__(self, blocks, args_route = {}):
|
||||
super().__init__()
|
||||
self.args_route = args_route
|
||||
self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
x = torch.cat([x, x], dim=-1)
|
||||
|
||||
blocks = self.blocks
|
||||
args = route_args(self.args_route, kwargs, len(blocks))
|
||||
args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))
|
||||
|
||||
out = _ReversibleFunction.apply(x, blocks, args)
|
||||
return torch.stack(out.chunk(2, dim=-1)).mean(dim=0)
|
|
@ -18,7 +18,8 @@ def get_scheduler_for_name(name, optimizers, scheduler_opt):
|
|||
weights=scheduler_opt['restart_weights'],
|
||||
gamma=scheduler_opt['lr_gamma'],
|
||||
clear_state=scheduler_opt['clear_state'],
|
||||
force_lr=scheduler_opt['force_lr'])
|
||||
force_lr=scheduler_opt['force_lr'],
|
||||
warmup_steps=scheduler_opt['warmup_steps'])
|
||||
elif name == 'ProgressiveMultiStepLR':
|
||||
sched = ProgressiveMultiStepLR(o, scheduler_opt['gen_lr_steps'],
|
||||
scheduler_opt['progressive_starts'],
|
||||
|
@ -55,7 +56,7 @@ class ProgressiveMultiStepLR(_LRScheduler):
|
|||
|
||||
class MultiStepLR_Restart(_LRScheduler):
|
||||
def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
|
||||
clear_state=False, force_lr=False, last_epoch=-1):
|
||||
clear_state=False, force_lr=False, last_epoch=-1, warmup_steps=0):
|
||||
self.milestones = Counter(milestones)
|
||||
self.gamma = gamma
|
||||
self.clear_state = clear_state
|
||||
|
@ -63,11 +64,13 @@ class MultiStepLR_Restart(_LRScheduler):
|
|||
self.restarts = [v + 1 for v in self.restarts]
|
||||
self.restart_weights = weights if weights else [1]
|
||||
self.force_lr = force_lr
|
||||
self.warmup_steps = warmup_steps
|
||||
assert len(self.restarts) == len(
|
||||
self.restart_weights), 'restarts and their weights do not match.'
|
||||
super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
# Note to self: for the purposes of this trainer, "last_epoch" should read "last_step"
|
||||
if self.force_lr:
|
||||
return [group['initial_lr'] for group in self.optimizer.param_groups]
|
||||
if self.last_epoch in self.restarts:
|
||||
|
@ -75,6 +78,9 @@ class MultiStepLR_Restart(_LRScheduler):
|
|||
self.optimizer.state = defaultdict(dict)
|
||||
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
|
||||
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
|
||||
if self.last_epoch < self.warmup_steps:
|
||||
factor = 1 - (self.warmup_steps - self.last_epoch) / self.warmup_steps
|
||||
return [group['initial_lr'] * factor for group in self.optimizer.param_groups]
|
||||
if self.last_epoch not in self.milestones:
|
||||
return [group['lr'] for group in self.optimizer.param_groups]
|
||||
return [
|
||||
|
@ -148,8 +154,8 @@ if __name__ == "__main__":
|
|||
restart_weights = [1, 1, 1]
|
||||
|
||||
scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
|
||||
clear_state=False)
|
||||
|
||||
clear_state=False, warmup_steps=20000)
|
||||
'''
|
||||
##############################
|
||||
# Cosine Annealing Restart
|
||||
##############################
|
||||
|
@ -165,11 +171,12 @@ if __name__ == "__main__":
|
|||
|
||||
scheduler = CosineAnnealingLR_Restart(optimizer, T_period, warmup=10000, eta_min=1e-8, restarts=restarts,
|
||||
weights=restart_weights)
|
||||
'''
|
||||
|
||||
##############################
|
||||
# Draw figure
|
||||
##############################
|
||||
N_iter = 500000
|
||||
N_iter = 100000
|
||||
lr_l = list(range(N_iter))
|
||||
for i in range(N_iter):
|
||||
scheduler.step()
|
||||
|
|
Loading…
Reference in New Issue
Block a user