added tqdm bar for AR
This commit is contained in:
parent
99be487482
commit
849de13f27
|
@ -69,6 +69,7 @@ def download_model( save_path, chunkSize = 1024, unit = "MiB" ):
|
|||
|
||||
bar.update( len(chunk) / scale )
|
||||
f.write(chunk)
|
||||
bar.close()
|
||||
|
||||
# semi-necessary as a way to provide a mechanism for other portions of the program to access models
|
||||
@cache
|
||||
|
|
|
@ -626,7 +626,7 @@ class GaussianDiffusion:
|
|||
img = torch.randn(*shape, device=device)
|
||||
indices = list(range(self.num_timesteps))[::-1]
|
||||
|
||||
for i in tqdm(indices, disable=not progress):
|
||||
for i in tqdm(indices, disable=not progress, desc="Diffusion"):
|
||||
t = torch.tensor([i] * shape[0], device=device)
|
||||
with torch.no_grad():
|
||||
out = self.p_sample(
|
||||
|
@ -791,10 +791,7 @@ class GaussianDiffusion:
|
|||
img = torch.randn(*shape, device=device)
|
||||
indices = list(range(self.num_timesteps))[::-1]
|
||||
|
||||
if progress:
|
||||
indices = tqdm(indices, disable=not progress)
|
||||
|
||||
for i in indices:
|
||||
for i in tqdm(indices, disable=not progress, desc="Diffusion"):
|
||||
t = torch.tensor([i] * shape[0], device=device)
|
||||
with torch.no_grad():
|
||||
out = self.ddim_sample(
|
||||
|
|
|
@ -12,6 +12,7 @@ from .arch_utils import AttentionBlock
|
|||
|
||||
from transformers import LogitsWarper
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
from tqdm import tqdm
|
||||
|
||||
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
|
||||
|
||||
|
@ -217,6 +218,9 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
if hasattr(self, "bar"):
|
||||
self.bar.update( 1 )
|
||||
|
||||
if not return_dict:
|
||||
return (lm_logits,) + transformer_outputs[1:]
|
||||
|
||||
|
@ -589,7 +593,7 @@ class UnifiedVoice(nn.Module):
|
|||
|
||||
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
|
||||
max_generate_length=None, typical_sampling=False, typical_mass=.9, kv_cache=True, **hf_generate_kwargs):
|
||||
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
||||
|
||||
if not hasattr(self, 'inference_model'):
|
||||
# TODO: Decouple gpt_config from this inference model.
|
||||
self.post_init_gpt2_config(kv_cache = kv_cache)
|
||||
|
@ -616,9 +620,13 @@ class UnifiedVoice(nn.Module):
|
|||
|
||||
logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
|
||||
max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
|
||||
|
||||
# yucky, why doesn't the base HF GenerationMixin have a tqdm exposed
|
||||
self.inference_model.bar = tqdm( unit="it", total=max_length, desc="AR" )
|
||||
gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
||||
max_length=max_length, logits_processor=logits_processor,
|
||||
num_return_sequences=num_return_sequences, **hf_generate_kwargs)
|
||||
self.inference_model.bar.close()
|
||||
return gen[:, trunc_index:]
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user