Compare commits

..

21 Commits
master ... main

Author SHA1 Message Date
mrq
95f679f4ba possible fix for when candidates >= samples 2023-10-10 15:30:08 +00:00
mrq
bf3b6c87aa added compat for coqui's XTTS 2023-09-16 03:38:21 +00:00
mrq
d7e6914fb8 Merge pull request 'main' (#47) from ken11o2/tortoise-tts:main into main
Reviewed-on: mrq/tortoise-tts#47
2023-09-04 20:01:14 +00:00
ken11o2
b7c7fd1c5f add arg use_deepspeed 2023-09-04 19:14:53 +00:00
ken11o2
2478dc255e update TextToSpeech 2023-09-04 19:13:45 +00:00
ken11o2
18adfaf785 add use_deepspeed to contructor and update method post_init_gpt2_config 2023-09-04 19:12:13 +00:00
ken11o2
ac97c17bf7 add use_deepspeed 2023-09-04 19:10:27 +00:00
mrq
b10c58436d pesky dot 2023-08-20 22:41:55 -05:00
mrq
cbd3c95c42 possible speedup with one simple trick (it worked for valle inferencing), also backported the voice list loading from aivc 2023-08-20 22:32:01 -05:00
mrq
9afa71542b little sloppy hack to try and not load the same model when it was already loaded 2023-08-11 04:02:36 +00:00
mrq
e2cd07d560 Fix for redaction at end of text (#45) 2023-06-10 21:16:21 +00:00
mrq
5ff00bf3bf added flags to rever to default method of latent generation (separately for the AR and Diffusion latents, as some voices don't play nicely with the chunk-for-all method) 2023-05-21 01:46:55 +00:00
mrq
c90ee7c529 removed kludgy wrappers for passing progress when I was a pythonlet and didn't know gradio can hook into tqdm outputs anyways 2023-05-04 23:39:39 +00:00
mrq
086aad5b49 quick hotfix to remove offending codesmell (will actually clean it when I finish eating) 2023-05-04 22:59:57 +00:00
mrq
04b7049811 freeze numpy to 1.23.5 because latest version will moan about deprecating complex 2023-05-04 01:54:41 +00:00
mrq
b6a213bbbd removed some CPU fallback wrappers because directml seems to work now without them 2023-04-29 00:46:36 +00:00
mrq
2f7d9ab932 disable BNB for inferencing by default because I'm pretty sure it makes zero differences (can be force enabled with env vars if you'r erelying on this for some reason) 2023-04-29 00:38:18 +00:00
mrq
f025470d60 Merge pull request 'Update tortoise/utils/devices.py vram issue' (#44) from aJoe/tortoise-tts:main into main
Reviewed-on: mrq/tortoise-tts#44
2023-04-12 19:58:02 +00:00
aJoe
eea4c68edc Update tortoise/utils/devices.py vram issue
Added line 85 to set the name variable as it was 'None' causing vram to be incorrect
2023-04-12 05:33:30 +00:00
mrq
815ae5d707 Merge pull request 'feat: support .flac voice files' (#43) from NtTestAlert/tortoise-tts:support_flac_voice into main
Reviewed-on: mrq/tortoise-tts#43
2023-04-01 16:37:56 +00:00
2cd7b72688 feat: support .flac voice files 2023-04-01 15:08:31 +02:00
10 changed files with 362 additions and 232 deletions

View File

@ -11,5 +11,5 @@ librosa==0.8.1
torchaudio torchaudio
threadpoolctl threadpoolctl
appdirs appdirs
numpy numpy<=1.23.5
numba numba

View File

@ -83,16 +83,6 @@ def check_for_kill_signal():
STOP_SIGNAL = False STOP_SIGNAL = False
raise Exception("Kill signal detected") raise Exception("Kill signal detected")
def tqdm_override(arr, verbose=False, progress=None, desc=None):
check_for_kill_signal()
if verbose and desc is not None:
print(desc)
if progress is None:
return tqdm(arr, disable=not verbose)
return progress.tqdm(arr, desc=f'{progress.msg_prefix} {desc}' if hasattr(progress, 'msg_prefix') else desc, track_tqdm=True)
def download_models(specific_models=None): def download_models(specific_models=None):
""" """
Call to download all the models that Tortoise uses. Call to download all the models that Tortoise uses.
@ -160,7 +150,7 @@ def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusi
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps), model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
conditioning_free=cond_free, conditioning_free_k=cond_free_k) conditioning_free=cond_free, conditioning_free_k=cond_free_k)
@torch.inference_mode()
def format_conditioning(clip, cond_length=132300, device='cuda', sampling_rate=22050): def format_conditioning(clip, cond_length=132300, device='cuda', sampling_rate=22050):
""" """
Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models. Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
@ -204,8 +194,8 @@ def fix_autoregressive_output(codes, stop_token, complain=True):
return codes return codes
@torch.inference_mode()
def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_latents, temperature=1, verbose=True, progress=None, desc=None, sampler="P", input_sample_rate=22050, output_sample_rate=24000): def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_latents, temperature=1, verbose=True, desc=None, sampler="P", input_sample_rate=22050, output_sample_rate=24000):
""" """
Uses the specified diffusion model to convert discrete codes into a spectrogram. Uses the specified diffusion model to convert discrete codes into a spectrogram.
""" """
@ -218,8 +208,7 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_la
diffuser.sampler = sampler.lower() diffuser.sampler = sampler.lower()
mel = diffuser.sample_loop(diffusion_model, output_shape, noise=noise, mel = diffuser.sample_loop(diffusion_model, output_shape, noise=noise,
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}, model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}, desc=desc)
verbose=verbose, progress=progress, desc=desc)
mel = denormalize_tacotron_mel(mel)[:,:,:output_seq_len] mel = denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
if get_device_name() == "dml": if get_device_name() == "dml":
@ -270,7 +259,8 @@ class TextToSpeech:
unsqueeze_sample_batches=False, unsqueeze_sample_batches=False,
input_sample_rate=22050, output_sample_rate=24000, input_sample_rate=22050, output_sample_rate=24000,
autoregressive_model_path=None, diffusion_model_path=None, vocoder_model=None, tokenizer_json=None, autoregressive_model_path=None, diffusion_model_path=None, vocoder_model=None, tokenizer_json=None,
): # ):
use_deepspeed=False): # Add use_deepspeed parameter
""" """
Constructor Constructor
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
@ -291,7 +281,8 @@ class TextToSpeech:
self.output_sample_rate = output_sample_rate self.output_sample_rate = output_sample_rate
self.minor_optimizations = minor_optimizations self.minor_optimizations = minor_optimizations
self.unsqueeze_sample_batches = unsqueeze_sample_batches self.unsqueeze_sample_batches = unsqueeze_sample_batches
self.use_deepspeed = use_deepspeed # Store use_deepspeed as an instance variable
print(f'use_deepspeed api_debug {use_deepspeed}')
# for clarity, it's simpler to split these up and just predicate them on requesting VRAM-consuming optimizations # for clarity, it's simpler to split these up and just predicate them on requesting VRAM-consuming optimizations
self.preloaded_tensors = minor_optimizations self.preloaded_tensors = minor_optimizations
self.use_kv_cache = minor_optimizations self.use_kv_cache = minor_optimizations
@ -347,25 +338,58 @@ class TextToSpeech:
self.loading = False self.loading = False
def load_autoregressive_model(self, autoregressive_model_path): def load_autoregressive_model(self, autoregressive_model_path, is_xtts=False):
if hasattr(self,"autoregressive_model_path") and self.autoregressive_model_path == autoregressive_model_path: if hasattr(self,"autoregressive_model_path") and os.path.samefile(self.autoregressive_model_path, autoregressive_model_path):
return return
self.loading = True
self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', self.models_dir) self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', self.models_dir)
self.autoregressive_model_hash = hash_file(self.autoregressive_model_path) new_hash = hash_file(self.autoregressive_model_path)
if hasattr(self,"autoregressive_model_hash") and self.autoregressive_model_hash == new_hash:
return
self.autoregressive_model_hash = new_hash
self.loading = True
print(f"Loading autoregressive model: {self.autoregressive_model_path}") print(f"Loading autoregressive model: {self.autoregressive_model_path}")
if hasattr(self, 'autoregressive'): if hasattr(self, 'autoregressive'):
del self.autoregressive del self.autoregressive
self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, # XTTS requires a different "dimensionality" for its autoregressive model
model_dim=1024, if new_hash == "e4ce21eae0043f7691d6a6c8540b74b8" or is_xtts:
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, dimensionality = {
train_solo_embeddings=False).cpu().eval() "max_mel_tokens": 605,
"max_text_tokens": 402,
"max_prompt_tokens": 70,
"max_conditioning_inputs": 1,
"layers": 30,
"model_dim": 1024,
"heads": 16,
"number_text_tokens": 5023, # -1
"start_text_token": 261,
"stop_text_token": 0,
"number_mel_codes": 8194,
"start_mel_token": 8192,
"stop_mel_token": 8193,
}
else:
dimensionality = {
"max_mel_tokens": 604,
"max_text_tokens": 402,
"max_conditioning_inputs": 2,
"layers": 30,
"model_dim": 1024,
"heads": 16,
"number_text_tokens": 255,
"start_text_token": 255,
"checkpointing": False,
"train_solo_embeddings": False
}
self.autoregressive = UnifiedVoice(**dimensionality).cpu().eval()
self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path)) self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path))
self.autoregressive.post_init_gpt2_config(kv_cache=self.use_kv_cache) self.autoregressive.post_init_gpt2_config(use_deepspeed=self.use_deepspeed, kv_cache=self.use_kv_cache)
if self.preloaded_tensors: if self.preloaded_tensors:
self.autoregressive = migrate_to_device( self.autoregressive, self.device ) self.autoregressive = migrate_to_device( self.autoregressive, self.device )
@ -373,7 +397,7 @@ class TextToSpeech:
print(f"Loaded autoregressive model") print(f"Loaded autoregressive model")
def load_diffusion_model(self, diffusion_model_path): def load_diffusion_model(self, diffusion_model_path):
if hasattr(self,"diffusion_model_path") and self.diffusion_model_path == diffusion_model_path: if hasattr(self,"diffusion_model_path") and os.path.samefile(self.diffusion_model_path, diffusion_model_path):
return return
self.loading = True self.loading = True
@ -384,9 +408,21 @@ class TextToSpeech:
if hasattr(self, 'diffusion'): if hasattr(self, 'diffusion'):
del self.diffusion del self.diffusion
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, # XTTS does not require a different "dimensionality" for its diffusion model
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, dimensionality = {
layer_drop=0, unconditioned_percentage=0).cpu().eval() "model_channels": 1024,
"num_layers": 10,
"in_channels": 100,
"out_channels": 200,
"in_latent_channels": 1024,
"in_tokens": 8193,
"dropout": 0,
"use_fp16": False,
"num_heads": 16,
"layer_drop": 0,
"unconditioned_percentage": 0
}
self.diffusion = DiffusionTts(**dimensionality)
self.diffusion.load_state_dict(torch.load(get_model_path('diffusion_decoder.pth', self.models_dir))) self.diffusion.load_state_dict(torch.load(get_model_path('diffusion_decoder.pth', self.models_dir)))
if self.preloaded_tensors: if self.preloaded_tensors:
self.diffusion = migrate_to_device( self.diffusion, self.device ) self.diffusion = migrate_to_device( self.diffusion, self.device )
@ -395,7 +431,7 @@ class TextToSpeech:
print(f"Loaded diffusion model") print(f"Loaded diffusion model")
def load_vocoder_model(self, vocoder_model): def load_vocoder_model(self, vocoder_model):
if hasattr(self,"vocoder_model_path") and self.vocoder_model_path == vocoder_model: if hasattr(self,"vocoder_model_path") and os.path.samefile(self.vocoder_model_path, vocoder_model):
return return
self.loading = True self.loading = True
@ -435,7 +471,7 @@ class TextToSpeech:
print(f"Loaded vocoder model") print(f"Loaded vocoder model")
def load_tokenizer_json(self, tokenizer_json): def load_tokenizer_json(self, tokenizer_json):
if hasattr(self,"tokenizer_json") and self.tokenizer_json == tokenizer_json: if hasattr(self,"tokenizer_json") and os.path.samefile(self.tokenizer_json, tokenizer_json):
return return
self.loading = True self.loading = True
@ -459,13 +495,15 @@ class TextToSpeech:
if self.preloaded_tensors: if self.preloaded_tensors:
self.cvvp = migrate_to_device( self.cvvp, self.device ) self.cvvp = migrate_to_device( self.cvvp, self.device )
def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, progress=None, slices=1, max_chunk_size=None, force_cpu=False): @torch.inference_mode()
def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, slices=1, max_chunk_size=None, force_cpu=False, original_ar=False, original_diffusion=False):
""" """
Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent). Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic
properties. properties.
:param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data. :param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
""" """
with torch.no_grad(): with torch.no_grad():
# computing conditional latents requires being done on the CPU if using DML because M$ still hasn't implemented some core functions # computing conditional latents requires being done on the CPU if using DML because M$ still hasn't implemented some core functions
if get_device_name() == "dml": if get_device_name() == "dml":
@ -475,50 +513,72 @@ class TextToSpeech:
if not isinstance(voice_samples, list): if not isinstance(voice_samples, list):
voice_samples = [voice_samples] voice_samples = [voice_samples]
voice_samples = [migrate_to_device(v, device) for v in voice_samples] resampler_22K = torchaudio.transforms.Resample(
resampler = torchaudio.transforms.Resample(
self.input_sample_rate, self.input_sample_rate,
self.output_sample_rate, 22050,
lowpass_filter_width=16, lowpass_filter_width=16,
rolloff=0.85, rolloff=0.85,
resampling_method="kaiser_window", resampling_method="kaiser_window",
beta=8.555504641634386, beta=8.555504641634386,
).to(device) ).to(device)
samples = [resampler(sample) for sample in voice_samples] resampler_24K = torchaudio.transforms.Resample(
chunks = [] self.input_sample_rate,
24000,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386,
).to(device)
concat = torch.cat(samples, dim=-1) voice_samples = [migrate_to_device(v, device) for v in voice_samples]
chunk_size = concat.shape[-1]
if slices == 0:
slices = 1
elif max_chunk_size is not None and chunk_size > max_chunk_size:
slices = 1
while int(chunk_size / slices) > max_chunk_size:
slices = slices + 1
chunks = torch.chunk(concat, slices, dim=1)
chunk_size = chunks[0].shape[-1]
auto_conds = [] auto_conds = []
for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing AR conditioning latents..."): diffusion_conds = []
auto_conds.append(format_conditioning(chunk, device=device, sampling_rate=self.input_sample_rate, cond_length=chunk_size))
auto_conds = torch.stack(auto_conds, dim=1)
if original_ar:
samples = [resampler_22K(sample) for sample in voice_samples]
for sample in tqdm(samples, desc="Computing AR conditioning latents..."):
auto_conds.append(format_conditioning(sample, device=device, sampling_rate=self.input_sample_rate, cond_length=132300))
else:
samples = [resampler_22K(sample) for sample in voice_samples]
concat = torch.cat(samples, dim=-1)
chunk_size = concat.shape[-1]
if slices == 0:
slices = 1
elif max_chunk_size is not None and chunk_size > max_chunk_size:
slices = 1
while int(chunk_size / slices) > max_chunk_size:
slices = slices + 1
chunks = torch.chunk(concat, slices, dim=1)
chunk_size = chunks[0].shape[-1]
for chunk in tqdm(chunks, desc="Computing AR conditioning latents..."):
auto_conds.append(format_conditioning(chunk, device=device, sampling_rate=self.input_sample_rate, cond_length=chunk_size))
if original_diffusion:
samples = [resampler_24K(sample) for sample in voice_samples]
for sample in tqdm(samples, desc="Computing diffusion conditioning latents..."):
sample = pad_or_truncate(sample, 102400)
cond_mel = wav_to_univnet_mel(migrate_to_device(sample, device), do_normalization=False, device=self.device)
diffusion_conds.append(cond_mel)
else:
samples = [resampler_24K(sample) for sample in voice_samples]
for chunk in tqdm(chunks, desc="Computing diffusion conditioning latents..."):
check_for_kill_signal()
chunk = pad_or_truncate(chunk, chunk_size)
cond_mel = wav_to_univnet_mel(migrate_to_device( chunk, device ), do_normalization=False, device=device)
diffusion_conds.append(cond_mel)
auto_conds = torch.stack(auto_conds, dim=1)
self.autoregressive = migrate_to_device( self.autoregressive, device ) self.autoregressive = migrate_to_device( self.autoregressive, device )
auto_latent = self.autoregressive.get_conditioning(auto_conds) auto_latent = self.autoregressive.get_conditioning(auto_conds)
self.autoregressive = migrate_to_device( self.autoregressive, self.device if self.preloaded_tensors else 'cpu' ) self.autoregressive = migrate_to_device( self.autoregressive, self.device if self.preloaded_tensors else 'cpu' )
diffusion_conds = []
for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing diffusion conditioning latents..."):
check_for_kill_signal()
chunk = pad_or_truncate(chunk, chunk_size)
cond_mel = wav_to_univnet_mel(migrate_to_device( chunk, device ), do_normalization=False, device=device)
diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1)
diffusion_conds = torch.stack(diffusion_conds, dim=1)
self.diffusion = migrate_to_device( self.diffusion, device ) self.diffusion = migrate_to_device( self.diffusion, device )
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds) diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
self.diffusion = migrate_to_device( self.diffusion, self.device if self.preloaded_tensors else 'cpu' ) self.diffusion = migrate_to_device( self.diffusion, self.device if self.preloaded_tensors else 'cpu' )
@ -561,6 +621,7 @@ class TextToSpeech:
settings.update(kwargs) # allow overriding of preset settings with kwargs settings.update(kwargs) # allow overriding of preset settings with kwargs
return self.tts(text, **settings) return self.tts(text, **settings)
@torch.inference_mode()
def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, use_deterministic_seed=None, def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, use_deterministic_seed=None,
return_deterministic_state=False, return_deterministic_state=False,
# autoregressive generation parameters follow # autoregressive generation parameters follow
@ -576,7 +637,6 @@ class TextToSpeech:
diffusion_sampler="P", diffusion_sampler="P",
breathing_room=8, breathing_room=8,
half_p=False, half_p=False,
progress=None,
**hf_generate_kwargs): **hf_generate_kwargs):
""" """
Produces an audio clip of the given text being spoken with the given reference voice. Produces an audio clip of the given text being spoken with the given reference voice.
@ -681,7 +741,7 @@ class TextToSpeech:
text_tokens = migrate_to_device( text_tokens, self.device ) text_tokens = migrate_to_device( text_tokens, self.device )
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p): with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p):
for b in tqdm_override(range(num_batches), verbose=verbose, progress=progress, desc="Generating autoregressive samples"): for b in tqdm(range(num_batches), desc="Generating autoregressive samples"):
check_for_kill_signal() check_for_kill_signal()
codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens, codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens,
do_sample=True, do_sample=True,
@ -730,7 +790,7 @@ class TextToSpeech:
desc = f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%" desc = f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%"
for batch in tqdm_override(samples, verbose=verbose, progress=progress, desc=desc): for batch in tqdm(samples, desc=desc):
check_for_kill_signal() check_for_kill_signal()
for i in range(batch.shape[0]): for i in range(batch.shape[0]):
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
@ -755,7 +815,10 @@ class TextToSpeech:
clip_results = torch.cat(clip_results, dim=0) clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0) samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=k).indices] if k < num_autoregressive_samples:
best_results = samples[torch.topk(clip_results, k=k).indices]
else:
best_results = samples
if not self.preloaded_tensors: if not self.preloaded_tensors:
self.clvp = migrate_to_device( self.clvp, 'cpu' ) self.clvp = migrate_to_device( self.clvp, 'cpu' )
@ -815,7 +878,7 @@ class TextToSpeech:
break break
mel = do_spectrogram_diffusion(self.diffusion, diffuser, latents, diffusion_conditioning, mel = do_spectrogram_diffusion(self.diffusion, diffuser, latents, diffusion_conditioning,
temperature=diffusion_temperature, verbose=verbose, progress=progress, desc="Transforming autoregressive outputs into audio..", sampler=diffusion_sampler, temperature=diffusion_temperature, desc="Transforming autoregressive outputs into audio..", sampler=diffusion_sampler,
input_sample_rate=self.input_sample_rate, output_sample_rate=self.output_sample_rate) input_sample_rate=self.input_sample_rate, output_sample_rate=self.output_sample_rate)
wav = self.vocoder.inference(mel) wav = self.vocoder.inference(mel)

View File

@ -14,6 +14,7 @@ if __name__ == '__main__':
parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) ' parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='random') 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='random')
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard') parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard')
parser.add_argument('--use_deepspeed', type=bool, help='Use deepspeed for speed bump.', default=True)
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/') parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/')
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
'should only be specified if you have custom checkpoints.', default=MODELS_DIR) 'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
@ -37,8 +38,8 @@ if __name__ == '__main__':
os.makedirs(args.output_path, exist_ok=True) os.makedirs(args.output_path, exist_ok=True)
#print(f'use_deepspeed do_tts_debug {use_deepspeed}')
tts = TextToSpeech(models_dir=args.model_dir) tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed)
selected_voices = args.voice.split(',') selected_voices = args.voice.split(',')
for k, selected_voice in enumerate(selected_voices): for k, selected_voice in enumerate(selected_voices):

View File

@ -283,9 +283,9 @@ class MelEncoder(nn.Module):
class UnifiedVoice(nn.Module): class UnifiedVoice(nn.Module):
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1, def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_prompt_tokens=2, max_mel_tokens=250, max_conditioning_inputs=1,
mel_length_compression=1024, number_text_tokens=256, mel_length_compression=1024, number_text_tokens=256,
start_text_token=None, number_mel_codes=8194, start_mel_token=8192, start_text_token=None, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
checkpointing=True, types=1): checkpointing=True, types=1):
""" """
@ -295,6 +295,7 @@ class UnifiedVoice(nn.Module):
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64 heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
max_text_tokens: Maximum number of text tokens that will be encountered by model. max_text_tokens: Maximum number of text tokens that will be encountered by model.
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model. max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
max_prompt_tokens: compat set to 2, 70 for XTTS
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s). max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length. mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
number_text_tokens: number_text_tokens:
@ -311,7 +312,7 @@ class UnifiedVoice(nn.Module):
self.number_text_tokens = number_text_tokens self.number_text_tokens = number_text_tokens
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
self.stop_text_token = 0 self.stop_text_token = stop_text_token
self.number_mel_codes = number_mel_codes self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token self.start_mel_token = start_mel_token
self.stop_mel_token = stop_mel_token self.stop_mel_token = stop_mel_token
@ -319,6 +320,7 @@ class UnifiedVoice(nn.Module):
self.heads = heads self.heads = heads
self.max_mel_tokens = max_mel_tokens self.max_mel_tokens = max_mel_tokens
self.max_text_tokens = max_text_tokens self.max_text_tokens = max_text_tokens
self.max_prompt_tokens = max_prompt_tokens
self.model_dim = model_dim self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression self.mel_length_compression = mel_length_compression
@ -352,8 +354,8 @@ class UnifiedVoice(nn.Module):
for module in embeddings: for module in embeddings:
module.weight.data.normal_(mean=0.0, std=.02) module.weight.data.normal_(mean=0.0, std=.02)
def post_init_gpt2_config(self, kv_cache=False): def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2 seq_length = self.max_mel_tokens + self.max_text_tokens + self.max_prompt_tokens
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens, gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
n_positions=seq_length, n_positions=seq_length,
n_ctx=seq_length, n_ctx=seq_length,
@ -363,6 +365,17 @@ class UnifiedVoice(nn.Module):
gradient_checkpointing=False, gradient_checkpointing=False,
use_cache=True) use_cache=True)
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head, kv_cache=kv_cache) self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head, kv_cache=kv_cache)
#print(f'use_deepspeed autoregressive_debug {use_deepspeed}')
if use_deepspeed and torch.cuda.is_available():
import deepspeed
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
mp_size=1,
replace_with_kernel_inject=True,
dtype=torch.float32)
self.inference_model = self.ds_engine.module.eval()
else:
self.inference_model = self.inference_model.eval()
self.gpt.wte = self.mel_embedding self.gpt.wte = self.mel_embedding
def build_aligned_inputs_and_targets(self, input, start_token, stop_token): def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
@ -483,7 +496,7 @@ class UnifiedVoice(nn.Module):
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1, def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2 seq_length = self.max_mel_tokens + self.max_text_tokens + self.max_prompt_tokens
if not hasattr(self, 'inference_model'): if not hasattr(self, 'inference_model'):
self.post_init_gpt2_config(kv_cache=self.kv_cache) self.post_init_gpt2_config(kv_cache=self.kv_cache)

View File

@ -17,6 +17,7 @@ if __name__ == '__main__':
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='pat') 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='pat')
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/') parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/')
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard') parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard')
parser.add_argument('--use_deepspeed', type=bool, help='Use deepspeed for speed bump.', default=True)
parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None) parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None)
parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice. Only the first candidate is actually used in the final product, the others can be used manually.', default=1) parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice. Only the first candidate is actually used in the final product, the others can be used manually.', default=1)
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
@ -25,7 +26,7 @@ if __name__ == '__main__':
parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True) parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
args = parser.parse_args() args = parser.parse_args()
tts = TextToSpeech(models_dir=args.model_dir) tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed)
outpath = args.output_path outpath = args.output_path
selected_voices = args.voice.split(',') selected_voices = args.voice.split(',')

View File

@ -2,6 +2,7 @@ import os
from glob import glob from glob import glob
import librosa import librosa
import soundfile as sf
import torch import torch
import torchaudio import torchaudio
import numpy as np import numpy as np
@ -24,6 +25,9 @@ def load_audio(audiopath, sampling_rate):
elif audiopath[-4:] == '.mp3': elif audiopath[-4:] == '.mp3':
audio, lsr = librosa.load(audiopath, sr=sampling_rate) audio, lsr = librosa.load(audiopath, sr=sampling_rate)
audio = torch.FloatTensor(audio) audio = torch.FloatTensor(audio)
elif audiopath[-5:] == '.flac':
audio, lsr = sf.read(audiopath)
audio = torch.FloatTensor(audio)
else: else:
assert False, f"Unsupported audio format provided: {audiopath[-4:]}" assert False, f"Unsupported audio format provided: {audiopath[-4:]}"
@ -85,17 +89,77 @@ def get_voices(extra_voice_dirs=[], load_latents=True):
for sub in subs: for sub in subs:
subj = os.path.join(d, sub) subj = os.path.join(d, sub)
if os.path.isdir(subj): if os.path.isdir(subj):
voices[sub] = list(glob(f'{subj}/*.wav')) + list(glob(f'{subj}/*.mp3')) voices[sub] = list(glob(f'{subj}/*.wav')) + list(glob(f'{subj}/*.mp3')) + list(glob(f'{subj}/*.flac'))
if load_latents: if load_latents:
voices[sub] = voices[sub] + list(glob(f'{subj}/*.pth')) voices[sub] = voices[sub] + list(glob(f'{subj}/*.pth'))
return voices return voices
def get_voice( name, dir=get_voice_dir(), load_latents=True, extensions=["wav", "mp3", "flac"] ):
subj = f'{dir}/{name}/'
if not os.path.isdir(subj):
return
files = os.listdir(subj)
if load_latents:
extensions.append("pth")
voice = []
for file in files:
ext = os.path.splitext(file)[-1][1:]
if ext not in extensions:
continue
voice.append(f'{subj}/{file}')
return sorted( voice )
def get_voice_list(dir=get_voice_dir(), append_defaults=False, load_latents=True, extensions=["wav", "mp3", "flac"]):
defaults = [ "random", "microphone" ]
os.makedirs(dir, exist_ok=True)
#res = sorted([d for d in os.listdir(dir) if d not in defaults and os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ])
res = []
for name in os.listdir(dir):
if name in defaults:
continue
if not os.path.isdir(f'{dir}/{name}'):
continue
if len(os.listdir(os.path.join(dir, name))) == 0:
continue
files = get_voice( name, dir=dir, extensions=extensions, load_latents=load_latents )
if len(files) > 0:
res.append(name)
else:
for subdir in os.listdir(f'{dir}/{name}'):
if not os.path.isdir(f'{dir}/{name}/{subdir}'):
continue
files = get_voice( f'{name}/{subdir}', dir=dir, extensions=extensions, load_latents=load_latents )
if len(files) == 0:
continue
res.append(f'{name}/{subdir}')
res = sorted(res)
if append_defaults:
res = res + defaults
return res
def _get_voices( dirs=[get_voice_dir()], load_latents=True ):
voices = {}
for dir in dirs:
voice_list = get_voice_list(dir=dir)
voices |= { name: get_voice(name=name, dir=dir, load_latents=load_latents) for name in voice_list }
return voices
def load_voice(voice, extra_voice_dirs=[], load_latents=True, sample_rate=22050, device='cpu', model_hash=None): def load_voice(voice, extra_voice_dirs=[], load_latents=True, sample_rate=22050, device='cpu', model_hash=None):
if voice == 'random': if voice == 'random':
return None, None return None, None
voices = get_voices(extra_voice_dirs=extra_voice_dirs, load_latents=load_latents) voices = _get_voices(dirs=[get_voice_dir()] + extra_voice_dirs, load_latents=load_latents)
paths = voices[voice] paths = voices[voice]
mtime = 0 mtime = 0

View File

@ -1,127 +1,130 @@
import torch import torch
import psutil import psutil
import importlib import importlib
DEVICE_OVERRIDE = None DEVICE_OVERRIDE = None
DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)] DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
from inspect import currentframe, getframeinfo from inspect import currentframe, getframeinfo
import gc import gc
def do_gc(): def do_gc():
gc.collect() gc.collect()
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
except Exception as e: except Exception as e:
pass pass
def print_stats(collect=False): def print_stats(collect=False):
cf = currentframe().f_back cf = currentframe().f_back
msg = f'{getframeinfo(cf).filename}:{cf.f_lineno}' msg = f'{getframeinfo(cf).filename}:{cf.f_lineno}'
if collect: if collect:
do_gc() do_gc()
tot = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) tot = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
res = torch.cuda.memory_reserved(0) / (1024 ** 3) res = torch.cuda.memory_reserved(0) / (1024 ** 3)
alloc = torch.cuda.memory_allocated(0) / (1024 ** 3) alloc = torch.cuda.memory_allocated(0) / (1024 ** 3)
print("[{}] Total: {:.3f} | Reserved: {:.3f} | Allocated: {:.3f} | Free: {:.3f}".format( msg, tot, res, alloc, tot-res )) print("[{}] Total: {:.3f} | Reserved: {:.3f} | Allocated: {:.3f} | Free: {:.3f}".format( msg, tot, res, alloc, tot-res ))
def has_dml(): def has_dml():
loader = importlib.find_loader('torch_directml') loader = importlib.find_loader('torch_directml')
if loader is None: if loader is None:
return False return False
import torch_directml import torch_directml
return torch_directml.is_available() return torch_directml.is_available()
def set_device_name(name): def set_device_name(name):
global DEVICE_OVERRIDE global DEVICE_OVERRIDE
DEVICE_OVERRIDE = name DEVICE_OVERRIDE = name
def get_device_name(attempt_gc=True): def get_device_name(attempt_gc=True):
global DEVICE_OVERRIDE global DEVICE_OVERRIDE
if DEVICE_OVERRIDE is not None and DEVICE_OVERRIDE != "": if DEVICE_OVERRIDE is not None and DEVICE_OVERRIDE != "":
return DEVICE_OVERRIDE return DEVICE_OVERRIDE
name = 'cpu' name = 'cpu'
if torch.cuda.is_available(): if torch.cuda.is_available():
name = 'cuda' name = 'cuda'
if attempt_gc: if attempt_gc:
torch.cuda.empty_cache() # may have performance implications torch.cuda.empty_cache() # may have performance implications
elif has_dml(): elif has_dml():
name = 'dml' name = 'dml'
return name return name
def get_device(verbose=False): def get_device(verbose=False):
name = get_device_name() name = get_device_name()
if verbose: if verbose:
if name == 'cpu': if name == 'cpu':
print("No hardware acceleration is available, falling back to CPU...") print("No hardware acceleration is available, falling back to CPU...")
else: else:
print(f"Hardware acceleration found: {name}") print(f"Hardware acceleration found: {name}")
if name == "dml": if name == "dml":
import torch_directml import torch_directml
return torch_directml.device() return torch_directml.device()
return torch.device(name) return torch.device(name)
def get_device_vram( name=get_device_name() ): def get_device_vram( name=get_device_name() ):
available = 1 available = 1
if name == "cuda": if name == "cuda":
_, available = torch.cuda.mem_get_info() _, available = torch.cuda.mem_get_info()
elif name == "cpu": elif name == "cpu":
available = psutil.virtual_memory()[4] available = psutil.virtual_memory()[4]
return available / (1024 ** 3) return available / (1024 ** 3)
def get_device_batch_size(name=None): def get_device_batch_size(name=get_device_name()):
vram = get_device_vram(name) vram = get_device_vram(name)
if vram > 14: if vram > 14:
return 16 return 16
elif vram > 10: elif vram > 10:
return 8 return 8
elif vram > 7: elif vram > 7:
return 4 return 4
""" """
for k, v in DEVICE_BATCH_SIZE_MAP: for k, v in DEVICE_BATCH_SIZE_MAP:
if vram > k: if vram > k:
return v return v
""" """
return 1 return 1
def get_device_count(name=get_device_name()): def get_device_count(name=get_device_name()):
if name == "cuda": if name == "cuda":
return torch.cuda.device_count() return torch.cuda.device_count()
if name == "dml": if name == "dml":
import torch_directml import torch_directml
return torch_directml.device_count() return torch_directml.device_count()
return 1 return 1
if has_dml(): # if you're getting errors make sure you've updated your torch-directml, and if you're still getting errors then you can uncomment the below block
_cumsum = torch.cumsum """
_repeat_interleave = torch.repeat_interleave if has_dml():
_multinomial = torch.multinomial _cumsum = torch.cumsum
_repeat_interleave = torch.repeat_interleave
_Tensor_new = torch.Tensor.new _multinomial = torch.multinomial
_Tensor_cumsum = torch.Tensor.cumsum
_Tensor_repeat_interleave = torch.Tensor.repeat_interleave _Tensor_new = torch.Tensor.new
_Tensor_multinomial = torch.Tensor.multinomial _Tensor_cumsum = torch.Tensor.cumsum
_Tensor_repeat_interleave = torch.Tensor.repeat_interleave
torch.cumsum = lambda input, *args, **kwargs: ( _cumsum(input.to("cpu"), *args, **kwargs).to(input.device) ) _Tensor_multinomial = torch.Tensor.multinomial
torch.repeat_interleave = lambda input, *args, **kwargs: ( _repeat_interleave(input.to("cpu"), *args, **kwargs).to(input.device) )
torch.multinomial = lambda input, *args, **kwargs: ( _multinomial(input.to("cpu"), *args, **kwargs).to(input.device) ) torch.cumsum = lambda input, *args, **kwargs: ( _cumsum(input.to("cpu"), *args, **kwargs).to(input.device) )
torch.repeat_interleave = lambda input, *args, **kwargs: ( _repeat_interleave(input.to("cpu"), *args, **kwargs).to(input.device) )
torch.Tensor.new = lambda self, *args, **kwargs: ( _Tensor_new(self.to("cpu"), *args, **kwargs).to(self.device) ) torch.multinomial = lambda input, *args, **kwargs: ( _multinomial(input.to("cpu"), *args, **kwargs).to(input.device) )
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( _Tensor_cumsum(self.to("cpu"), *args, **kwargs).to(self.device) )
torch.Tensor.repeat_interleave = lambda self, *args, **kwargs: ( _Tensor_repeat_interleave(self.to("cpu"), *args, **kwargs).to(self.device) ) torch.Tensor.new = lambda self, *args, **kwargs: ( _Tensor_new(self.to("cpu"), *args, **kwargs).to(self.device) )
torch.Tensor.multinomial = lambda self, *args, **kwargs: ( _Tensor_multinomial(self.to("cpu"), *args, **kwargs).to(self.device) ) torch.Tensor.cumsum = lambda self, *args, **kwargs: ( _Tensor_cumsum(self.to("cpu"), *args, **kwargs).to(self.device) )
torch.Tensor.repeat_interleave = lambda self, *args, **kwargs: ( _Tensor_repeat_interleave(self.to("cpu"), *args, **kwargs).to(self.device) )
torch.Tensor.multinomial = lambda self, *args, **kwargs: ( _Tensor_multinomial(self.to("cpu"), *args, **kwargs).to(self.device) )
"""

View File

@ -13,15 +13,7 @@ import math
import numpy as np import numpy as np
import torch import torch
import torch as th import torch as th
from tqdm import tqdm from tqdm.auto import tqdm
def tqdm_override(arr, verbose=False, progress=None, desc=None):
if verbose and desc is not None:
print(desc)
if progress is None:
return tqdm(arr, disable=not verbose)
return progress.tqdm(arr, desc=f'{progress.msg_prefix} {desc}' if hasattr(progress, 'msg_prefix') else desc, track_tqdm=True)
def normal_kl(mean1, logvar1, mean2, logvar2): def normal_kl(mean1, logvar1, mean2, logvar2):
""" """
@ -556,7 +548,6 @@ class GaussianDiffusion:
model_kwargs=None, model_kwargs=None,
device=None, device=None,
verbose=False, verbose=False,
progress=None,
desc=None desc=None
): ):
""" """
@ -589,7 +580,6 @@ class GaussianDiffusion:
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
device=device, device=device,
verbose=verbose, verbose=verbose,
progress=progress,
desc=desc desc=desc
): ):
final = sample final = sample
@ -606,7 +596,6 @@ class GaussianDiffusion:
model_kwargs=None, model_kwargs=None,
device=None, device=None,
verbose=False, verbose=False,
progress=None,
desc=None desc=None
): ):
""" """
@ -626,7 +615,7 @@ class GaussianDiffusion:
img = th.randn(*shape, device=device) img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1] indices = list(range(self.num_timesteps))[::-1]
for i in tqdm_override(indices, verbose=verbose, desc=desc, progress=progress): for i in tqdm(indices, desc=desc):
t = th.tensor([i] * shape[0], device=device) t = th.tensor([i] * shape[0], device=device)
with th.no_grad(): with th.no_grad():
out = self.p_sample( out = self.p_sample(
@ -741,7 +730,6 @@ class GaussianDiffusion:
device=None, device=None,
verbose=False, verbose=False,
eta=0.0, eta=0.0,
progress=None,
desc=None, desc=None,
): ):
""" """
@ -761,7 +749,6 @@ class GaussianDiffusion:
device=device, device=device,
verbose=verbose, verbose=verbose,
eta=eta, eta=eta,
progress=progress,
desc=desc desc=desc
): ):
final = sample final = sample
@ -779,7 +766,6 @@ class GaussianDiffusion:
device=None, device=None,
verbose=False, verbose=False,
eta=0.0, eta=0.0,
progress=None,
desc=None, desc=None,
): ):
""" """
@ -798,10 +784,7 @@ class GaussianDiffusion:
indices = list(range(self.num_timesteps))[::-1] indices = list(range(self.num_timesteps))[::-1]
if verbose: if verbose:
# Lazy import so that we don't depend on tqdm. indices = tqdm(indices, desc=desc)
from tqdm.auto import tqdm
indices = tqdm_override(indices, verbose=verbose, desc=desc, progress=progress)
for i in indices: for i in indices:
t = th.tensor([i] * shape[0], device=device) t = th.tensor([i] * shape[0], device=device)

View File

@ -22,17 +22,19 @@ import os
USE_STABLE_EMBEDDING = False USE_STABLE_EMBEDDING = False
try: try:
import bitsandbytes as bnb
OVERRIDE_LINEAR = False OVERRIDE_LINEAR = False
OVERRIDE_EMBEDDING = True OVERRIDE_EMBEDDING = False
OVERRIDE_ADAM = True OVERRIDE_ADAM = False
OVERRIDE_ADAMW = True OVERRIDE_ADAMW = False
USE_STABLE_EMBEDDING = os.environ.get('BITSANDBYTES_USE_STABLE_EMBEDDING', '1' if USE_STABLE_EMBEDDING else '0') == '1' USE_STABLE_EMBEDDING = os.environ.get('BITSANDBYTES_USE_STABLE_EMBEDDING', '1' if USE_STABLE_EMBEDDING else '0') == '1'
OVERRIDE_LINEAR = os.environ.get('BITSANDBYTES_OVERRIDE_LINEAR', '1' if OVERRIDE_LINEAR else '0') == '1' OVERRIDE_LINEAR = os.environ.get('BITSANDBYTES_OVERRIDE_LINEAR', '1' if OVERRIDE_LINEAR else '0') == '1'
OVERRIDE_EMBEDDING = os.environ.get('BITSANDBYTES_OVERRIDE_EMBEDDING', '1' if OVERRIDE_EMBEDDING else '0') == '1' OVERRIDE_EMBEDDING = os.environ.get('BITSANDBYTES_OVERRIDE_EMBEDDING', '1' if OVERRIDE_EMBEDDING else '0') == '1'
OVERRIDE_ADAM = os.environ.get('BITSANDBYTES_OVERRIDE_ADAM', '1' if OVERRIDE_ADAM else '0') == '1' OVERRIDE_ADAM = os.environ.get('BITSANDBYTES_OVERRIDE_ADAM', '1' if OVERRIDE_ADAM else '0') == '1'
OVERRIDE_ADAMW = os.environ.get('BITSANDBYTES_OVERRIDE_ADAMW', '1' if OVERRIDE_ADAMW else '0') == '1' OVERRIDE_ADAMW = os.environ.get('BITSANDBYTES_OVERRIDE_ADAMW', '1' if OVERRIDE_ADAMW else '0') == '1'
if OVERRIDE_LINEAR or OVERRIDE_EMBEDDING or OVERRIDE_ADAM or OVERRIDE_ADAMW:
import bitsandbytes as bnb
except Exception as e: except Exception as e:
OVERRIDE_LINEAR = False OVERRIDE_LINEAR = False
OVERRIDE_EMBEDDING = False OVERRIDE_EMBEDDING = False

View File

@ -144,7 +144,7 @@ class Wav2VecAlignment:
non_redacted_intervals = [] non_redacted_intervals = []
last_point = 0 last_point = 0
for i in range(len(fully_split)): for i in range(len(fully_split)):
if i % 2 == 0: if i % 2 == 0 and fully_split[i] != "": # Check for empty string fixes index error
end_interval = max(0, last_point + len(fully_split[i]) - 1) end_interval = max(0, last_point + len(fully_split[i]) - 1)
non_redacted_intervals.append((last_point, end_interval)) non_redacted_intervals.append((last_point, end_interval))
last_point += len(fully_split[i]) last_point += len(fully_split[i])