integrate new autoregressive model and fix new diffusion bug

This commit is contained in:
James Betker 2022-04-04 16:51:35 -06:00
parent 9043dde3f9
commit 33e4bc7907
5 changed files with 549 additions and 10 deletions

7
api.py
View File

@ -117,13 +117,14 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_
cond_mels.append(cond_mel)
cond_mels = torch.stack(cond_mels, dim=1)
output_shape = (mel_codes.shape[0], 100, mel_codes.shape[-1]*4)
precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, False)
output_seq_len = mel_codes.shape[-1]*4*24000//22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
output_shape = (mel_codes.shape[0], 100, output_seq_len)
precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, output_seq_len, False)
noise = torch.randn(output_shape, device=mel_codes.device) * temperature
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
return denormalize_tacotron_mel(mel)[:,:,:mel_codes.shape[-1]*4]
return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
class TextToSpeech:

245
api_new_autoregressive.py Normal file
View File

@ -0,0 +1,245 @@
import argparse
import os
import random
from urllib import request
import torch
import torch.nn.functional as F
import torchaudio
import progressbar
import ocotillo
from models.diffusion_decoder import DiffusionTts
from models.autoregressive import UnifiedVoice
from tqdm import tqdm
from models.arch_util import TorchMelSpectrogram
from models.new_autoregressive import AutoregressiveCodegen
from models.text_voice_clip import VoiceCLIP
from models.vocoder import UnivNetGenerator
from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel
from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
from utils.tokenizer import VoiceBpeTokenizer, lev_distance
pbar = None
def download_models():
MODELS = {
'clip.pth': 'https://huggingface.co/jbetker/tortoise-tts-clip/resolve/main/pytorch-model.bin',
'diffusion.pth': 'https://huggingface.co/jbetker/tortoise-tts-diffusion-v1/resolve/main/pytorch-model.bin',
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-autoregressive/resolve/main/pytorch-model.bin'
}
os.makedirs('.models', exist_ok=True)
def show_progress(block_num, block_size, total_size):
global pbar
if pbar is None:
pbar = progressbar.ProgressBar(maxval=total_size)
pbar.start()
downloaded = block_num * block_size
if downloaded < total_size:
pbar.update(downloaded)
else:
pbar.finish()
pbar = None
for model_name, url in MODELS.items():
if os.path.exists(f'.models/{model_name}'):
continue
print(f'Downloading {model_name} from {url}...')
request.urlretrieve(url, f'.models/{model_name}', show_progress)
print('Done.')
def pad_or_truncate(t, length):
if t.shape[-1] == length:
return t
elif t.shape[-1] < length:
return F.pad(t, (0, length-t.shape[-1]))
else:
return t[..., :length]
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1):
"""
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
"""
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
conditioning_free=cond_free, conditioning_free_k=cond_free_k)
def load_conditioning(clip, cond_length=132300):
gap = clip.shape[-1] - cond_length
if gap < 0:
clip = F.pad(clip, pad=(0, abs(gap)))
elif gap > 0:
rand_start = random.randint(0, gap)
clip = clip[:, rand_start:rand_start + cond_length]
mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0)
return mel_clip.unsqueeze(0).cuda()
def fix_autoregressive_output(codes, stop_token):
"""
This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
trained on and what the autoregressive code generator creates (which has no padding or end).
This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with
a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE
and copying out the last few codes.
Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar.
"""
# Strip off the autoregressive stop token and add padding.
stop_token_indices = (codes == stop_token).nonzero()
if len(stop_token_indices) == 0:
print("No stop tokens found, enjoy that output of yours!")
return codes
else:
codes[stop_token_indices] = 83
stm = stop_token_indices.min().item()
codes[stm:] = 83
if stm - 3 < codes.shape[0]:
codes[-3] = 45
codes[-2] = 45
codes[-1] = 248
return codes
def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_samples, temperature=1):
"""
Uses the specified diffusion model to convert discrete codes into a spectrogram.
"""
with torch.no_grad():
cond_mels = []
for sample in conditioning_samples:
sample = pad_or_truncate(sample, 102400)
cond_mel = wav_to_univnet_mel(sample.to(mel_codes.device), do_normalization=False)
cond_mels.append(cond_mel)
cond_mels = torch.stack(cond_mels, dim=1)
output_seq_len = mel_codes.shape[-1]*4*24000//22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
output_shape = (mel_codes.shape[0], 100, output_seq_len)
precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, output_seq_len, False)
noise = torch.randn(output_shape, device=mel_codes.device) * temperature
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
class TextToSpeech:
def __init__(self, autoregressive_batch_size=32):
self.autoregressive_batch_size = autoregressive_batch_size
self.tokenizer = VoiceBpeTokenizer()
download_models()
self.autoregressive = AutoregressiveCodegen(512, 12).cpu().eval()
self.autoregressive.load_state_dict(torch.load('D:\\dlas\\experiments\\train_autoregressive_codegen\\models\\23000_codegen_ema.pth'))
self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
text_seq_len=350, text_heads=8,
num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430,
use_xformers=True).cpu().eval()
self.clip.load_state_dict(torch.load('.models/clip.pth'))
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
layer_drop=0, unconditioned_percentage=0).cpu().eval()
self.diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
self.vocoder = UnivNetGenerator().cpu()
self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
self.vocoder.eval(inference=True)
def tts(self, text, voice_samples, k=1,
# autoregressive generation parameters follow
num_autoregressive_samples=512, temperature=.5, length_penalty=2, repetition_penalty=2.0, top_p=.5,
typical_sampling=False, typical_mass=.9,
# diffusion generation parameters follow
diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=.7,):
text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
text = F.pad(text, (0, 1)) # This may not be necessary.
conds = []
if not isinstance(voice_samples, list):
voice_samples = [voice_samples]
for vs in voice_samples:
conds.append(load_conditioning(vs))
conds = torch.stack(conds, dim=1)
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
with torch.no_grad():
samples = []
num_batches = num_autoregressive_samples // self.autoregressive_batch_size
stop_mel_token = self.autoregressive.STOP_TOKEN
self.autoregressive = self.autoregressive.cuda()
for _ in tqdm(range(num_batches)):
codes = self.autoregressive.generate(conds, text,
do_sample=True,
top_p=top_p,
temperature=temperature,
num_return_sequences=self.autoregressive_batch_size,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
typical_sampling=typical_sampling,
typical_mass=typical_mass)
padding_needed = 250 - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes)
#self.autoregressive = self.autoregressive.cpu()
clip_results = []
self.clip = self.clip.cuda()
for batch in samples:
for i in range(batch.shape[0]):
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
bad_toks = batch >= 8192
batch = batch * bad_toks.logical_not()
clip_results.append(self.clip(text.repeat(batch.shape[0], 1), batch, return_loss=False))
clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=k).indices]
self.clip = self.clip.cpu()
del samples
print("Performing vocoding..")
wav_candidates = []
self.diffusion = self.diffusion.cuda()
self.vocoder = self.vocoder.cuda()
for b in range(best_results.shape[0]):
code = best_results[b].unsqueeze(0)
mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, voice_samples, temperature=diffusion_temperature)
wav = self.vocoder.inference(mel)
wav_candidates.append(wav.cpu())
self.diffusion = self.diffusion.cpu()
self.vocoder = self.vocoder.cpu()
if len(wav_candidates) > 1:
return wav_candidates
return wav_candidates[0]
def refine_for_intellibility(self, wav_candidates, corresponding_codes, output_path):
"""
Further refine the remaining candidates using a ASR model to pick out the ones that are the most understandable.
TODO: finish this function
:param wav_candidates:
:return:
"""
transcriber = ocotillo.Transcriber(on_cuda=True)
transcriptions = transcriber.transcribe_batch(torch.cat(wav_candidates, dim=0).squeeze(1), 24000)
best = 99999999
for i, transcription in enumerate(transcriptions):
dist = lev_distance(transcription, args.text.lower())
if dist < best:
best = dist
best_codes = corresponding_codes[i].unsqueeze(0)
best_wav = wav_candidates[i]
del transcriber
torchaudio.save(os.path.join(output_path, f'{voice}_poor.wav'), best_wav.squeeze(0).cpu(), 24000)
# Perform diffusion again with the high-quality diffuser.
mel = do_spectrogram_diffusion(diffusion, final_diffuser, best_codes, cond_diffusion, mean=False)
wav = vocoder.inference(mel)
torchaudio.save(os.path.join(args.output_path, f'{voice}.wav'), wav.squeeze(0).cpu(), 24000)

View File

@ -5,7 +5,7 @@ import torch
import torch.nn.functional as F
import torchaudio
from api import TextToSpeech, load_conditioning
from api_new_autoregressive import TextToSpeech, load_conditioning
from utils.audio import load_audio
from utils.tokenizer import VoiceBpeTokenizer
@ -28,7 +28,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice,harris,lescault,otto,atkins,grace,kennard,mol')
parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=512)
parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=32)
parser.add_argument('-batch_size', type=int, help='How many samples to process at once in the autoregressive model.', default=16)
parser.add_argument('-num_diffusion_samples', type=int, help='Number of outputs that progress to the diffusion stage.', default=16)
parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/')

View File

@ -212,7 +212,7 @@ class DiffusionTts(nn.Module):
}
return groups
def timestep_independent(self, aligned_conditioning, conditioning_input, return_code_pred):
def timestep_independent(self, aligned_conditioning, conditioning_input, expected_seq_len, return_code_pred):
# Shuffle aligned_latent to BxCxS format
if is_latent(aligned_conditioning):
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
@ -227,7 +227,7 @@ class DiffusionTts(nn.Module):
cond_emb = conds.mean(dim=-1)
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
if is_latent(aligned_conditioning):
code_emb = self.latent_converter(aligned_conditioning)
code_emb = self.autoregressive_latent_converter(aligned_conditioning)
else:
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
code_emb = self.code_converter(code_emb)
@ -240,7 +240,7 @@ class DiffusionTts(nn.Module):
device=code_emb.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
code_emb)
expanded_code_emb = F.interpolate(code_emb, size=aligned_conditioning.shape[-1]*4, mode='nearest')
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
if not return_code_pred:
return expanded_code_emb
@ -250,7 +250,6 @@ class DiffusionTts(nn.Module):
mel_pred = mel_pred * unconditioned_batches.logical_not()
return expanded_code_emb, mel_pred
def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
"""
Apply the model to an input batch.
@ -275,11 +274,12 @@ class DiffusionTts(nn.Module):
if precomputed_aligned_embeddings is not None:
code_emb = precomputed_aligned_embeddings
else:
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, True)
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, x.shape[-1], True)
if is_latent(aligned_conditioning):
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
else:
unused_params.extend(list(self.latent_converter.parameters()))
unused_params.append(self.unconditioned_embedding)
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

View File

@ -0,0 +1,293 @@
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2PreTrainedModel, GPT2Config
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from x_transformers import TransformerWrapper, Encoder, Decoder
from models.arch_util import AttentionBlock
class InferenceModel(GPT2PreTrainedModel):
"""
Implementation of GPT2PreTrainedModel from transformers, which allows us to use their generation library with
this transformer.
"""
def __init__(self, model):
super().__init__(GPT2Config())
self.transformer = model
self.context = None
def parallelize(self, device_map=None):
# Not implemented.
pass
def deparallelize(self):
# Not implemented.
pass
def get_output_embeddings(self):
assert False, "Unsupported operation."
def set_output_embeddings(self, new_embeddings):
assert False, "Unsupported operation."
def store_context(self, context):
self.context = context
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
assert self.context is not None
assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model.
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.transformer.decoder(input_ids, context=self.context, return_embeddings=True)
logits = self.transformer.decoder.transformer.to_logits(hidden_states)
if not return_dict:
return (logits, )
return CausalLMOutputWithCrossAttentions(
loss=None,
logits=logits,
past_key_values=None,
hidden_states=hidden_states,
attentions=None,
cross_attentions=None,
)
@staticmethod
def _reorder_cache(past, beam_idx):
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
class ResBlock(nn.Module):
"""
Basic residual convolutional block that uses GroupNorm.
"""
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan//8, chan),
nn.ReLU(),
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan//8, chan)
)
def forward(self, x):
return F.relu(self.net(x) + x)
class ConditioningEncoder(nn.Module):
def __init__(self,
spec_dim,
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
do_checkpointing=False):
super().__init__()
attn = []
self.init = nn.Sequential(nn.Conv1d(spec_dim, embedding_dim//4, kernel_size=5, padding=2),
nn.Conv1d(embedding_dim//4, embedding_dim//2, kernel_size=3, padding=1, stride=2),
ResBlock(embedding_dim//2),
nn.Conv1d(embedding_dim//2, embedding_dim, kernel_size=3, padding=1, stride=2))
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
def forward(self, x):
h = self.init(x)
h = self.attn(h)
return h.mean(dim=2)
class CheckpointedLayer(nn.Module):
"""
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
checkpoint for all other args.
"""
def __init__(self, wrap):
super().__init__()
self.wrap = wrap
def forward(self, x, *args, **kwargs):
for k, v in kwargs.items():
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
partial = functools.partial(self.wrap, **kwargs)
return torch.utils.checkpoint.checkpoint(partial, x, *args)
class CheckpointedXTransformerWrapper(nn.Module):
"""
Wraps a TransformerWrapper and applies CheckpointedLayer to each layer.
"""
def __init__(self, checkpoint=True, **xtransformer_kwargs):
super().__init__()
self.transformer = TransformerWrapper(**xtransformer_kwargs)
if not checkpoint:
return
for i in range(len(self.transformer.attn_layers.layers)):
n, b, r = self.transformer.attn_layers.layers[i]
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
def forward(self, x, **kwargs):
return self.transformer(x, **kwargs)
class AutoregressiveCodegen(nn.Module):
def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, max_text_tokens=4000,
max_mel_tokens=4000, dropout=.1):
super().__init__()
self.START_TOKEN=8192
self.STOP_TOKEN=8193
self.max_mel_tokens = max_mel_tokens
self.minicoder = ConditioningEncoder(80, model_dim, do_checkpointing=False)
self.encoder = CheckpointedXTransformerWrapper(
num_tokens=num_text_tokens,
max_seq_len=max_text_tokens,
attn_layers = Encoder(
depth=depth//2,
heads=model_dim//64,
dim=model_dim,
attn_dropout=dropout,
ff_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
ff_mult=1,
rotary_pos_emb=True,
rel_pos_bias=True,
))
self.decoder = CheckpointedXTransformerWrapper(
num_tokens=num_mel_tokens,
max_seq_len=max_mel_tokens,
attn_layers=Decoder(
depth=depth,
heads=model_dim//64,
dim=model_dim,
attn_dropout=dropout,
ff_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
ff_mult=1,
rotary_pos_emb=True,
rel_pos_bias=True,
cross_attend=True,
))
def get_grad_norm_parameter_groups(self):
return {
'encoder': list(self.encoder.parameters()),
'decoder': list(self.decoder.parameters()),
'minicoder': list(self.minicoder.parameters()),
}
def forward(self, text_codes, conditioning_signal, mel_codes, wav_lengths, return_loss=True):
# Format mel_codes with a stop token on the end.
mel_lengths = wav_lengths // 1024 + 1
for b in range(mel_codes.shape[0]):
mel_codes[b, mel_lengths[b]:] = self.STOP_TOKEN
mel_codes = F.pad(mel_codes, (0, 1), value=self.STOP_TOKEN)
# Build the context
if len(conditioning_signal.shape) != 4:
conditioning_signal = conditioning_signal.unsqueeze(1)
cond_embs = []
for i in range(conditioning_signal.shape[1]):
cond_embs.append(self.minicoder(conditioning_signal[:, i]))
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
enc_text = self.encoder(text_codes, return_embeddings=True)
context = torch.cat([cond_emb, enc_text], dim=1)
# Execute the decoder
dec_inputs = F.pad(mel_codes, (1,0), value=self.START_TOKEN)[:, :-1]
dec = self.decoder(dec_inputs, context=context)
if not return_loss:
return dec
loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
return loss_mel
def generate(self, conditioning_signal, text_codes, **hf_generate_kwargs):
if not hasattr(self, 'inference_model'):
self.inference_model = InferenceModel(self)
if len(conditioning_signal.shape) != 4:
conditioning_signal = conditioning_signal.unsqueeze(1)
cond_embs = []
for i in range(conditioning_signal.shape[1]):
cond_embs.append(self.minicoder(conditioning_signal[:, i]))
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
enc_text = self.encoder(text_codes, return_embeddings=True)
context = torch.cat([cond_emb, enc_text], dim=1)
self.inference_model.store_context(context)
gen = self.inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
max_length=250, output_attentions=False, return_dict_in_generate=True,
**hf_generate_kwargs)
return gen.sequences
if __name__ == '__main__':
codegen = AutoregressiveCodegen(1024, 20)
codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200)))
codegen(torch.randint(0,256, (2,200)),
torch.randn(2,80,120),
torch.randint(0,8192, (2,350)),
torch.tensor([192,350]))