From 849de13f27db1f1d945ec12093de536a2003379a Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 19 Jun 2024 15:00:14 -0500 Subject: [PATCH] added tqdm bar for AR --- tortoise_tts/models/__init__.py | 1 + tortoise_tts/models/diffusion.py | 7 ++----- tortoise_tts/models/unified_voice.py | 10 +++++++++- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tortoise_tts/models/__init__.py b/tortoise_tts/models/__init__.py index a061906..d5da18d 100755 --- a/tortoise_tts/models/__init__.py +++ b/tortoise_tts/models/__init__.py @@ -69,6 +69,7 @@ def download_model( save_path, chunkSize = 1024, unit = "MiB" ): bar.update( len(chunk) / scale ) f.write(chunk) + bar.close() # semi-necessary as a way to provide a mechanism for other portions of the program to access models @cache diff --git a/tortoise_tts/models/diffusion.py b/tortoise_tts/models/diffusion.py index 2b1926a..1d65270 100644 --- a/tortoise_tts/models/diffusion.py +++ b/tortoise_tts/models/diffusion.py @@ -626,7 +626,7 @@ class GaussianDiffusion: img = torch.randn(*shape, device=device) indices = list(range(self.num_timesteps))[::-1] - for i in tqdm(indices, disable=not progress): + for i in tqdm(indices, disable=not progress, desc="Diffusion"): t = torch.tensor([i] * shape[0], device=device) with torch.no_grad(): out = self.p_sample( @@ -791,10 +791,7 @@ class GaussianDiffusion: img = torch.randn(*shape, device=device) indices = list(range(self.num_timesteps))[::-1] - if progress: - indices = tqdm(indices, disable=not progress) - - for i in indices: + for i in tqdm(indices, disable=not progress, desc="Diffusion"): t = torch.tensor([i] * shape[0], device=device) with torch.no_grad(): out = self.ddim_sample( diff --git a/tortoise_tts/models/unified_voice.py b/tortoise_tts/models/unified_voice.py index 72c55fb..ad37413 100644 --- a/tortoise_tts/models/unified_voice.py +++ b/tortoise_tts/models/unified_voice.py @@ -12,6 +12,7 @@ from .arch_utils import AttentionBlock from transformers import LogitsWarper from transformers import GPT2Config, GPT2Model +from tqdm import tqdm AVAILABLE_ATTENTIONS = ["mem_efficient", "math"] @@ -217,6 +218,9 @@ class GPT2InferenceModel(GPT2PreTrainedModel): lm_logits = self.lm_head(hidden_states) + if hasattr(self, "bar"): + self.bar.update( 1 ) + if not return_dict: return (lm_logits,) + transformer_outputs[1:] @@ -589,7 +593,7 @@ class UnifiedVoice(nn.Module): def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1, max_generate_length=None, typical_sampling=False, typical_mass=.9, kv_cache=True, **hf_generate_kwargs): - seq_length = self.max_mel_tokens + self.max_text_tokens + 2 + if not hasattr(self, 'inference_model'): # TODO: Decouple gpt_config from this inference model. self.post_init_gpt2_config(kv_cache = kv_cache) @@ -616,9 +620,13 @@ class UnifiedVoice(nn.Module): logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList() max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length + + # yucky, why doesn't the base HF GenerationMixin have a tqdm exposed + self.inference_model.bar = tqdm( unit="it", total=max_length, desc="AR" ) gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, max_length=max_length, logits_processor=logits_processor, num_return_sequences=num_return_sequences, **hf_generate_kwargs) + self.inference_model.bar.close() return gen[:, trunc_index:]