diff --git a/tortoise/api_fast.py b/tortoise/api_fast.py new file mode 100644 index 0000000..26aeaec --- /dev/null +++ b/tortoise/api_fast.py @@ -0,0 +1,740 @@ +import os +import random +import uuid +from time import time +from urllib import request + +import torch +import torch.nn.functional as F +import progressbar +import torchaudio +import numpy as np +from tortoise.models.classifier import AudioMiniEncoderWithClassifierHead +from tortoise.models.diffusion_decoder import DiffusionTts +from tortoise.models.autoregressive import UnifiedVoice +from tqdm import tqdm +from tortoise.models.arch_util import TorchMelSpectrogram +from tortoise.models.clvp import CLVP +from tortoise.models.cvvp import CVVP +from tortoise.models.hifigan_decoder import HifiganGenerator +from tortoise.models.random_latent_generator import RandomLatentConverter +from tortoise.models.vocoder import UnivNetGenerator +from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel +from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule +from tortoise.utils.tokenizer import VoiceBpeTokenizer +from tortoise.utils.wav2vec_alignment import Wav2VecAlignment +from contextlib import contextmanager +# from tortoise.models.stream_generator import init_stream_support +from huggingface_hub import hf_hub_download + +from tortoise.utils.device import get_device, get_device_name, get_device_batch_size, print_stats, do_gc + +pbar = None +# init_stream_support() +STOP_SIGNAL = False +DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'tortoise', 'models') +MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', DEFAULT_MODELS_DIR) + +MODELS = { + 'autoregressive.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/autoregressive.pth', + 'classifier.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/classifier.pth', + 'rlg_auto.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/rlg_auto.pth', + 'hifidecoder.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/hifidecoder.pth', +} + +def download_models(specific_models=None): + """ + Call to download all the models that Tortoise uses. + """ + os.makedirs(MODELS_DIR, exist_ok=True) + + def show_progress(block_num, block_size, total_size): + global pbar + if pbar is None: + pbar = progressbar.ProgressBar(maxval=total_size) + pbar.start() + + downloaded = block_num * block_size + if downloaded < total_size: + pbar.update(downloaded) + else: + pbar.finish() + pbar = None + + for model_name, url in MODELS.items(): + if specific_models is not None and model_name not in specific_models: + continue + model_path = os.path.join(MODELS_DIR, model_name) + if os.path.exists(model_path): + continue + print(f'Downloading {model_name} from {url}...') + request.urlretrieve(url, model_path, show_progress) + print('Done.') + +def get_model_path(model_name, models_dir=MODELS_DIR): + """ + Get path to given model, download it if it doesn't exist. + """ + if model_name not in MODELS: + raise ValueError(f'Model {model_name} not found in available models.') + model_path = os.path.join(models_dir, model_name) + if not os.path.exists(model_path) and models_dir == MODELS_DIR: + download_models([model_name]) + # Add the logic to download models if not available + # model_path = hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=models_dir) + return model_path + +def check_for_kill_signal(): + global STOP_SIGNAL + if STOP_SIGNAL: + STOP_SIGNAL = False + raise Exception("Kill signal detected") + +def pad_or_truncate(t, length): + """ + Utility function for forcing to have the specified sequence length, whether by clipping it or padding it with 0s. + """ + if t.shape[-1] == length: + return t + elif t.shape[-1] < length: + return F.pad(t, (0, length-t.shape[-1])) + else: + return t[..., :length] + + +def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1): + """ + Helper function to load a GaussianDiffusion instance configured for use as a vocoder. + """ + return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon', + model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps), + conditioning_free=cond_free, conditioning_free_k=cond_free_k) + + +def format_conditioning(clip, cond_length=132300, device="cuda" if not torch.backends.mps.is_available() else 'mps'): + """ + Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models. + """ + gap = clip.shape[-1] - cond_length + if gap < 0: + clip = F.pad(clip, pad=(0, abs(gap))) + elif gap > 0: + rand_start = random.randint(0, gap) + clip = clip[:, rand_start:rand_start + cond_length] + mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0) + return mel_clip.unsqueeze(0).to(device) + + +def fix_autoregressive_output(codes, stop_token, complain=True): + """ + This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was + trained on and what the autoregressive code generator creates (which has no padding or end). + This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with + a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE + and copying out the last few codes. + + Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar. + """ + # Strip off the autoregressive stop token and add padding. + stop_token_indices = (codes == stop_token).nonzero() + if len(stop_token_indices) == 0: + if complain: + print("No stop tokens found in one of the generated voice clips. This typically means the spoken audio is " + "too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, " + "try breaking up your input text.") + return codes + else: + codes[stop_token_indices] = 83 + stm = stop_token_indices.min().item() + codes[stm:] = 83 + if stm - 3 < codes.shape[0]: + codes[-3] = 45 + codes[-2] = 45 + codes[-1] = 248 + + return codes + + +def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_latents, temperature=1, verbose=True): + """ + Uses the specified diffusion model to convert discrete codes into a spectrogram. + """ + with torch.no_grad(): + output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. + output_shape = (latents.shape[0], 100, output_seq_len) + precomputed_embeddings = diffusion_model.timestep_independent(latents, conditioning_latents, output_seq_len, False) + + noise = torch.randn(output_shape, device=latents.device) * temperature + mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise, + model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}, + progress=verbose) + return denormalize_tacotron_mel(mel)[:,:,:output_seq_len] + + +def classify_audio_clip(clip): + """ + Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise. + :param clip: torch tensor containing audio waveform data (get it from load_audio) + :return: True if the clip was classified as coming from Tortoise and false if it was classified as real. + """ + classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4, + resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32, + dropout=0, kernel_size=5, distribute_zero_label=False) + classifier.load_state_dict(torch.load(get_model_path('classifier.pth'), map_location=torch.device('cpu'))) + clip = clip.cpu().unsqueeze(0) + results = F.softmax(classifier(clip), dim=-1) + return results[0][0] + + +def pick_best_batch_size_for_gpu(): + """ + Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give + you a good shot. + """ + if torch.cuda.is_available(): + _, available = torch.cuda.mem_get_info() + availableGb = available / (1024 ** 3) + if availableGb > 14: + return 16 + elif availableGb > 10: + return 8 + elif availableGb > 7: + return 4 + if torch.backends.mps.is_available(): + import psutil + available = psutil.virtual_memory().total + availableGb = available / (1024 ** 3) + if availableGb > 14: + return 16 + elif availableGb > 10: + return 8 + elif availableGb > 7: + return 4 + return 1 + +# Taken from MRQ's api +@torch.inference_mode() +def format_conditioning(clip, cond_length=132300, device='cuda', sampling_rate=22050): + """ + Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models. + """ + gap = clip.shape[-1] - cond_length + if gap < 0: + clip = F.pad(clip, pad=(0, abs(gap))) + elif gap > 0: + rand_start = random.randint(0, gap) + clip = clip[:, rand_start:rand_start + cond_length] + mel_clip = TorchMelSpectrogram(sampling_rate=sampling_rate)(clip.unsqueeze(0)).squeeze(0) + mel_clip = mel_clip.unsqueeze(0) + return migrate_to_device(mel_clip, device) + +# Taken from MRQ's api +def hash_file(path, algo="md5", buffer_size=0): + import hashlib + + hash = None + if algo == "md5": + hash = hashlib.md5() + elif algo == "sha1": + hash = hashlib.sha1() + else: + raise Exception(f'Unknown hash algorithm specified: {algo}') + + if not os.path.exists(path): + raise Exception(f'Path not found: {path}') + + with open(path, 'rb') as f: + if buffer_size > 0: + while True: + data = f.read(buffer_size) + if not data: + break + hash.update(data) + else: + hash.update(f.read()) + + return "{0}".format(hash.hexdigest()) + +# Taken from MRQ's api +def migrate_to_device( t, device ): + if t is None: + return t + + if not hasattr(t, 'device'): + t.device = device + t.manually_track_device = True + elif t.device == device: + return t + + if hasattr(t, 'manually_track_device') and t.manually_track_device: + t.device = device + + t = t.to(device) + + do_gc() + + return t + +class TextToSpeech: + """ + Main entry point into Tortoise. + """ + + def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, + enable_redaction=True, kv_cache=False, use_deepspeed=False, half=False, device=None, + tokenizer_vocab_file=None, tokenizer_basic=False, + autoregressive_model_path=None, tokenizer_json=None, + minor_optimizations=True, + input_sample_rate=22050, output_sample_rate=24000, + ): + + """ + Constructor + :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing + GPU OOM errors. Larger numbers generates slightly faster. + :param models_dir: Where model weights are stored. This should only be specified if you are providing your own + models, otherwise use the defaults. + :param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output + (but are still rendered by the model). This can be used for prompt engineering. + Default is true. + :param device: Device to use when running the model. If omitted, the device will be automatically chosen. + """ + self.use_deepspeed = use_deepspeed # Store deepspeed + self.use_kv_cache = kv_cache # Store KV cache + self.preloaded_tensors = minor_optimizations + self.input_sample_rate = input_sample_rate + self.output_sample_rate = output_sample_rate + + self.models_dir = models_dir + self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size + self.enable_redaction = enable_redaction + self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu') + if torch.backends.mps.is_available(): + self.device = torch.device('mps') + if self.enable_redaction: + self.aligner = Wav2VecAlignment() + + self.load_tokenizer_json(tokenizer_json) + + self.half = half + if os.path.exists(f'{models_dir}/autoregressive.ptt'): + # Assume this is a traced directory. + self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt') + else: + if not autoregressive_model_path or not os.path.exists(autoregressive_model_path): + autoregressive_model_path = get_model_path('autoregressive.pth', models_dir) + + self.load_autoregressive_model(autoregressive_model_path) + + # self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, + # model_dim=1024, + # heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, + # train_solo_embeddings=False).to(self.device).eval() + # self.autoregressive.load_state_dict(torch.load(autoregressive_model_path, weights_only=True), strict=False) + # self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=kv_cache, half=self.half) + # self.autoregressive = migrate_to_device(self.autoregressive, self.device) + # print(f"Loaded autoregressive model") + + self.hifi_decoder = HifiganGenerator(in_channels=1024, out_channels = 1, resblock_type = "1", + resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], resblock_kernel_sizes = [3, 7, 11], + upsample_kernel_sizes = [16, 16, 4, 4], upsample_initial_channel = 512, upsample_factors = [8, 8, 2, 2], + cond_channels=1024).to(self.device).eval() + hifi_model = torch.load(get_model_path('hifidecoder.pth')) + self.hifi_decoder.load_state_dict(hifi_model, strict=False) + self.hifi_decoder.to(self.device) + # Random latent generators (RLGs) are loaded lazily. + self.rlg_auto = None + + # Taken from MRQ's api.py + def load_autoregressive_model(self, autoregressive_model_path, is_xtts=False): + if hasattr(self,"autoregressive_model_path") and os.path.samefile(self.autoregressive_model_path, autoregressive_model_path): + return + + self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', self.models_dir) + new_hash = hash_file(self.autoregressive_model_path) + + if hasattr(self,"autoregressive_model_hash") and self.autoregressive_model_hash == new_hash: + return + + self.autoregressive_model_hash = new_hash + + self.loading = True + print(f"Loading autoregressive model: {self.autoregressive_model_path}") + + if hasattr(self, 'autoregressive'): + del self.autoregressive + + # XTTS requires a different "dimensionality" for its autoregressive model + if new_hash == "e4ce21eae0043f7691d6a6c8540b74b8" or is_xtts: + dimensionality = { + "max_mel_tokens": 605, + "max_text_tokens": 402, + "max_prompt_tokens": 70, + "max_conditioning_inputs": 1, + "layers": 30, + "model_dim": 1024, + "heads": 16, + "number_text_tokens": 5023, # -1 + "start_text_token": 261, + "stop_text_token": 0, + "number_mel_codes": 8194, + "start_mel_token": 8192, + "stop_mel_token": 8193, + } + else: + dimensionality = { + "max_mel_tokens": 604, + "max_text_tokens": 402, + "max_conditioning_inputs": 2, + "layers": 30, + "model_dim": 1024, + "heads": 16, + "number_text_tokens": 255, + "start_text_token": 255, + "checkpointing": False, + "train_solo_embeddings": False + } + + self.autoregressive = UnifiedVoice(**dimensionality).cpu().eval() + self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path)) + self.autoregressive.post_init_gpt2_config(use_deepspeed=self.use_deepspeed, kv_cache=self.use_kv_cache) + if self.preloaded_tensors: + self.autoregressive = migrate_to_device( self.autoregressive, self.device ) + + self.loading = False + print(f"Loaded autoregressive model") + + # Taken from MRQ's modified api.py + def load_tokenizer_json(self, tokenizer_json): + if hasattr(self,"tokenizer_json") and os.path.samefile(self.tokenizer_json, tokenizer_json): + return + + self.loading = True + self.tokenizer_json = tokenizer_json if tokenizer_json else os.path.join(os.path.dirname(os.path.realpath(__file__)), '../tortoise/data/tokenizer.json') + print("Loading tokenizer JSON:", self.tokenizer_json) + + if hasattr(self, 'tokenizer'): + del self.tokenizer + + self.tokenizer = VoiceBpeTokenizer(vocab_file=self.tokenizer_json) + self.loading = False + print(f"Loaded tokenizer") + + def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, slices=1, max_chunk_size=None, force_cpu=False, original_ar=False, original_diffusion=False): + """ + Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent). + These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic + properties. + :param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data. + """ + + with torch.no_grad(): + # computing conditional latents requires being done on the CPU if using DML because M$ still hasn't implemented some core functions + if get_device_name() == "dml": + force_cpu = True + device = torch.device('cpu') if force_cpu else self.device + + if not isinstance(voice_samples, list): + voice_samples = [voice_samples] + + resampler_22K = torchaudio.transforms.Resample( + self.input_sample_rate, + 22050, + lowpass_filter_width=16, + rolloff=0.85, + resampling_method="kaiser_window", + beta=8.555504641634386, + ).to(device) + + resampler_24K = torchaudio.transforms.Resample( + self.input_sample_rate, + 24000, + lowpass_filter_width=16, + rolloff=0.85, + resampling_method="kaiser_window", + beta=8.555504641634386, + ).to(device) + + voice_samples = [migrate_to_device(v, device) for v in voice_samples] + + auto_conds = [] + diffusion_conds = [] + + if original_ar: + samples = [resampler_22K(sample) for sample in voice_samples] + for sample in tqdm(samples, desc="Computing AR conditioning latents..."): + auto_conds.append(format_conditioning(sample, device=device, sampling_rate=self.input_sample_rate, cond_length=132300)) + else: + samples = [resampler_22K(sample) for sample in voice_samples] + concat = torch.cat(samples, dim=-1) + chunk_size = concat.shape[-1] + + if slices == 0: + slices = 1 + elif max_chunk_size is not None and chunk_size > max_chunk_size: + slices = 1 + while int(chunk_size / slices) > max_chunk_size: + slices = slices + 1 + + chunks = torch.chunk(concat, slices, dim=1) + chunk_size = chunks[0].shape[-1] + + for chunk in tqdm(chunks, desc="Computing AR conditioning latents..."): + auto_conds.append(format_conditioning(chunk, device=device, sampling_rate=self.input_sample_rate, cond_length=chunk_size)) + + auto_conds = torch.stack(auto_conds, dim=1) + self.autoregressive = migrate_to_device( self.autoregressive, device ) + auto_latent = self.autoregressive.get_conditioning(auto_conds) + self.autoregressive = migrate_to_device( self.autoregressive, self.device if self.preloaded_tensors else 'cpu' ) + + if return_mels: + return auto_latent, auto_conds, diffusion_conds + else: + return auto_latent + + def get_random_conditioning_latents(self): + # Lazy-load the RLG models. + if self.rlg_auto is None: + self.rlg_auto = RandomLatentConverter(1024).eval() + self.rlg_auto.load_state_dict(torch.load(get_model_path('rlg_auto.pth', self.models_dir), map_location=torch.device('cpu'))) + with torch.no_grad(): + return self.rlg_auto(torch.tensor([0.0])) + + + # taken from here https://github.com/coqui-ai/TTS/blob/d21f15cc850788f9cdf93dac0321395138665287/TTS/tts/models/xtts.py#L666 + def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): + """Handle chunk formatting in streaming mode""" + wav_chunk = wav_gen[:-overlap_len] + if wav_gen_prev is not None: + wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len] + if wav_overlap is not None: + crossfade_wav = wav_chunk[:overlap_len] + crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device) + wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device) + wav_chunk[:overlap_len] += crossfade_wav + wav_overlap = wav_gen[-overlap_len:] + wav_gen_prev = wav_gen + return wav_chunk, wav_gen_prev, wav_overlap + + + def tts_stream(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, use_deterministic_seed=None, + return_deterministic_state=False, overlap_wav_len=1024, stream_chunk_size=40, + # autoregressive generation parameters follow + num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500, + # CVVP parameters follow + cvvp_amount=.0, + # diffusion generation parameters follow + diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0, + **hf_generate_kwargs): + """ + Produces an audio clip of the given text being spoken with the given reference voice. + :param text: Text to be spoken. + :param voice_samples: List of 2 or more ~10 second reference clips which should be torch tensors containing 22.05kHz waveform data. + :param conditioning_latents: A tuple of (autoregressive_conditioning_latent, diffusion_conditioning_latent), which + can be provided in lieu of voice_samples. This is ignored unless voice_samples=None. + Conditioning latents can be retrieved via get_conditioning_latents(). + :param k: The number of returned clips. The most likely (as determined by Tortoises' CLVP model) clips are returned. + :param verbose: Whether or not to print log messages indicating the progress of creating a clip. Default=true. + ~~AUTOREGRESSIVE KNOBS~~ + :param num_autoregressive_samples: Number of samples taken from the autoregressive model, all of which are filtered using CLVP. + As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great". + :param temperature: The softmax temperature of the autoregressive model. + :param length_penalty: A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs. + :param repetition_penalty: A penalty that prevents the autoregressive decoder from repeating itself during decoding. Can be used to reduce the incidence + of long silences or "uhhhhhhs", etc. + :param top_p: P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" (aka boring) outputs. + :param max_mel_tokens: Restricts the output length. (0,600] integer. Each unit is 1/20 of a second. + ~~DIFFUSION KNOBS~~ + :param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine + the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, + however. + :param cond_free: Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion performs two forward passes for + each diffusion step: one with the outputs of the autoregressive model and one with no conditioning priors. The output + of the two is blended according to the cond_free_k value below. Conditioning-free diffusion is the real deal, and + dramatically improves realism. + :param cond_free_k: Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf]. + As cond_free_k increases, the output becomes dominated by the conditioning-free signal. + Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k + :param diffusion_temperature: Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0 + are the "mean" prediction of the diffusion network and will sound bland and smeared. + ~~OTHER STUFF~~ + :param hf_generate_kwargs: The huggingface Transformers generate API is used for the autoregressive transformer. + Extra keyword args fed to this function get forwarded directly to that API. Documentation + here: https://huggingface.co/docs/transformers/internal/generation_utils + :return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. + Sample rate is 24kHz. + """ + deterministic_seed = self.deterministic_state(seed=use_deterministic_seed) + + text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) + text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. + assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.' + if voice_samples is not None: + auto_conditioning = self.get_conditioning_latents(voice_samples, return_mels=False) + elif conditioning_latents is not None: + latent_tuple = conditioning_latents + if len(latent_tuple) == 2: + auto_conditioning = conditioning_latents + else: + auto_conditioning, auto_conds, _ = conditioning_latents + else: + auto_conditioning = self.get_random_conditioning_latents() + + auto_conditioning = migrate_to_device( auto_conditioning, self.device ) + + + with torch.no_grad(): + calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" + if verbose: + print("Generating autoregressive samples..") + with torch.autocast( + device_type="cuda" , dtype=torch.float16, enabled=self.half + ): + fake_inputs = self.autoregressive.compute_embeddings( + auto_conditioning, + text_tokens, + ) + gpt_generator = self.autoregressive.get_generator( + fake_inputs=fake_inputs, + top_k=50, + top_p=top_p, + temperature=temperature, + do_sample=True, + num_beams=1, + num_return_sequences=1, + length_penalty=float(length_penalty), + repetition_penalty=float(repetition_penalty), + output_attentions=False, + output_hidden_states=True, + **hf_generate_kwargs, + ) + all_latents = [] + codes_ = [] + wav_gen_prev = None + wav_overlap = None + is_end = False + first_buffer = 60 + while not is_end: + try: + with torch.autocast( + device_type="cuda", dtype=torch.float16, enabled=self.half + ): + codes, latent = next(gpt_generator) + all_latents += [latent] + codes_ += [codes] + except StopIteration: + is_end = True + + if is_end or (stream_chunk_size > 0 and len(codes_) >= max(stream_chunk_size, first_buffer)): + first_buffer = 0 + gpt_latents = torch.cat(all_latents, dim=0)[None, :] + wav_gen = self.hifi_decoder.inference(gpt_latents.to(self.device), auto_conditioning) + wav_gen = wav_gen.squeeze() + wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( + wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len + ) + codes_ = [] + yield wav_chunk + + def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, use_deterministic_seed=None, + # autoregressive generation parameters follow + num_autoregressive_samples=512, temperature=.8, length_penalty=6, repetition_penalty=8.0, + top_p=.8, max_mel_tokens=500, + # CVVP parameters follow + cvvp_amount=.0, + **hf_generate_kwargs): + """ + Produces an audio clip of the given text being spoken with the given reference voice. + :param text: Text to be spoken. + :param voice_samples: List of 2 or more ~10 second reference clips which should be torch tensors containing 22.05kHz waveform data. + :param conditioning_latents: A tuple of (autoregressive_conditioning_latent, diffusion_conditioning_latent), which + can be provided in lieu of voice_samples. This is ignored unless voice_samples=None. + Conditioning latents can be retrieved via get_conditioning_latents(). + :param k: The number of returned clips. The most likely (as determined by Tortoises' CLVP model) clips are returned. + :param verbose: Whether or not to print log messages indicating the progress of creating a clip. Default=true. + ~~AUTOREGRESSIVE KNOBS~~ + :param num_autoregressive_samples: Number of samples taken from the autoregressive model, all of which are filtered using CLVP. + As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great". + :param temperature: The softmax temperature of the autoregressive model. + :param length_penalty: A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs. + :param repetition_penalty: A penalty that prevents the autoregressive decoder from repeating itself during decoding. Can be used to reduce the incidence + of long silences or "uhhhhhhs", etc. + :param top_p: P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" (aka boring) outputs. + :param max_mel_tokens: Restricts the output length. (0,600] integer. Each unit is 1/20 of a second. + ~~DIFFUSION KNOBS~~ + :param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine + the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, + however. + :param cond_free: Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion performs two forward passes for + each diffusion step: one with the outputs of the autoregressive model and one with no conditioning priors. The output + of the two is blended according to the cond_free_k value below. Conditioning-free diffusion is the real deal, and + dramatically improves realism. + :param cond_free_k: Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf]. + As cond_free_k increases, the output becomes dominated by the conditioning-free signal. + Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k + :param diffusion_temperature: Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0 + are the "mean" prediction of the diffusion network and will sound bland and smeared. + ~~OTHER STUFF~~ + :param hf_generate_kwargs: The huggingface Transformers generate API is used for the autoregressive transformer. + Extra keyword args fed to this function get forwarded directly to that API. Documentation + here: https://huggingface.co/docs/transformers/internal/generation_utils + :return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. + Sample rate is 24kHz. + """ + deterministic_seed = self.deterministic_state(seed=use_deterministic_seed) + + text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) + text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. + + assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.' + if voice_samples is not None: + auto_conditioning = self.get_conditioning_latents(voice_samples, return_mels=False) + elif conditioning_latents is not None: + auto_conditioning = conditioning_latents + else: + auto_conditioning = self.get_random_conditioning_latents() + + auto_conditioning = migrate_to_device(auto_conditioning, self.device) + + with torch.no_grad(): + calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" + if verbose: + print("Generating autoregressive samples..") + with torch.autocast( + device_type="cuda" , dtype=torch.float16, enabled=self.half + ): + # print("Autoregressive model device:", next(self.autoregressive.parameters()).device) + # print("Hifi Decoder model device:", next(self.hifi_decoder.parameters()).device) + + codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens, + top_k=50, + top_p=top_p, + temperature=temperature, + do_sample=True, + num_beams=1, + num_return_sequences=1, + length_penalty=float(length_penalty), + repetition_penalty=float(repetition_penalty), + output_attentions=False, + output_hidden_states=True, + **hf_generate_kwargs) + gpt_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes, + torch.tensor([codes.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), + return_latent=True, clip_inputs=False) + if verbose: + print("generating audio..") + wav_gen = self.hifi_decoder.inference(gpt_latents.to(self.device), auto_conditioning) + return wav_gen + def deterministic_state(self, seed=None): + """ + Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be + reproduced. + """ + seed = int(time()) if seed is None else seed + torch.manual_seed(seed) + random.seed(seed) + # Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary. + # torch.use_deterministic_algorithms(True) + + return seed diff --git a/tortoise/models/autoregressive.py b/tortoise/models/autoregressive.py index 7d63e4a..fcd1a94 100755 --- a/tortoise/models/autoregressive.py +++ b/tortoise/models/autoregressive.py @@ -9,9 +9,6 @@ from transformers.utils.model_parallel_utils import get_device_map, assert_devic from tortoise.models.arch_util import AttentionBlock from tortoise.utils.typical_sampling import TypicalLogitsWarper -from tortoise.utils.device import get_device_count - -import tortoise.utils.torch_intermediary as ml def null_position_embeddings(range, dim): return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) @@ -36,23 +33,22 @@ class ResBlock(nn.Module): class GPT2InferenceModel(GPT2PreTrainedModel): - def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache): + def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=False): super().__init__(config) self.transformer = gpt self.text_pos_embedding = text_pos_emb self.embeddings = embeddings + self.final_norm = norm self.lm_head = nn.Sequential(norm, linear) - self.kv_cache = kv_cache - + # Model parallel self.model_parallel = False self.device_map = None self.cached_mel_emb = None - def parallelize(self, device_map=None): self.device_map = ( - get_device_map(len(self.transformer.h), range(get_device_count())) + get_device_map(len(self.transformer.h), range(max(1, torch.cuda.device_count()))) if device_map is None else device_map ) @@ -67,22 +63,24 @@ class GPT2InferenceModel(GPT2PreTrainedModel): self.lm_head = self.lm_head.to("cpu") self.model_parallel = False torch.cuda.empty_cache() - + if torch.backends.mps.is_available(): + torch.mps.empty_cache() + def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - + def store_mel_emb(self, mel_emb): self.cached_mel_emb = mel_emb - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): - - token_type_ids = kwargs.get("token_type_ids", None) - if not self.kv_cache: past = None + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) # usually None + if not self.kv_cache: + past_key_values = None # only last token for inputs_ids if past is defined in kwargs - if past: + if past_key_values: input_ids = input_ids[:, -1].unsqueeze(-1) if token_type_ids is not None: token_type_ids = token_type_ids[:, -1].unsqueeze(-1) @@ -94,13 +92,13 @@ class GPT2InferenceModel(GPT2PreTrainedModel): # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - if past: + if past_key_values: position_ids = position_ids[:, -1].unsqueeze(-1) else: position_ids = None return { "input_ids": input_ids, - "past_key_values": past, + "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "position_ids": position_ids, "attention_mask": attention_mask, @@ -127,7 +125,9 @@ class GPT2InferenceModel(GPT2PreTrainedModel): assert self.cached_mel_emb is not None assert inputs_embeds is None # Not supported by this inference model. assert labels is None # Training not supported by this inference model. - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) # Create embedding mel_len = self.cached_mel_emb.shape[1] @@ -136,14 +136,17 @@ class GPT2InferenceModel(GPT2PreTrainedModel): text_emb = self.embeddings(text_inputs) text_emb = text_emb + self.text_pos_embedding(text_emb) if self.cached_mel_emb.shape[0] != text_emb.shape[0]: - mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0) - else: + mel_emb = self.cached_mel_emb.repeat_interleave( + text_emb.shape[0] // self.cached_mel_emb.shape[0], 0 + ) + else: # this outcome only occurs once per loop in most cases mel_emb = self.cached_mel_emb emb = torch.cat([mel_emb, text_emb], dim=1) else: emb = self.embeddings(input_ids) - emb = emb + self.text_pos_embedding.get_fixed_embedding(attention_mask.shape[1]-mel_len, attention_mask.device) - + emb = emb + self.text_pos_embedding.get_fixed_embedding( + attention_mask.shape[1] - mel_len, attention_mask.device + ) transformer_outputs = self.transformer( inputs_embeds=emb, past_key_values=past_key_values, @@ -162,7 +165,10 @@ class GPT2InferenceModel(GPT2PreTrainedModel): # Set device for model parallelism if self.model_parallel: - torch.cuda.set_device(self.transformer.first_device) + if torch.backends.mps.is_available(): + self.to(self.transformer.first_device) + else: + torch.cuda.set_device(self.transformer.first_device) hidden_states = hidden_states.to(self.lm_head.weight.device) lm_logits = self.lm_head(hidden_states) @@ -187,7 +193,10 @@ class GPT2InferenceModel(GPT2PreTrainedModel): called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. """ return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ) for layer_past in past ) @@ -222,8 +231,7 @@ class ConditioningEncoder(nn.Module): class LearnedPositionEmbeddings(nn.Module): def __init__(self, seq_len, model_dim, init=.02): super().__init__() - # ml.Embedding - self.emb = ml.Embedding(seq_len, model_dim) + self.emb = nn.Embedding(seq_len, model_dim) # Initializing this way is standard for GPT-2 self.emb.weight.data.normal_(mean=0.0, std=init) @@ -232,7 +240,7 @@ class LearnedPositionEmbeddings(nn.Module): return self.emb(torch.arange(0, sl, device=x.device)) def get_fixed_embedding(self, ind, dev): - return self.emb(torch.arange(0, ind, device=dev))[ind-1:ind] + return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): @@ -283,9 +291,9 @@ class MelEncoder(nn.Module): class UnifiedVoice(nn.Module): - def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_prompt_tokens=2, max_mel_tokens=250, max_conditioning_inputs=1, + def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1, mel_length_compression=1024, number_text_tokens=256, - start_text_token=None, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192, + start_text_token=None, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, checkpointing=True, types=1): """ @@ -295,7 +303,6 @@ class UnifiedVoice(nn.Module): heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64 max_text_tokens: Maximum number of text tokens that will be encountered by model. max_mel_tokens: Maximum number of MEL tokens that will be encountered by model. - max_prompt_tokens: compat set to 2, 70 for XTTS max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s). mel_length_compression: The factor between and . Used to compute MEL code padding given wav input length. number_text_tokens: @@ -312,7 +319,7 @@ class UnifiedVoice(nn.Module): self.number_text_tokens = number_text_tokens self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token - self.stop_text_token = stop_text_token + self.stop_text_token = 0 self.number_mel_codes = number_mel_codes self.start_mel_token = start_mel_token self.stop_mel_token = stop_mel_token @@ -320,16 +327,13 @@ class UnifiedVoice(nn.Module): self.heads = heads self.max_mel_tokens = max_mel_tokens self.max_text_tokens = max_text_tokens - self.max_prompt_tokens = max_prompt_tokens self.model_dim = model_dim self.max_conditioning_inputs = max_conditioning_inputs self.mel_length_compression = mel_length_compression self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) - # ml.Embedding - self.text_embedding = ml.Embedding(self.number_text_tokens*types+1, model_dim) + self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim) if use_mel_codes_as_input: - # ml.Embedding - self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim) + self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) else: self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ @@ -342,10 +346,8 @@ class UnifiedVoice(nn.Module): self.text_solo_embedding = 0 self.final_norm = nn.LayerNorm(model_dim) - # nn.Linear - self.text_head = ml.Linear(model_dim, self.number_text_tokens*types+1) - # nn.Linear - self.mel_head = ml.Linear(model_dim, self.number_mel_codes) + self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1) + self.mel_head = nn.Linear(model_dim, self.number_mel_codes) # Initialize the embeddings per the GPT-2 scheme embeddings = [self.text_embedding] @@ -353,20 +355,35 @@ class UnifiedVoice(nn.Module): embeddings.append(self.mel_embedding) for module in embeddings: module.weight.data.normal_(mean=0.0, std=.02) - - def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False): - seq_length = self.max_mel_tokens + self.max_text_tokens + self.max_prompt_tokens - gpt_config = GPT2Config(vocab_size=self.max_mel_tokens, - n_positions=seq_length, - n_ctx=seq_length, - n_embd=self.model_dim, - n_layer=self.layers, - n_head=self.heads, - gradient_checkpointing=False, - use_cache=True) - self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head, kv_cache=kv_cache) - #print(f'use_deepspeed autoregressive_debug {use_deepspeed}') - if use_deepspeed and torch.cuda.is_available(): + def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False): + seq_length = self.max_mel_tokens + self.max_text_tokens + 2 + gpt_config = GPT2Config( + vocab_size=self.max_mel_tokens, + n_positions=seq_length, + n_ctx=seq_length, + n_embd=self.model_dim, + n_layer=self.layers, + n_head=self.heads, + gradient_checkpointing=False, + use_cache=True, + ) + self.inference_model = GPT2InferenceModel( + gpt_config, + self.gpt, + self.mel_pos_embedding, + self.mel_embedding, + self.final_norm, + self.mel_head, + kv_cache=kv_cache, + ) + if use_deepspeed and half and torch.cuda.is_available(): + import deepspeed + self.ds_engine = deepspeed.init_inference(model=self.inference_model, + mp_size=1, + replace_with_kernel_inject=True, + dtype=torch.float16) + self.inference_model = self.ds_engine.module.eval() + elif use_deepspeed and torch.cuda.is_available(): import deepspeed self.ds_engine = deepspeed.init_inference(model=self.inference_model, mp_size=1, @@ -375,9 +392,9 @@ class UnifiedVoice(nn.Module): self.inference_model = self.ds_engine.module.eval() else: self.inference_model = self.inference_model.eval() - - self.gpt.wte = self.mel_embedding + # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) + self.gpt.wte = self.mel_embedding def build_aligned_inputs_and_targets(self, input, start_token, stop_token): inp = F.pad(input, (1,0), value=start_token) tar = F.pad(input, (0,1), value=stop_token) @@ -493,16 +510,33 @@ class UnifiedVoice(nn.Module): loss_text = F.cross_entropy(text_logits, text_targets.long()) loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) return loss_text.mean(), loss_mel.mean(), mel_logits - + def compute_embeddings( + self, + cond_latents, + text_inputs, + ): + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token) + emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + conds = cond_latents.unsqueeze(1) + emb = torch.cat([conds, emb], dim=1) + self.inference_model.store_mel_emb(emb) + gpt_inputs = torch.full( + ( + emb.shape[0], + emb.shape[1] + 1, # +1 for the start_mel_token + ), + fill_value=1, + dtype=torch.long, + device=text_inputs.device, + ) + gpt_inputs[:, -1] = self.start_mel_token + return gpt_inputs def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1, - max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): - seq_length = self.max_mel_tokens + self.max_text_tokens + self.max_prompt_tokens - if not hasattr(self, 'inference_model'): - self.post_init_gpt2_config(kv_cache=self.kv_cache) - + max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) + text_inputs, _ = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) conds = speech_conditioning_latent.unsqueeze(1) @@ -528,7 +562,16 @@ class UnifiedVoice(nn.Module): num_return_sequences=num_return_sequences, **hf_generate_kwargs) return gen[:, trunc_index:] - + def get_generator(self, fake_inputs, **hf_generate_kwargs): + return self.inference_model.generate_stream( + fake_inputs, + bos_token_id=self.start_mel_token, + pad_token_id=self.stop_mel_token, + eos_token_id=self.stop_mel_token, + max_length=500, + do_stream=True, + **hf_generate_kwargs, + ) if __name__ == '__main__': gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4) l = gpt(torch.randn(2, 3, 80, 800), @@ -536,4 +579,4 @@ if __name__ == '__main__': torch.tensor([32, 120]), torch.randint(high=8192, size=(2,250)), torch.tensor([250*256,195*256])) - gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80])) \ No newline at end of file + gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80])) diff --git a/tortoise/models/hifigan_decoder.py b/tortoise/models/hifigan_decoder.py new file mode 100644 index 0000000..ae2f627 --- /dev/null +++ b/tortoise/models/hifigan_decoder.py @@ -0,0 +1,303 @@ +# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py +import torch +from torch import nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm, weight_norm + +LRELU_SLOPE = 0.1 + + +def get_padding(k, d): + return int((k * d - d) / 2) + + +class ResBlock1(torch.nn.Module): + """Residual Block Type 1. It has 3 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o + |--------------------------------------------------------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + ] + ) + + def forward(self, x): + """ + Args: + x (Tensor): input tensor. + Returns: + Tensor: output tensor. + Shapes: + x: [B, C, T] + """ + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + """Residual Block Type 2. It has 1 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o + |---------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class HifiganGenerator(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + resblock_type, + resblock_dilation_sizes, + resblock_kernel_sizes, + upsample_kernel_sizes, + upsample_initial_channel, + upsample_factors, + inference_padding=5, + cond_channels=0, + conv_pre_weight_norm=True, + conv_post_weight_norm=True, + conv_post_bias=True, + ): + r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) + + Network: + x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o + .. -> zI ---| + resblockN_kNx1 -> zN ---' + + Args: + in_channels (int): number of input tensor channels. + out_channels (int): number of output tensor channels. + resblock_type (str): type of the `ResBlock`. '1' or '2'. + resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`. + resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`. + upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution. + upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2 + for each consecutive upsampling layer. + upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer. + inference_padding (int): constant padding applied to the input at inference time. Defaults to 5. + """ + super().__init__() + self.inference_padding = inference_padding + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_factors) + # initial upsampling layers + self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if resblock_type == "1" else ResBlock2 + # upsampling layers + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + # MRF blocks + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + # post convolution layer + self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)) + if cond_channels > 0: + self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1) + + if not conv_pre_weight_norm: + remove_weight_norm(self.conv_pre) + + if not conv_post_weight_norm: + remove_weight_norm(self.conv_post) + + self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu') + if torch.backends.mps.is_available(): + self.device = torch.device('mps') + + def forward(self, x, g=None): + """ + Args: + x (Tensor): feature input tensor. + g (Tensor): global conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + o = self.conv_pre(x) + if hasattr(self, "cond_layer"): + o = o + self.cond_layer(g) + for i in range(self.num_upsamples): + o = F.leaky_relu(o, LRELU_SLOPE) + o = self.ups[i](o) + z_sum = None + for j in range(self.num_kernels): + if z_sum is None: + z_sum = self.resblocks[i * self.num_kernels + j](o) + else: + z_sum += self.resblocks[i * self.num_kernels + j](o) + o = z_sum / self.num_kernels + o = F.leaky_relu(o) + o = self.conv_post(o) + o = torch.tanh(o) + return o + + @torch.no_grad() + def inference(self, c, g=None): + """ + Args: + x (Tensor): conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + # c = c.to(self.conv_pre.weight.device) + # c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") + up_1 = torch.nn.functional.interpolate( + c.transpose(1,2), + scale_factor=[1024 / 256], + mode="linear", + ) + up_2 = torch.nn.functional.interpolate( + up_1, + scale_factor=[24000 / 22050], + mode="linear", + ) + g = g.unsqueeze(0) + return self.forward(up_2.to(self.device), g.transpose(1,2)) + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) diff --git a/tortoise/models/stream_generator.py b/tortoise/models/stream_generator.py new file mode 100644 index 0000000..a8dd07b --- /dev/null +++ b/tortoise/models/stream_generator.py @@ -0,0 +1,1057 @@ +# Adapted from: https://github.com/LowinLi/transformers-stream-generator + +from transformers import ( + GenerationConfig, + GenerationMixin, + LogitsProcessorList, + StoppingCriteriaList, + DisjunctiveConstraint, + BeamSearchScorer, + PhrasalConstraint, + ConstrainedBeamSearchScorer, + PreTrainedModel, +) +import numpy as np +import random +import warnings +import inspect +from transformers.generation.utils import GenerateOutput, SampleOutput, logger +import torch +from typing import Callable, List, Optional, Union +from torch import nn +import torch.distributed as dist +import copy + + +def setup_seed(seed): + if seed == -1: + return + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +class StreamGenerationConfig(GenerationConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.do_stream = kwargs.pop("do_stream", False) + + +class NewGenerationMixin(GenerationMixin): + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[StreamGenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[ + Callable[[int, torch.Tensor], List[int]] + ] = None, + synced_gpus: Optional[bool] = False, + seed=0, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + r""" + + Generates sequences of token ids for models with a language modeling head. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchDecoderOnlyOutput`], + - [`~generation.SampleDecoderOnlyOutput`], + - [`~generation.BeamSearchDecoderOnlyOutput`], + - [`~generation.BeamSampleDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchEncoderDecoderOutput`], + - [`~generation.SampleEncoderDecoderOutput`], + - [`~generation.BeamSearchEncoderDecoderOutput`], + - [`~generation.BeamSampleEncoderDecoderOutput`] + """ + setup_seed(seed) + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + self._validate_model_class() + + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + if generation_config is None: + # legacy: users may modify the model configuration to control generation -- update the generation config + # model attribute accordingly, if it was created from the model config + if self.generation_config._from_model_config: + new_generation_config = StreamGenerationConfig.from_model_config( + self.config + ) + if new_generation_config != self.generation_config: + warnings.warn( + "You have modified the pretrained model configuration to control generation. This is a" + " deprecated strategy to control generation and will be removed soon, in a future version." + " Please use a generation configuration file (see" + " https://huggingface.co/docs/transformers/main_classes/text_generation)" + ) + self.generation_config = new_generation_config + generation_config = self.generation_config + + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update( + **kwargs + ) # All unused kwargs must be model kwargs + # self._validate_model_kwargs(model_kwargs.copy()) + + # 2. Set generation parameters if not already defined + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + + if ( + generation_config.pad_token_id is None + and generation_config.eos_token_id is not None + ): + if model_kwargs.get("attention_mask", None) is None: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + logger.warning( + f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation." + ) + generation_config.pad_token_id = eos_token_id + + # 3. Define model inputs + # inputs_tensor has to be defined + # model_input_name is defined if model-specific keyword input is passed + # otherwise model_input_name is None + # all model-specific keyword inputs are removed from `model_kwargs` + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + + # 4. Define other model kwargs + model_kwargs["output_attentions"] = generation_config.output_attentions + model_kwargs["output_hidden_states"] = generation_config.output_hidden_states + model_kwargs["use_cache"] = generation_config.use_cache + + accepts_attention_mask = "attention_mask" in set( + inspect.signature(self.forward).parameters.keys() + ) + requires_attention_mask = "encoder_outputs" not in model_kwargs + + if ( + model_kwargs.get("attention_mask", None) is None + and requires_attention_mask + and accepts_attention_mask + ): + model_kwargs[ + "attention_mask" + ] = self._prepare_attention_mask_for_generation( + inputs_tensor, + generation_config.pad_token_id, + generation_config.eos_token_id, + ) + + # decoder-only models should use left-padding for generation + if not self.config.is_encoder_decoder: + if ( + generation_config.pad_token_id is not None + and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) + > 0 + ): + logger.warning( + "A decoder-only architecture is being used, but right-padding was detected! For correct " + "generation results, please set `padding_side='left'` when initializing the tokenizer." + ) + + if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: + # if model is encoder decoder encoder_outputs are created + # and added to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, model_kwargs, model_input_name + ) + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + if self.config.is_encoder_decoder: + input_ids = self._prepare_decoder_input_ids_for_generation( + batch_size, + decoder_start_token_id=generation_config.decoder_start_token_id, + bos_token_id=generation_config.bos_token_id, + model_kwargs=model_kwargs, + device=inputs_tensor.device, + ) + else: + # if decoder-only then inputs_tensor has to be `input_ids` + input_ids = inputs_tensor + + # 6. Prepare `max_length` depending on other stopping criteria. + input_ids_seq_length = input_ids.shape[-1] + has_default_max_length = ( + kwargs.get("max_length") is None + and generation_config.max_length is not None + ) + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to" + f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" + " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif has_default_max_length and generation_config.max_new_tokens is not None: + generation_config.max_length = ( + generation_config.max_new_tokens + input_ids_seq_length + ) + elif ( + not has_default_max_length and generation_config.max_new_tokens is not None + ): + raise ValueError( + "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" + " limit to the generated output length. Remove one of those arguments. Please refer to the" + " documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + + if ( + generation_config.min_length is not None + and generation_config.min_length > generation_config.max_length + ): + raise ValueError( + f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" + f" the maximum length ({generation_config.max_length})" + ) + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = ( + "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + ) + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 7. determine generation mode + is_constraint_gen_mode = ( + generation_config.constraints is not None + or generation_config.force_words_ids is not None + ) + + is_contrastive_search_gen_mode = ( + generation_config.top_k is not None + and generation_config.top_k > 1 + and generation_config.do_sample is False + and generation_config.penalty_alpha is not None + and generation_config.penalty_alpha > 0 + ) + + is_greedy_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_sample_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True + and generation_config.do_stream is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_sample_gen_stream_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_stream is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_beam_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_beam_sample_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_group_beam_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups > 1) + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + + if generation_config.num_beam_groups > generation_config.num_beams: + raise ValueError( + "`num_beam_groups` has to be smaller or equal to `num_beams`" + ) + if is_group_beam_gen_mode and generation_config.do_sample is True: + raise ValueError( + "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." + ) + + if self.device.type != input_ids.device.type: + warnings.warn( + "You are calling .generate() with the `input_ids` being on a device type different" + f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" + f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." + " Please make sure that you have put `input_ids` to the" + f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" + " running `.generate()`.", + UserWarning, + ) + # 8. prepare distribution pre_processing samplers + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + # 9. prepare stopping criteria + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + # 10. go into different generation modes + if is_greedy_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " greedy search." + ) + + # 11. run greedy search + return self.greedy_search( + input_ids, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_contrastive_search_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " contrastive search." + ) + + return self.contrastive_search( + input_ids, + top_k=generation_config.top_k, + penalty_alpha=generation_config.penalty_alpha, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_sample_gen_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample + return self.sample( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + elif is_sample_gen_stream_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample + return self.sample_stream( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + elif is_beam_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`." + ) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_beam_sample_gen_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + # 12. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size * generation_config.num_return_sequences, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + ) + + # 13. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams + * generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 14. run beam sample + return self.beam_sample( + input_ids, + beam_scorer, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_group_beam_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`." + ) + + if generation_config.num_beams % generation_config.num_beam_groups != 0: + raise ValueError( + "`num_beams` should be divisible by `num_beam_groups` for group beam search." + ) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + + has_default_typical_p = ( + kwargs.get("typical_p") is None and generation_config.typical_p == 1.0 + ) + if not has_default_typical_p: + raise ValueError( + "Decoder argument `typical_p` is not supported with beam groups." + ) + + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + max_length=stopping_criteria.max_length, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + num_beam_groups=generation_config.num_beam_groups, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.group_beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_constraint_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`." + ) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + + if generation_config.num_beams <= 1: + raise ValueError( + "`num_beams` needs to be greater than 1 for constrained generation." + ) + + if generation_config.do_sample: + raise ValueError( + "`do_sample` needs to be false for constrained generation." + ) + + if ( + generation_config.num_beam_groups is not None + and generation_config.num_beam_groups > 1 + ): + raise ValueError( + "`num_beam_groups` not supported yet for constrained generation." + ) + + final_constraints = [] + if generation_config.constraints is not None: + final_constraints = generation_config.constraints + + if generation_config.force_words_ids is not None: + + def typeerror(): + raise ValueError( + "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" + f"of positive integers, but is {generation_config.force_words_ids}." + ) + + if ( + not isinstance(generation_config.force_words_ids, list) + or len(generation_config.force_words_ids) == 0 + ): + typeerror() + + for word_ids in generation_config.force_words_ids: + if isinstance(word_ids[0], list): + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any( + not isinstance(token_ids, list) for token_ids in word_ids + ): + typeerror() + if any( + any( + (not isinstance(token_id, int) or token_id < 0) + for token_id in token_ids + ) + for token_ids in word_ids + ): + typeerror() + + constraint = DisjunctiveConstraint(word_ids) + else: + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any( + (not isinstance(token_id, int) or token_id < 0) + for token_id in word_ids + ): + typeerror() + + constraint = PhrasalConstraint(word_ids) + final_constraints.append(constraint) + + # 11. prepare beam search scorer + constrained_beam_scorer = ConstrainedBeamSearchScorer( + constraints=final_constraints, + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.constrained_beam_search( + input_ids, + constrained_beam_scorer=constrained_beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + @torch.no_grad() + def sample_stream( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ) -> Union[SampleOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + + + In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead. + For an overview of generation strategies and code examples, check the [following + guide](./generation_strategies). + + + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`: + A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... TopKLogitsWarper, + ... TemperatureLogitsWarper, + ... StoppingCriteriaList, + ... MaxLengthCriteria, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token + >>> model.config.pad_token_id = model.config.eos_token_id + >>> model.generation_config.pad_token_id = model.config.eos_token_id + + >>> input_prompt = "Today is a beautiful day, and" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id), + ... ] + ... ) + >>> # instantiate logits processors + >>> logits_warper = LogitsProcessorList( + ... [ + ... TopKLogitsWarper(50), + ... TemperatureLogitsWarper(0.7), + ... ] + ... ) + + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + + >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT + >>> outputs = model.sample( + ... input_ids, + ... logits_processor=logits_processor, + ... logits_warper=logits_warper, + ... stopping_criteria=stopping_criteria, + ... ) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] + ```""" + # init values + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length + ) + logits_warper = ( + logits_warper if logits_warper is not None else LogitsProcessorList() + ) + pad_token_id = ( + pad_token_id + if pad_token_id is not None + else self.generation_config.pad_token_id + ) + eos_token_id = ( + eos_token_id + if eos_token_id is not None + else self.generation_config.eos_token_id + ) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + output_scores = ( + output_scores + if output_scores is not None + else self.generation_config.output_scores + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None + ) + + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + this_peer_finished = False # used by synced_gpus only + # auto-regressive generation + while True: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0 + ).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError( + "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." + ) + next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( + 1 - unfinished_sequences + ) + yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1]) + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul( + (sum(next_tokens != i for i in eos_token_id)).long() + ) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + +def init_stream_support(): + """Overload PreTrainedModel for streaming.""" + PreTrainedModel.generate_stream = NewGenerationMixin.generate + PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream + + +if __name__ == "__main__": + from transformers import PreTrainedModel + from transformers import AutoTokenizer, AutoModelForCausalLM + + PreTrainedModel.generate = NewGenerationMixin.generate + PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream + model = AutoModelForCausalLM.from_pretrained( + "bigscience/bloom-560m", torch_dtype=torch.float16 + ) + + tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") + model = model.to("cuda:0") + model = model.eval() + prompt_text = "hello? \n" + input_ids = tokenizer( + prompt_text, return_tensors="pt", add_special_tokens=False + ).input_ids + input_ids = input_ids.to("cuda:0") + + with torch.no_grad(): + result = model.generate( + input_ids, + max_new_tokens=200, + do_sample=True, + top_k=30, + top_p=0.85, + temperature=0.35, + repetition_penalty=1.2, + early_stopping=True, + seed=0, + ) + print(tokenizer.decode(result, skip_special_tokens=True)) + generator = model.generate( + input_ids, + max_new_tokens=200, + do_sample=True, + top_k=30, + top_p=0.85, + temperature=0.35, + repetition_penalty=1.2, + early_stopping=True, + seed=0, + do_stream=True, + ) + stream_result = "" + for x in generator: + chunk = tokenizer.decode(x, skip_special_tokens=True) + stream_result += chunk + print(stream_result)