diff --git a/tortoise_tts/config.py b/tortoise_tts/config.py
index 8c718cc..29f179d 100755
--- a/tortoise_tts/config.py
+++ b/tortoise_tts/config.py
@@ -20,6 +20,8 @@ from .utils.distributed import world_size
# Yuck
from transformers import PreTrainedTokenizerFast
+from tokenizers import Tokenizer
+
@dataclass()
class BaseConfig:
@@ -494,6 +496,177 @@ class Inference:
return torch.float8_e4m3fn
return torch.float32
+import inflect
+import re
+
+# Regular expression matching whitespace:
+from unidecode import unidecode
+
+_whitespace_re = re.compile(r'\s+')
+
+# List of (regular expression, replacement) pairs for abbreviations:
+_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
+ ('mrs', 'misess'),
+ ('mr', 'mister'),
+ ('dr', 'doctor'),
+ ('st', 'saint'),
+ ('co', 'company'),
+ ('jr', 'junior'),
+ ('maj', 'major'),
+ ('gen', 'general'),
+ ('drs', 'doctors'),
+ ('rev', 'reverend'),
+ ('lt', 'lieutenant'),
+ ('hon', 'honorable'),
+ ('sgt', 'sergeant'),
+ ('capt', 'captain'),
+ ('esq', 'esquire'),
+ ('ltd', 'limited'),
+ ('col', 'colonel'),
+ ('ft', 'fort'),
+]]
+
+
+def expand_abbreviations(text):
+ for regex, replacement in _abbreviations:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+_inflect = inflect.engine()
+_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
+_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
+_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
+_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
+_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
+_number_re = re.compile(r'[0-9]+')
+
+
+def _remove_commas(m):
+ return m.group(1).replace(',', '')
+
+
+def _expand_decimal_point(m):
+ return m.group(1).replace('.', ' point ')
+
+
+def _expand_dollars(m):
+ match = m.group(1)
+ parts = match.split('.')
+ if len(parts) > 2:
+ return match + ' dollars' # Unexpected format
+ dollars = int(parts[0]) if parts[0] else 0
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
+ if dollars and cents:
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
+ cent_unit = 'cent' if cents == 1 else 'cents'
+ return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
+ elif dollars:
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
+ return '%s %s' % (dollars, dollar_unit)
+ elif cents:
+ cent_unit = 'cent' if cents == 1 else 'cents'
+ return '%s %s' % (cents, cent_unit)
+ else:
+ return 'zero dollars'
+
+
+def _expand_ordinal(m):
+ return _inflect.number_to_words(m.group(0))
+
+
+def _expand_number(m):
+ num = int(m.group(0))
+ if num > 1000 and num < 3000:
+ if num == 2000:
+ return 'two thousand'
+ elif num > 2000 and num < 2010:
+ return 'two thousand ' + _inflect.number_to_words(num % 100)
+ elif num % 100 == 0:
+ return _inflect.number_to_words(num // 100) + ' hundred'
+ else:
+ return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
+ else:
+ return _inflect.number_to_words(num, andword='')
+
+
+def normalize_numbers(text):
+ text = re.sub(_comma_number_re, _remove_commas, text)
+ text = re.sub(_pounds_re, r'\1 pounds', text)
+ text = re.sub(_dollars_re, _expand_dollars, text)
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
+ text = re.sub(_number_re, _expand_number, text)
+ return text
+
+
+def expand_numbers(text):
+ return normalize_numbers(text)
+
+
+def lowercase(text):
+ return text.lower()
+
+
+def collapse_whitespace(text):
+ return re.sub(_whitespace_re, ' ', text)
+
+
+def convert_to_ascii(text):
+ return unidecode(text)
+
+
+def basic_cleaners(text):
+ '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
+ text = lowercase(text)
+ text = collapse_whitespace(text)
+ return text
+
+
+def transliteration_cleaners(text):
+ '''Pipeline for non-English text that transliterates to ASCII.'''
+ text = convert_to_ascii(text)
+ text = lowercase(text)
+ text = collapse_whitespace(text)
+ return text
+
+
+def english_cleaners(text):
+ '''Pipeline for English text, including number and abbreviation expansion.'''
+ text = convert_to_ascii(text)
+ text = lowercase(text)
+ text = expand_numbers(text)
+ text = expand_abbreviations(text)
+ text = collapse_whitespace(text)
+ text = text.replace('"', '')
+ return text
+
+class VoiceBpeTokenizer:
+ def __init__(self, tokenizer_file=None):
+ if tokenizer_file is not None:
+ self.tokenizer = Tokenizer.from_file(tokenizer_file)
+
+ def preprocess_text(self, txt):
+ txt = english_cleaners(txt)
+ return txt
+
+ def encode(self, txt):
+ txt = self.preprocess_text(txt)
+ txt = txt.replace(' ', '[SPACE]')
+ return self.tokenizer.encode(txt).ids
+
+ def decode(self, seq):
+ if isinstance(seq, torch.Tensor):
+ seq = seq.cpu().numpy()
+ txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(' ', '')
+ txt = txt.replace('[SPACE]', ' ')
+ txt = txt.replace('[STOP]', '')
+ txt = txt.replace('[UNK]', '')
+ return txt
+
+ def get_vocab(self):
+ return self.tokenizer.get_vocab()
+
# should be renamed to optimizations
@dataclass()
class Optimizations:
@@ -667,39 +840,16 @@ class Config(BaseConfig):
# load tokenizer
try:
from transformers import PreTrainedTokenizerFast
- cfg.tokenizer = (cfg.rel_path if cfg.yaml_path is not None else Path("./data/")) / cfg.tokenizer
- cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(cfg.tokenizer))
+ #cfg.tokenizer = (cfg.rel_path if cfg.yaml_path is not None else Path("./data/")) / cfg.tokenizer
+ tokenizer_path = cfg.rel_path / cfg.tokenizer
+ if not tokenizer_path.exists():
+ tokenizer_path = Path("./data/") / cfg.tokenizer
+
+ #cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
+ cfg.tokenizer = VoiceBpeTokenizer(tokenizer_file=str(tokenizer_path))
except Exception as e:
- cfg.tokenizer = NaiveTokenizer()
print("Error while parsing tokenizer:", e)
- pass
-
-
-# Preserves the old behavior
-class NaiveTokenizer:
- def get_vocab( self ):
- """
- if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
- return json.loads( cfg.hdf5['symmap'].asstr()[()] )
- """
- return {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '”': 179, '“': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, ';ˌ': 184, ':ˈ': 185, '1': 186, 'rˈ': 187, 'qˈ': 188, 'ᵻˌ': 189, 'ä': 190, '̞ˌ': 191, '̞': 192, 'ũˌ': 193, 'ʑˌ': 194, 'ᵝ': 195, 'ɽ': 196, 'ʲˌ': 197, 'ᵝˌ': 198, 'ũ': 199, 'ũˈ': 200, 'äˌ': 201, 'ɕ': 202, 'ɕˌ': 203, 'ɽˌ': 204, 'çˌ': 205, '…ˌ': 206, '̞ˈ': 207, 'äˈ': 208, 'ɽˈ': 209, 'ɸˌ': 210, 'ɴ': 211, 'ɸˈ': 212, 'ɕˈ': 213, 'ɸ': 214, 'ᵝˈ': 215, 'ʲˈ': 216, 'ĩ': 217, 'çˈ': 218, 'ĩˌ': 219, 'oˌ': 220, 'eˈ': 221, 'ʍ': 222, 'eˌ': 223, 'uˌ': 224, 'ʍˌ': 225, 'uˈ': 226, 'oˈ': 227, 'aˈ': 228}
-
- def encode( self, s ):
- symmap = self.get_vocab()
- phones = " ".join( list(s) )
-
- # do merge
- for merge in [ "\u02C8", "\u02CC", "\u02D0" ]:
- phones = phones.replace( f' {merge}', merge )
-
- phones = phones.split(" ")
- # cleanup
- phones = [ p for i, p in enumerate(phones) if p not in [" "] or ( p in [" "] and p != phones[i-1] ) ]
- # add bos / eos
- phones = [""] + [ " " if not p else p for p in phones ] + [""]
- # tokenize
- return [*map(symmap.get, phones)]
-
+ raise e
cfg = Config.from_cli()
diff --git a/tortoise_tts/data.py b/tortoise_tts/data.py
index 56227f4..4b15614 100755
--- a/tortoise_tts/data.py
+++ b/tortoise_tts/data.py
@@ -431,7 +431,7 @@ class Dataset(_Dataset):
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
- mel = torch.from_numpy(cfg.hdf5[key]["audio"][:]).to(torch.int16)
+ mel = torch.from_numpy(cfg.hdf5[key]["audio"]).to(torch.int16)
else:
mel = _load_mels(path, return_metadata=False)
return mel
@@ -497,22 +497,22 @@ class Dataset(_Dataset):
if key not in cfg.hdf5:
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
- try:
- text = cfg.hdf5[key]["text"][:]
- mel = cfg.hdf5[key]["audio"][:]
- latents = cfg.hdf5[key]["latents"][:]
- except Exception as e:
- print( key, cfg.hdf5[key].keys() )
- raise e
+ text = cfg.hdf5[key]["text"][:]
+ mel = cfg.hdf5[key]["audio"][:]
+ conds = (cfg.hdf5[key]["conds_0"][:], cfg.hdf5[key]["conds_1"][:])
+ latents = (cfg.hdf5[key]["latents_0"][:], cfg.hdf5[key]["latents_1"][:])
text = torch.from_numpy(text).to(self.text_dtype)
mel = torch.from_numpy(mel).to(torch.int16)
- latents = torch.from_numpy(latents)
+ conds = (torch.from_numpy(conds[0]), torch.from_numpy(conds[1]))
+ latents = (torch.from_numpy(latents[0]), torch.from_numpy(latents[1]))
+
wav_length = cfg.hdf5[key].attrs["wav_length"]
else:
mel, metadata = _load_mels(path, return_metadata=True)
text = torch.tensor(metadata["text"]).to(self.text_dtype)
- latents = torch.from_numpy(metadata["latent"][0])
+ conds = (torch.from_numpy(metadata["conds"][0]), torch.from_numpy(metadata["conds"][1]))
+ latents = (torch.from_numpy(metadata["latent"][0]), torch.from_numpy(metadata["latent"][1]))
wav_length = metadata["wav_length"]
return dict(
@@ -521,7 +521,12 @@ class Dataset(_Dataset):
spkr_name=spkr_name,
spkr_id=spkr_id,
- latents=latents,
+ latents_0=latents[0][0],
+ latents_1=latents[1][0],
+
+ conds_0=conds[0][0, 0],
+ conds_1=conds[1][0, 0],
+
text=text,
mel=mel,
wav_length=wav_length,
@@ -612,9 +617,10 @@ def create_train_val_dataloader():
return train_dl, subtrain_dl, val_dl
def unpack_audio( npz ):
- mel = npz["codes"].to(dtype=torch.int16, device="cpu")
- conds = npz["conds"][0].to(dtype=torch.int16, device="cpu")
- latent = npz["latent"][0].to(dtype=torch.int16, device="cpu")
+ mel = npz["codes"].to(device="cpu")
+
+ conds = npz["conds"][0].to(device="cpu"), npz["conds"][1].to(device="cpu")
+ latent = npz["latent"][0].to(device="cpu"), npz["latent"][1].to(device="cpu")
metadata = {}
@@ -774,13 +780,15 @@ def create_dataset_hdf5( skip_existing=True ):
mel, conds, latents, utterance_metadata = unpack_audio( npz )
if "audio" not in group:
- group.create_dataset('audio', data=mel.numpy().astype(np.int16), compression='lzf')
+ group.create_dataset('audio', data=mel.numpy(), compression='lzf')
- if "conds" not in group:
- group.create_dataset('conds', data=conds.numpy().astype(np.int16), compression='lzf')
-
- if "latents" not in group:
- group.create_dataset('latents', data=latents.numpy().astype(np.int16), compression='lzf')
+ for i, cond in enumerate(conds):
+ if f"conds_{i}" not in group:
+ group.create_dataset(f'conds_{i}', data=cond.numpy(), compression='lzf')
+
+ for i, latent in enumerate(latents):
+ if f"latents_{i}" not in group:
+ group.create_dataset(f'latents_{i}', data=latent.numpy(), compression='lzf')
# text
if texts:
@@ -859,14 +867,21 @@ if __name__ == "__main__":
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
samples = {
- "training": [ next(iter(train_dl)), next(iter(train_dl)) ],
- #"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
- #"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
+ "training": next(iter(train_dl)),
+ #"evaluation": next(iter(subtrain_dl)),
+ #"validation": next(iter(val_dl)),
}
+
+ for sample_name, sample_batch in samples.items():
+ for name, batch in sample_batch.items():
+ #print( name, [ x.shape if hasattr(x, "shape") else x for x in batch ] )
+ print( name, [ x for x in batch ] )
+ """
for k, v in samples.items():
for i in range(len(v)):
print(f'{k}[{i}]:', v[i])
+ """
elif args.action == "tasks":
index = 0
diff --git a/tortoise_tts/models/diffusion.py b/tortoise_tts/models/diffusion.py
index 2a41976..00b7737 100644
--- a/tortoise_tts/models/diffusion.py
+++ b/tortoise_tts/models/diffusion.py
@@ -1,14 +1,1263 @@
+import enum
import math
import random
+from tqdm import tqdm
from abc import abstractmethod
import torch
import torch.nn as nn
import torch.nn.functional as F
+import numpy as np
+
from torch import autocast
from .arch_utils import normalization, AttentionBlock
+
+"""
+This is an almost carbon copy of gaussian_diffusion.py from OpenAI's ImprovedDiffusion repo, which itself:
+
+This code started out as a PyTorch port of Ho et al's diffusion models:
+https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
+
+Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
+"""
+
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ Compute the KL divergence between two gaussians.
+
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
+
+
+def approx_standard_normal_cdf(x):
+ """
+ A fast approximation of the cumulative distribution function of the
+ standard normal.
+ """
+ return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))
+
+
+def discretized_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
+ given image.
+
+ :param x: the target images. It is assumed that this was uint8 values,
+ rescaled to the range [-1, 1].
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ assert x.shape == means.shape == log_scales.shape
+ centered_x = x - means
+ inv_stdv = torch.exp(-log_scales)
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
+ cdf_plus = approx_standard_normal_cdf(plus_in)
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
+ cdf_min = approx_standard_normal_cdf(min_in)
+ log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
+ log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
+ cdf_delta = cdf_plus - cdf_min
+ log_probs = torch.where(
+ x < -0.999,
+ log_cdf_plus,
+ torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))),
+ )
+ assert log_probs.shape == x.shape
+ return log_probs
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
+ """
+ Get a pre-defined beta schedule for the given name.
+
+ The beta schedule library consists of beta schedules which remain similar
+ in the limit of num_diffusion_timesteps.
+ Beta schedules may be added, but should not be removed or changed once
+ they are committed to maintain backwards compatibility.
+ """
+ if schedule_name == "linear":
+ # Linear schedule from Ho et al, extended to work for any number of
+ # diffusion steps.
+ scale = 1000 / num_diffusion_timesteps
+ beta_start = scale * 0.0001
+ beta_end = scale * 0.02
+ return np.linspace(
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
+ )
+ elif schedule_name == "cosine":
+ return betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
+ )
+ else:
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+class ModelMeanType(enum.Enum):
+ """
+ Which type of output the model predicts.
+ """
+
+ PREVIOUS_X = 'previous_x' # the model predicts x_{t-1}
+ START_X = 'start_x' # the model predicts x_0
+ EPSILON = 'epsilon' # the model predicts epsilon
+
+
+class ModelVarType(enum.Enum):
+ """
+ What is used as the model's output variance.
+
+ The LEARNED_RANGE option has been added to allow the model to predict
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
+ """
+
+ LEARNED = 'learned'
+ FIXED_SMALL = 'fixed_small'
+ FIXED_LARGE = 'fixed_large'
+ LEARNED_RANGE = 'learned_range'
+
+
+class LossType(enum.Enum):
+ MSE = 'mse' # use raw MSE loss (and KL when learning variances)
+ RESCALED_MSE = 'rescaled_mse' # use raw MSE loss (with RESCALED_KL when learning variances)
+ KL = 'kl' # use the variational lower-bound
+ RESCALED_KL = 'rescaled_kl' # like KL, but rescale to estimate the full VLB
+
+ def is_vb(self):
+ return self == LossType.KL or self == LossType.RESCALED_KL
+
+
+class GaussianDiffusion:
+ """
+ Utilities for training and sampling diffusion models.
+
+ Ported directly from here, and then adapted over time to further experimentation.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
+
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
+ starting at T and going to 1.
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
+ :param model_var_type: a ModelVarType determining how variance is output.
+ :param loss_type: a LossType determining the loss function to use.
+ :param rescale_timesteps: if True, pass floating point timesteps into the
+ model so that they are always scaled like in the
+ original paper (0 to 1000).
+ """
+
+ def __init__(
+ self,
+ *,
+ betas,
+ model_mean_type,
+ model_var_type,
+ loss_type,
+ rescale_timesteps=False,
+ conditioning_free=False,
+ conditioning_free_k=1,
+ ramp_conditioning_free=True,
+ ):
+ self.model_mean_type = ModelMeanType(model_mean_type)
+ self.model_var_type = ModelVarType(model_var_type)
+ self.loss_type = LossType(loss_type)
+ self.rescale_timesteps = rescale_timesteps
+ self.conditioning_free = conditioning_free
+ self.conditioning_free_k = conditioning_free_k
+ self.ramp_conditioning_free = ramp_conditioning_free
+
+ # Use float64 for accuracy.
+ betas = np.array(betas, dtype=np.float64)
+ self.betas = betas
+ assert len(betas.shape) == 1, "betas must be 1-D"
+ assert (betas > 0).all() and (betas <= 1).all()
+
+ self.num_timesteps = int(betas.shape[0])
+
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ self.posterior_variance = (
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ # log calculation clipped because the posterior variance is 0 at the
+ # beginning of the diffusion chain.
+ self.posterior_log_variance_clipped = np.log(
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
+ )
+ self.posterior_mean_coef1 = (
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ self.posterior_mean_coef2 = (
+ (1.0 - self.alphas_cumprod_prev)
+ * np.sqrt(alphas)
+ / (1.0 - self.alphas_cumprod)
+ )
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ )
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = _extract_into_tensor(
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
+ )
+ return mean, variance, log_variance
+
+ def q_sample(self, x_start, t, noise=None):
+ """
+ Diffuse the data for a given number of diffusion steps.
+
+ In other words, sample from q(x_t | x_0).
+
+ :param x_start: the initial data batch.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :param noise: if specified, the split-out normal noise.
+ :return: A noisy version of x_start.
+ """
+ if noise is None:
+ noise = torch.randn_like(x_start)
+ assert noise.shape == x_start.shape
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+ * noise
+ )
+
+ def q_posterior_mean_variance(self, x_start, x_t, t):
+ """
+ Compute the mean and variance of the diffusion posterior:
+
+ q(x_{t-1} | x_t, x_0)
+
+ """
+ assert x_start.shape == x_t.shape
+ posterior_mean = (
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape
+ )
+ assert (
+ posterior_mean.shape[0]
+ == posterior_variance.shape[0]
+ == posterior_log_variance_clipped.shape[0]
+ == x_start.shape[0]
+ )
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
+ ):
+ """
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
+ the initial x, x_0.
+
+ :param model: the model, which takes a signal and a batch of timesteps
+ as input.
+ :param x: the [N x C x ...] tensor at time t.
+ :param t: a 1-D Tensor of timesteps.
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample. Applies before
+ clip_denoised.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict with the following keys:
+ - 'mean': the model mean output.
+ - 'variance': the model variance output.
+ - 'log_variance': the log of 'variance'.
+ - 'pred_xstart': the prediction for x_0.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ B, C = x.shape[:2]
+ assert t.shape == (B,)
+ model_output = model(x, self._scale_timesteps(t), **model_kwargs)
+ if self.conditioning_free:
+ model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs)
+
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
+ model_output, model_var_values = torch.split(model_output, C, dim=1)
+ if self.conditioning_free:
+ model_output_no_conditioning, _ = torch.split(model_output_no_conditioning, C, dim=1)
+ if self.model_var_type == ModelVarType.LEARNED:
+ model_log_variance = model_var_values
+ model_variance = torch.exp(model_log_variance)
+ else:
+ min_log = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x.shape
+ )
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
+ # The model_var_values is [-1, 1] for [min_var, max_var].
+ frac = (model_var_values + 1) / 2
+ model_log_variance = frac * max_log + (1 - frac) * min_log
+ model_variance = torch.exp(model_log_variance)
+ else:
+ model_variance, model_log_variance = {
+ # for fixedlarge, we set the initial (log-)variance like so
+ # to get a better decoder log likelihood.
+ ModelVarType.FIXED_LARGE: (
+ np.append(self.posterior_variance[1], self.betas[1:]),
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
+ ),
+ ModelVarType.FIXED_SMALL: (
+ self.posterior_variance,
+ self.posterior_log_variance_clipped,
+ ),
+ }[self.model_var_type]
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
+
+ if self.conditioning_free:
+ if self.ramp_conditioning_free:
+ assert t.shape[0] == 1 # This should only be used in inference.
+ cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps)
+ else:
+ cfk = self.conditioning_free_k
+ model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ if clip_denoised:
+ return x.clamp(-1, 1)
+ return x
+
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
+ )
+ model_mean = model_output
+ elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
+ if self.model_mean_type == ModelMeanType.START_X:
+ pred_xstart = process_xstart(model_output)
+ else:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
+ )
+ model_mean, _, _ = self.q_posterior_mean_variance(
+ x_start=pred_xstart, x_t=x, t=t
+ )
+ else:
+ raise NotImplementedError(self.model_mean_type)
+
+ assert (
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
+ )
+ return {
+ "mean": model_mean,
+ "variance": model_variance,
+ "log_variance": model_log_variance,
+ "pred_xstart": pred_xstart,
+ }
+
+ def _predict_xstart_from_eps(self, x_t, t, eps):
+ assert x_t.shape == eps.shape
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
+ )
+
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
+ assert x_t.shape == xprev.shape
+ return ( # (xprev - coef2*x_t) / coef1
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
+ - _extract_into_tensor(
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
+ )
+ * x_t
+ )
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - pred_xstart
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _scale_timesteps(self, t):
+ if self.rescale_timesteps:
+ return t.float() * (1000.0 / self.num_timesteps)
+ return t
+
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute the mean for the previous step, given a function cond_fn that
+ computes the gradient of a conditional log probability with respect to
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
+ condition on y.
+
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
+ """
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
+ new_mean = (
+ p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
+ )
+ return new_mean
+
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute what the p_mean_variance output would have been, should the
+ model's score function be conditioned by cond_fn.
+
+ See condition_mean() for details on cond_fn.
+
+ Unlike condition_mean(), this instead uses the conditioning strategy
+ from Song et al (2020).
+ """
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
+ x, self._scale_timesteps(t), **model_kwargs
+ )
+
+ out = p_mean_var.copy()
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
+ out["mean"], _, _ = self.q_posterior_mean_variance(
+ x_start=out["pred_xstart"], x_t=x, t=t
+ )
+ return out
+
+ def p_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+
+ :param model: the model to sample from.
+ :param x: the current tensor at x_{t-1}.
+ :param t: the value of t, starting at 0 for the first diffusion step.
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - 'sample': a random sample from the model.
+ - 'pred_xstart': a prediction of x_0.
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ noise = torch.randn_like(x)
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ if cond_fn is not None:
+ out["mean"] = self.condition_mean(
+ cond_fn, out, x, t, model_kwargs=model_kwargs
+ )
+ sample = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def p_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model.
+
+ :param model: the model module.
+ :param shape: the shape of the samples, (N, C, H, W).
+ :param noise: if specified, the noise from the encoder to sample.
+ Should be of the same shape as `shape`.
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param device: if specified, the device to create the samples on.
+ If not specified, use a model parameter's device.
+ :param progress: if True, show a tqdm progress bar.
+ :return: a non-differentiable batch of samples.
+ """
+ final = None
+ for sample in self.p_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ ):
+ final = sample
+ return final["sample"]
+
+ def p_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model and yield intermediate samples from
+ each timestep of diffusion.
+
+ Arguments are the same as p_sample_loop().
+ Returns a generator over dicts, where each dict is the return value of
+ p_sample().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = torch.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ for i in tqdm(indices, disable=not progress):
+ t = torch.tensor([i] * shape[0], device=device)
+ with torch.no_grad():
+ out = self.p_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ )
+ yield out
+ img = out["sample"]
+
+ def ddim_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t-1} from the model using DDIM.
+
+ Same usage as p_sample().
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
+
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
+ sigma = (
+ eta
+ * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
+ * torch.sqrt(1 - alpha_bar / alpha_bar_prev)
+ )
+ # Equation 12.
+ noise = torch.randn_like(x)
+ mean_pred = (
+ out["pred_xstart"] * torch.sqrt(alpha_bar_prev)
+ + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
+ )
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ sample = mean_pred + nonzero_mask * sigma * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_reverse_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t+1} from the model using DDIM reverse ODE.
+ """
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
+ - out["pred_xstart"]
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
+
+ # Equation 12. reversed
+ mean_pred = (
+ out["pred_xstart"] * torch.sqrt(alpha_bar_next)
+ + torch.sqrt(1 - alpha_bar_next) * eps
+ )
+
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Generate samples from the model using DDIM.
+
+ Same usage as p_sample_loop().
+ """
+ final = None
+ for sample in self.ddim_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ eta=eta,
+ ):
+ final = sample
+ return final["sample"]
+
+ def ddim_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Use DDIM to sample from the model and yield intermediate samples from
+ each timestep of DDIM.
+
+ Same usage as p_sample_loop_progressive().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = torch.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices, disable=not progress)
+
+ for i in indices:
+ t = torch.tensor([i] * shape[0], device=device)
+ with torch.no_grad():
+ out = self.ddim_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ eta=eta,
+ )
+ yield out
+ img = out["sample"]
+
+ def _vb_terms_bpd(
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
+ ):
+ """
+ Get a term for the variational lower-bound.
+
+ The resulting units are bits (rather than nats, as one might expect).
+ This allows for comparison to other papers.
+
+ :return: a dict with the following keys:
+ - 'output': a shape [N] tensor of NLLs or KLs.
+ - 'pred_xstart': the x_0 predictions.
+ """
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )
+ out = self.p_mean_variance(
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
+ )
+ kl = normal_kl(
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
+ )
+ kl = mean_flat(kl) / np.log(2.0)
+
+ decoder_nll = -discretized_gaussian_log_likelihood(
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
+ )
+ assert decoder_nll.shape == x_start.shape
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
+
+ # At the first timestep return the decoder NLL,
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
+ output = torch.where((t == 0), decoder_nll, kl)
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
+
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
+ """
+ Compute training losses for a single timestep.
+
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param t: a batch of timestep indices.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param noise: if specified, the specific Gaussian noise to try to remove.
+ :return: a dict with the key "loss" containing a tensor of shape [N].
+ Some mean or variance settings may also have other keys.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+ if noise is None:
+ noise = torch.randn_like(x_start)
+ x_t = self.q_sample(x_start, t, noise=noise)
+
+ terms = {}
+
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
+ # TODO: support multiple model outputs for this mode.
+ terms["loss"] = self._vb_terms_bpd(
+ model=model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ model_kwargs=model_kwargs,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] *= self.num_timesteps
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
+ model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs)
+ if isinstance(model_outputs, tuple):
+ model_output = model_outputs[0]
+ terms['extra_outputs'] = model_outputs[1:]
+ else:
+ model_output = model_outputs
+
+ if self.model_var_type in [
+ ModelVarType.LEARNED,
+ ModelVarType.LEARNED_RANGE,
+ ]:
+ B, C = x_t.shape[:2]
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
+ model_output, model_var_values = torch.split(model_output, C, dim=1)
+ # Learn the variance using the variational bound, but don't let
+ # it affect our mean prediction.
+ frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
+ terms["vb"] = self._vb_terms_bpd(
+ model=lambda *args, r=frozen_out: r,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_MSE:
+ # Divide by 1000 for equivalence with initial implementation.
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
+ terms["vb"] *= self.num_timesteps / 1000.0
+
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
+ target = self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )[0]
+ x_start_pred = torch.zeros(x_start) # Not supported.
+ elif self.model_mean_type == ModelMeanType.START_X:
+ target = x_start
+ x_start_pred = model_output
+ elif self.model_mean_type == ModelMeanType.EPSILON:
+ target = noise
+ x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
+ else:
+ raise NotImplementedError(self.model_mean_type)
+ assert model_output.shape == target.shape == x_start.shape
+ terms["mse"] = mean_flat((target - model_output) ** 2)
+ terms["x_start_predicted"] = x_start_pred
+ if "vb" in terms:
+ terms["loss"] = terms["mse"] + terms["vb"]
+ else:
+ terms["loss"] = terms["mse"]
+ else:
+ raise NotImplementedError(self.loss_type)
+
+ return terms
+
+ def autoregressive_training_losses(self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None):
+ """
+ Compute training losses for a single timestep.
+
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param t: a batch of timestep indices.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param noise: if specified, the specific Gaussian noise to try to remove.
+ :return: a dict with the key "loss" containing a tensor of shape [N].
+ Some mean or variance settings may also have other keys.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+ if noise is None:
+ noise = torch.randn_like(x_start)
+ x_t = self.q_sample(x_start, t, noise=noise)
+ terms = {}
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
+ assert False # not currently supported for this type of diffusion.
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
+ model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
+ terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
+ model_output = terms[gd_out_key]
+ if self.model_var_type in [
+ ModelVarType.LEARNED,
+ ModelVarType.LEARNED_RANGE,
+ ]:
+ B, C = x_t.shape[:2]
+ assert model_output.shape == (B, C, 2, *x_t.shape[2:])
+ model_output, model_var_values = model_output[:, :, 0], model_output[:, :, 1]
+ # Learn the variance using the variational bound, but don't let
+ # it affect our mean prediction.
+ frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
+ terms["vb"] = self._vb_terms_bpd(
+ model=lambda *args, r=frozen_out: r,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_MSE:
+ # Divide by 1000 for equivalence with initial implementation.
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
+ terms["vb"] *= self.num_timesteps / 1000.0
+
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
+ target = self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )[0]
+ x_start_pred = torch.zeros(x_start) # Not supported.
+ elif self.model_mean_type == ModelMeanType.START_X:
+ target = x_start
+ x_start_pred = model_output
+ elif self.model_mean_type == ModelMeanType.EPSILON:
+ target = noise
+ x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
+ else:
+ raise NotImplementedError(self.model_mean_type)
+ assert model_output.shape == target.shape == x_start.shape
+ terms["mse"] = mean_flat((target - model_output) ** 2)
+ terms["x_start_predicted"] = x_start_pred
+ if "vb" in terms:
+ terms["loss"] = terms["mse"] + terms["vb"]
+ else:
+ terms["loss"] = terms["mse"]
+ else:
+ raise NotImplementedError(self.loss_type)
+
+ return terms
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+
+ This term can't be optimized, as it only depends on the encoder.
+
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
+ )
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
+ """
+ Compute the entire variational lower-bound, measured in bits-per-dim,
+ as well as other related quantities.
+
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param clip_denoised: if True, clip denoised samples.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+
+ :return: a dict containing the following keys:
+ - total_bpd: the total variational lower-bound, per batch element.
+ - prior_bpd: the prior term in the lower-bound.
+ - vb: an [N x T] tensor of terms in the lower-bound.
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
+ """
+ device = x_start.device
+ batch_size = x_start.shape[0]
+
+ vb = []
+ xstart_mse = []
+ mse = []
+ for t in list(range(self.num_timesteps))[::-1]:
+ t_batch = torch.tensor([t] * batch_size, device=device)
+ noise = torch.randn_like(x_start)
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
+ # Calculate VLB term at the current timestep
+ with torch.no_grad():
+ out = self._vb_terms_bpd(
+ model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t_batch,
+ clip_denoised=clip_denoised,
+ model_kwargs=model_kwargs,
+ )
+ vb.append(out["output"])
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
+ mse.append(mean_flat((eps - noise) ** 2))
+
+ vb = torch.stack(vb, dim=1)
+ xstart_mse = torch.stack(xstart_mse, dim=1)
+ mse = torch.stack(mse, dim=1)
+
+ prior_bpd = self._prior_bpd(x_start)
+ total_bpd = vb.sum(dim=1) + prior_bpd
+ return {
+ "total_bpd": total_bpd,
+ "prior_bpd": prior_bpd,
+ "vb": vb,
+ "xstart_mse": xstart_mse,
+ "mse": mse,
+ }
+
+
+def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
+ """
+ Get a pre-defined beta schedule for the given name.
+
+ The beta schedule library consists of beta schedules which remain similar
+ in the limit of num_diffusion_timesteps.
+ Beta schedules may be added, but should not be removed or changed once
+ they are committed to maintain backwards compatibility.
+ """
+ if schedule_name == "linear":
+ # Linear schedule from Ho et al, extended to work for any number of
+ # diffusion steps.
+ scale = 1000 / num_diffusion_timesteps
+ beta_start = scale * 0.0001
+ beta_end = scale * 0.02
+ return np.linspace(
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
+ )
+ elif schedule_name == "cosine":
+ return betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
+ )
+ else:
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
+
+
+class SpacedDiffusion(GaussianDiffusion):
+ """
+ A diffusion process which can skip steps in a base diffusion process.
+
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
+ original diffusion process to retain.
+ :param kwargs: the kwargs to create the base diffusion process.
+ """
+
+ def __init__(self, use_timesteps, **kwargs):
+ self.use_timesteps = set(use_timesteps)
+ self.timestep_map = []
+ self.original_num_steps = len(kwargs["betas"])
+
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
+ if i in self.use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ self.timestep_map.append(i)
+ kwargs["betas"] = np.array(new_betas)
+ super().__init__(**kwargs)
+
+ def p_mean_variance(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
+
+ def training_losses(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
+
+ def autoregressive_training_losses(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs)
+
+ def condition_mean(self, cond_fn, *args, **kwargs):
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def condition_score(self, cond_fn, *args, **kwargs):
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def _wrap_model(self, model, autoregressive=False):
+ if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel):
+ return model
+ mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel
+ return mod(
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
+ )
+
+ def _scale_timesteps(self, t):
+ # Scaling is done by the wrapped model.
+ return t
+
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim") :])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+
+class _WrappedModel:
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
+ self.model = model
+ self.timestep_map = timestep_map
+ self.rescale_timesteps = rescale_timesteps
+ self.original_num_steps = original_num_steps
+
+ def __call__(self, x, ts, **kwargs):
+ map_tensor = torch.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
+ new_ts = map_tensor[ts]
+ if self.rescale_timesteps:
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
+ return self.model(x, new_ts, **kwargs)
+
+
+class _WrappedAutoregressiveModel:
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
+ self.model = model
+ self.timestep_map = timestep_map
+ self.rescale_timesteps = rescale_timesteps
+ self.original_num_steps = original_num_steps
+
+ def __call__(self, x, x0, ts, **kwargs):
+ map_tensor = torch.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
+ new_ts = map_tensor[ts]
+ if self.rescale_timesteps:
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
+ return self.model(x, x0, new_ts, **kwargs)
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res.expand(broadcast_shape)
+
def is_latent(t):
return t.dtype == torch.float
diff --git a/tortoise_tts/train.py b/tortoise_tts/train.py
index c0c6bdb..7da8399 100755
--- a/tortoise_tts/train.py
+++ b/tortoise_tts/train.py
@@ -12,6 +12,7 @@ import json
import logging
import random
import torch
+import torchaudio
import torch.nn.functional as F
import traceback
import shutil
@@ -23,6 +24,9 @@ import argparse
from torch.nn.utils.rnn import pad_sequence
+from .models.arch_utils import denormalize_tacotron_mel
+from .models.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
+
_logger = logging.getLogger(__name__)
mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
@@ -32,13 +36,18 @@ def train_feeder(engine, batch):
device = batch["text"][0].device
batch_size = len(batch["text"])
- conditioning_latents = pad_sequence([ latents[0] for latents in batch["latents"] ], batch_first = True)
- text_inputs = pad_sequence([ text for text in batch["text"] ], batch_first = True)
+ autoregressive_conds = torch.stack([ conds for conds in batch["conds_0"] ])
+ diffusion_conds = torch.stack([ conds for conds in batch["conds_1"] ])
+
+ autoregressive_latents = torch.stack([ latents for latents in batch["latents_0"] ])
+ diffusion_latents = torch.stack([ latents for latents in batch["latents_1"] ])
+
+ text_tokens = pad_sequence([ text for text in batch["text"] ], batch_first = True)
text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32)
- mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True)
+ mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = engine.module.stop_mel_token )
wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32)
- engine.forward(conditioning_latents, text_inputs, text_lengths, mel_codes, wav_lengths)
+ engine.forward(autoregressive_latents, text_tokens, text_lengths, mel_codes, wav_lengths)
losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats")
@@ -78,11 +87,9 @@ def run_eval(engines, eval_name, dl):
ref_path.parent.mkdir(parents=True, exist_ok=True)
prom_path.parent.mkdir(parents=True, exist_ok=True)
- """
- ref_audio, sr = qnt.decode_to_file(ref, ref_path)
- hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
- prom_audio, sr = qnt.decode_to_file(prom, prom_path)
- """
+ ref_audio, sr = emb.decode_to_file(ref, ref_path)
+ hyp_audio, sr = emb.decode_to_file(hyp, hyp_path)
+ prom_audio, sr = emb.decode_to_file(prom, prom_path)
# pseudo loss calculation since we don't get the logits during eval
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
@@ -90,17 +97,119 @@ def run_eval(engines, eval_name, dl):
hyp_audio = hyp_audio[..., 0:min_length]
stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item())
+ autoregressive = None
+ diffusion = None
+ clvp = None
+ vocoder = None
+
+ for name in engines:
+ engine = engines[name]
+ if "autoregressive" in name:
+ autoregressive = engine.module
+ elif "diffusion" in name:
+ diffusion = engine.module
+ elif "clvp" in name:
+ clvp = engine.module
+ elif "vocoder" in name:
+ vocoder = engine.module
+
+ trained_diffusion_steps=4000
+ desired_diffusion_steps=50
+ cond_free=False
+ cond_free_k=1
+ diffuser = SpacedDiffusion(
+ use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
+ model_mean_type='epsilon',
+ model_var_type='learned_range',
+ loss_type='mse',
+ betas=get_named_beta_schedule('linear', trained_diffusion_steps),
+ conditioning_free=cond_free,
+ conditioning_free_k=cond_free_k
+ )
+
processed = 0
+ temperature = 1.0
while processed < cfg.evaluation.size:
batch: dict = to_device(next(iter(dl)), cfg.device)
processed += len(batch["text"])
- for name in engines:
- engine = engines[name]
+ max_mel_tokens = 500
+ stop_mel_token = autoregressive.stop_mel_token
+ calm_token = 83
+ verbose = True
- ...
+ with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
+ autoregressive_conds = torch.stack([ conds for conds in batch["conds_0"] ])
+ diffusion_conds = torch.stack([ conds for conds in batch["conds_1"] ])
- process( name, batch, resps_list )
+ autoregressive_latents = torch.stack([ latents for latents in batch["latents_0"] ])
+ diffusion_latents = torch.stack([ latents for latents in batch["latents_1"] ])
+
+ text_tokens = pad_sequence([ text for text in batch["text"] ], batch_first = True)
+ text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32)
+ mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = stop_mel_token )
+ wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32)
+
+ # autoregressive pass
+ if True:
+ codes = autoregressive.inference_speech(
+ autoregressive_latents,
+ text_tokens,
+ do_sample=True,
+ #top_p=top_p,
+ temperature=temperature,
+ num_return_sequences=1,
+ #length_penalty=length_penalty,
+ #repetition_penalty=repetition_penalty,
+ max_generate_length=max_mel_tokens,
+ )
+ padding_needed = max_mel_tokens - codes.shape[1]
+ codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
+ else:
+ codes = mel_codes
+
+ latents = autoregressive.forward(
+ autoregressive_latents,
+ text_tokens,
+ text_lengths,
+ codes,
+ wav_lengths,
+ return_latent=True,
+ clip_inputs=False
+ )
+
+ calm_tokens = 0
+ for k in range( codes.shape[-1] ):
+ if codes[0, k] == calm_token:
+ calm_tokens += 1
+ else:
+ calm_tokens = 0
+ if calm_tokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
+ latents = latents[:, :k]
+ break
+
+ # diffusion pass
+ output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
+ output_shape = (latents.shape[0], 100, output_seq_len)
+ precomputed_embeddings = diffusion.timestep_independent(latents, diffusion_latents, output_seq_len, False)
+
+ noise = torch.randn(output_shape, device=latents.device) * temperature
+ mel = diffuser.p_sample_loop(
+ diffusion,
+ output_shape,
+ noise=noise,
+ model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings},
+ progress=verbose
+ )
+ mels = denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
+
+ # vocoder pass
+ wavs = vocoder.inference(mels)
+
+ for i, wav in enumerate( wavs ):
+ torchaudio.save( f"./data/{cfg.start_time}[{i}].wav", wav.cpu(), 24_000 )
+
+ # process( name, batch, resps_list )
stats = {k: sum(v) / len(v) for k, v in stats.items()}