added tqdm bar for AR

This commit is contained in:
mrq 2024-06-19 15:00:14 -05:00
parent 99be487482
commit 849de13f27
3 changed files with 12 additions and 6 deletions

View File

@ -69,6 +69,7 @@ def download_model( save_path, chunkSize = 1024, unit = "MiB" ):
bar.update( len(chunk) / scale )
f.write(chunk)
bar.close()
# semi-necessary as a way to provide a mechanism for other portions of the program to access models
@cache

View File

@ -626,7 +626,7 @@ class GaussianDiffusion:
img = torch.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
for i in tqdm(indices, disable=not progress):
for i in tqdm(indices, disable=not progress, desc="Diffusion"):
t = torch.tensor([i] * shape[0], device=device)
with torch.no_grad():
out = self.p_sample(
@ -791,10 +791,7 @@ class GaussianDiffusion:
img = torch.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
indices = tqdm(indices, disable=not progress)
for i in indices:
for i in tqdm(indices, disable=not progress, desc="Diffusion"):
t = torch.tensor([i] * shape[0], device=device)
with torch.no_grad():
out = self.ddim_sample(

View File

@ -12,6 +12,7 @@ from .arch_utils import AttentionBlock
from transformers import LogitsWarper
from transformers import GPT2Config, GPT2Model
from tqdm import tqdm
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
@ -217,6 +218,9 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
lm_logits = self.lm_head(hidden_states)
if hasattr(self, "bar"):
self.bar.update( 1 )
if not return_dict:
return (lm_logits,) + transformer_outputs[1:]
@ -589,7 +593,7 @@ class UnifiedVoice(nn.Module):
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
max_generate_length=None, typical_sampling=False, typical_mass=.9, kv_cache=True, **hf_generate_kwargs):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
if not hasattr(self, 'inference_model'):
# TODO: Decouple gpt_config from this inference model.
self.post_init_gpt2_config(kv_cache = kv_cache)
@ -616,9 +620,13 @@ class UnifiedVoice(nn.Module):
logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
# yucky, why doesn't the base HF GenerationMixin have a tqdm exposed
self.inference_model.bar = tqdm( unit="it", total=max_length, desc="AR" )
gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
max_length=max_length, logits_processor=logits_processor,
num_return_sequences=num_return_sequences, **hf_generate_kwargs)
self.inference_model.bar.close()
return gen[:, trunc_index:]