added tqdm bar for AR
This commit is contained in:
parent
99be487482
commit
849de13f27
|
@ -69,6 +69,7 @@ def download_model( save_path, chunkSize = 1024, unit = "MiB" ):
|
||||||
|
|
||||||
bar.update( len(chunk) / scale )
|
bar.update( len(chunk) / scale )
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
|
bar.close()
|
||||||
|
|
||||||
# semi-necessary as a way to provide a mechanism for other portions of the program to access models
|
# semi-necessary as a way to provide a mechanism for other portions of the program to access models
|
||||||
@cache
|
@cache
|
||||||
|
|
|
@ -626,7 +626,7 @@ class GaussianDiffusion:
|
||||||
img = torch.randn(*shape, device=device)
|
img = torch.randn(*shape, device=device)
|
||||||
indices = list(range(self.num_timesteps))[::-1]
|
indices = list(range(self.num_timesteps))[::-1]
|
||||||
|
|
||||||
for i in tqdm(indices, disable=not progress):
|
for i in tqdm(indices, disable=not progress, desc="Diffusion"):
|
||||||
t = torch.tensor([i] * shape[0], device=device)
|
t = torch.tensor([i] * shape[0], device=device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out = self.p_sample(
|
out = self.p_sample(
|
||||||
|
@ -791,10 +791,7 @@ class GaussianDiffusion:
|
||||||
img = torch.randn(*shape, device=device)
|
img = torch.randn(*shape, device=device)
|
||||||
indices = list(range(self.num_timesteps))[::-1]
|
indices = list(range(self.num_timesteps))[::-1]
|
||||||
|
|
||||||
if progress:
|
for i in tqdm(indices, disable=not progress, desc="Diffusion"):
|
||||||
indices = tqdm(indices, disable=not progress)
|
|
||||||
|
|
||||||
for i in indices:
|
|
||||||
t = torch.tensor([i] * shape[0], device=device)
|
t = torch.tensor([i] * shape[0], device=device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out = self.ddim_sample(
|
out = self.ddim_sample(
|
||||||
|
|
|
@ -12,6 +12,7 @@ from .arch_utils import AttentionBlock
|
||||||
|
|
||||||
from transformers import LogitsWarper
|
from transformers import LogitsWarper
|
||||||
from transformers import GPT2Config, GPT2Model
|
from transformers import GPT2Config, GPT2Model
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
|
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
|
||||||
|
|
||||||
|
@ -217,6 +218,9 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||||
|
|
||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
if hasattr(self, "bar"):
|
||||||
|
self.bar.update( 1 )
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (lm_logits,) + transformer_outputs[1:]
|
return (lm_logits,) + transformer_outputs[1:]
|
||||||
|
|
||||||
|
@ -589,7 +593,7 @@ class UnifiedVoice(nn.Module):
|
||||||
|
|
||||||
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
|
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
|
||||||
max_generate_length=None, typical_sampling=False, typical_mass=.9, kv_cache=True, **hf_generate_kwargs):
|
max_generate_length=None, typical_sampling=False, typical_mass=.9, kv_cache=True, **hf_generate_kwargs):
|
||||||
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
|
||||||
if not hasattr(self, 'inference_model'):
|
if not hasattr(self, 'inference_model'):
|
||||||
# TODO: Decouple gpt_config from this inference model.
|
# TODO: Decouple gpt_config from this inference model.
|
||||||
self.post_init_gpt2_config(kv_cache = kv_cache)
|
self.post_init_gpt2_config(kv_cache = kv_cache)
|
||||||
|
@ -616,9 +620,13 @@ class UnifiedVoice(nn.Module):
|
||||||
|
|
||||||
logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
|
logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
|
||||||
max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
|
max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
|
||||||
|
|
||||||
|
# yucky, why doesn't the base HF GenerationMixin have a tqdm exposed
|
||||||
|
self.inference_model.bar = tqdm( unit="it", total=max_length, desc="AR" )
|
||||||
gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
||||||
max_length=max_length, logits_processor=logits_processor,
|
max_length=max_length, logits_processor=logits_processor,
|
||||||
num_return_sequences=num_return_sequences, **hf_generate_kwargs)
|
num_return_sequences=num_return_sequences, **hf_generate_kwargs)
|
||||||
|
self.inference_model.bar.close()
|
||||||
return gen[:, trunc_index:]
|
return gen[:, trunc_index:]
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user