maybe solved that odd VRAM spike when doing the clvp pass
This commit is contained in:
parent
fec0685405
commit
97cd58e7eb
145
tortoise/api.py
145
tortoise/api.py
|
@ -29,7 +29,7 @@ from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named
|
|||
from tortoise.utils.tokenizer import VoiceBpeTokenizer
|
||||
from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
|
||||
|
||||
from tortoise.utils.device import get_device, get_device_name, get_device_batch_size
|
||||
from tortoise.utils.device import get_device, get_device_name, get_device_batch_size, print_stats, do_gc
|
||||
|
||||
pbar = None
|
||||
STOP_SIGNAL = False
|
||||
|
@ -172,8 +172,8 @@ def format_conditioning(clip, cond_length=132300, device='cuda', sampling_rate=2
|
|||
rand_start = random.randint(0, gap)
|
||||
clip = clip[:, rand_start:rand_start + cond_length]
|
||||
mel_clip = TorchMelSpectrogram(sampling_rate=sampling_rate)(clip.unsqueeze(0)).squeeze(0)
|
||||
return mel_clip.unsqueeze(0).to(device)
|
||||
|
||||
mel_clip = mel_clip.unsqueeze(0)
|
||||
return migrate_to_device(mel_clip, device)
|
||||
|
||||
def fix_autoregressive_output(codes, stop_token, complain=True):
|
||||
"""
|
||||
|
@ -241,6 +241,25 @@ def classify_audio_clip(clip):
|
|||
results = F.softmax(classifier(clip), dim=-1)
|
||||
return results[0][0]
|
||||
|
||||
def migrate_to_device( t, device ):
|
||||
if t is None:
|
||||
return t
|
||||
|
||||
if not hasattr(t, 'device'):
|
||||
t.device = device
|
||||
t.manually_track_device = True
|
||||
elif t.device == device:
|
||||
return t
|
||||
|
||||
if hasattr(t, 'manually_track_device') and t.manually_track_device:
|
||||
t.device = device
|
||||
|
||||
t = t.to(device)
|
||||
|
||||
do_gc()
|
||||
|
||||
return t
|
||||
|
||||
class TextToSpeech:
|
||||
"""
|
||||
Main entry point into Tortoise.
|
||||
|
@ -315,10 +334,11 @@ class TextToSpeech:
|
|||
self.rlg_diffusion = None
|
||||
|
||||
if self.preloaded_tensors:
|
||||
self.autoregressive = self.autoregressive.to(self.device)
|
||||
self.diffusion = self.diffusion.to(self.device)
|
||||
self.clvp = self.clvp.to(self.device)
|
||||
self.vocoder = self.vocoder.to(self.device)
|
||||
self.autoregressive = migrate_to_device( self.autoregressive, self.device )
|
||||
self.diffusion = migrate_to_device( self.diffusion, self.device )
|
||||
self.clvp = migrate_to_device( self.clvp, self.device )
|
||||
self.vocoder = migrate_to_device( self.vocoder, self.device )
|
||||
|
||||
self.loading = False
|
||||
|
||||
def load_autoregressive_model(self, autoregressive_model_path):
|
||||
|
@ -341,7 +361,7 @@ class TextToSpeech:
|
|||
self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path))
|
||||
self.autoregressive.post_init_gpt2_config(kv_cache=self.use_kv_cache)
|
||||
if self.preloaded_tensors:
|
||||
self.autoregressive = self.autoregressive.to(self.device)
|
||||
self.autoregressive = migrate_to_device( self.autoregressive, self.device )
|
||||
|
||||
self.loading = False
|
||||
print(f"Loaded autoregressive model")
|
||||
|
@ -382,7 +402,7 @@ class TextToSpeech:
|
|||
|
||||
self.vocoder.eval(inference=True)
|
||||
if self.preloaded_tensors:
|
||||
self.vocoder = self.vocoder.to(self.device)
|
||||
self.vocoder = migrate_to_device( self.vocoder, self.device )
|
||||
self.loading = False
|
||||
print(f"Loaded vocoder model")
|
||||
|
||||
|
@ -393,7 +413,7 @@ class TextToSpeech:
|
|||
self.cvvp.load_state_dict(torch.load(get_model_path('cvvp.pth', self.models_dir)))
|
||||
|
||||
if self.preloaded_tensors:
|
||||
self.cvvp = self.cvvp.to(self.device)
|
||||
self.cvvp = migrate_to_device( self.cvvp, self.device )
|
||||
|
||||
def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, progress=None, slices=1, max_chunk_size=None, force_cpu=False):
|
||||
"""
|
||||
|
@ -411,7 +431,7 @@ class TextToSpeech:
|
|||
if not isinstance(voice_samples, list):
|
||||
voice_samples = [voice_samples]
|
||||
|
||||
voice_samples = [v.to(device) for v in voice_samples]
|
||||
voice_samples = [migrate_to_device(v, device) for v in voice_samples]
|
||||
|
||||
resampler = torchaudio.transforms.Resample(
|
||||
self.input_sample_rate,
|
||||
|
@ -420,24 +440,19 @@ class TextToSpeech:
|
|||
rolloff=0.85,
|
||||
resampling_method="kaiser_window",
|
||||
beta=8.555504641634386,
|
||||
)
|
||||
).to(device)
|
||||
|
||||
samples = []
|
||||
auto_conds = []
|
||||
for sample in voice_samples:
|
||||
auto_conds.append(format_conditioning(sample, device=device, sampling_rate=self.input_sample_rate))
|
||||
samples.append(resampler(sample.cpu()).to(device)) # icky no good, easier to do the resampling on CPU than figure out how to do it on GPU
|
||||
samples.append(resampler(sample))
|
||||
|
||||
auto_conds = torch.stack(auto_conds, dim=1)
|
||||
|
||||
|
||||
self.autoregressive = self.autoregressive.to(device)
|
||||
self.autoregressive = migrate_to_device( self.autoregressive, device )
|
||||
auto_latent = self.autoregressive.get_conditioning(auto_conds)
|
||||
if self.preloaded_tensors:
|
||||
self.autoregressive = self.autoregressive.to(self.device)
|
||||
else:
|
||||
self.autoregressive = self.autoregressive.cpu()
|
||||
|
||||
self.autoregressive = migrate_to_device( self.autoregressive, self.device if self.preloaded_tensors else 'cpu' )
|
||||
|
||||
diffusion_conds = []
|
||||
chunks = []
|
||||
|
@ -460,21 +475,14 @@ class TextToSpeech:
|
|||
for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."):
|
||||
check_for_kill_signal()
|
||||
chunk = pad_or_truncate(chunk, chunk_size)
|
||||
cond_mel = wav_to_univnet_mel(chunk.to(device), do_normalization=False, device=device)
|
||||
cond_mel = wav_to_univnet_mel(migrate_to_device( chunk, device ), do_normalization=False, device=device)
|
||||
diffusion_conds.append(cond_mel)
|
||||
|
||||
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
||||
|
||||
self.diffusion = self.diffusion.to(device)
|
||||
|
||||
self.diffusion = migrate_to_device( self.diffusion, device )
|
||||
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
|
||||
|
||||
if self.preloaded_tensors:
|
||||
self.diffusion = self.diffusion.to(self.device)
|
||||
else:
|
||||
self.diffusion = self.diffusion.cpu()
|
||||
|
||||
|
||||
self.diffusion = migrate_to_device( self.diffusion, self.device if self.preloaded_tensors else 'cpu' )
|
||||
|
||||
if return_mels:
|
||||
return auto_latent, diffusion_latent, auto_conds, diffusion_conds
|
||||
|
@ -587,7 +595,9 @@ class TextToSpeech:
|
|||
elif autoregressive_model != self.autoregressive_model_path:
|
||||
self.load_autoregressive_model(autoregressive_model)
|
||||
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0)
|
||||
text_tokens = migrate_to_device( text_tokens, self.device )
|
||||
|
||||
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
|
||||
assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
|
||||
|
||||
|
@ -615,9 +625,9 @@ class TextToSpeech:
|
|||
stop_mel_token = self.autoregressive.stop_mel_token
|
||||
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
|
||||
|
||||
self.autoregressive = self.autoregressive.to(self.device)
|
||||
auto_conditioning = auto_conditioning.to(self.device)
|
||||
text_tokens = text_tokens.to(self.device)
|
||||
self.autoregressive = migrate_to_device( self.autoregressive, self.device )
|
||||
auto_conditioning = migrate_to_device( auto_conditioning, self.device )
|
||||
text_tokens = migrate_to_device( text_tokens, self.device )
|
||||
|
||||
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p):
|
||||
for b in tqdm_override(range(num_batches), verbose=verbose, progress=progress, desc="Generating autoregressive samples"):
|
||||
|
@ -636,24 +646,24 @@ class TextToSpeech:
|
|||
samples.append(codes)
|
||||
|
||||
if not self.preloaded_tensors:
|
||||
self.autoregressive = self.autoregressive.cpu()
|
||||
auto_conditioning = auto_conditioning.cpu()
|
||||
self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' )
|
||||
|
||||
clip_results = []
|
||||
|
||||
if auto_conds is not None:
|
||||
auto_conds = auto_conds.to(self.device)
|
||||
auto_conditioning = migrate_to_device( auto_conditioning, self.device )
|
||||
|
||||
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p):
|
||||
if not self.minor_optimizations:
|
||||
self.autoregressive = self.autoregressive.cpu()
|
||||
self.clvp = self.clvp.to(self.device)
|
||||
if not self.preloaded_tensors:
|
||||
self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' )
|
||||
self.clvp = migrate_to_device( self.clvp, self.device )
|
||||
|
||||
if cvvp_amount > 0:
|
||||
if self.cvvp is None:
|
||||
self.load_cvvp()
|
||||
if not self.minor_optimizations:
|
||||
self.cvvp = self.cvvp.to(self.device)
|
||||
|
||||
if not self.preloaded_tensors:
|
||||
self.cvvp = migrate_to_device( self.cvvp, self.device )
|
||||
|
||||
desc="Computing best candidates"
|
||||
if verbose:
|
||||
|
@ -662,6 +672,7 @@ class TextToSpeech:
|
|||
else:
|
||||
desc = f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%"
|
||||
|
||||
|
||||
for batch in tqdm_override(samples, verbose=verbose, progress=progress, desc=desc):
|
||||
check_for_kill_signal()
|
||||
for i in range(batch.shape[0]):
|
||||
|
@ -683,30 +694,28 @@ class TextToSpeech:
|
|||
clip_results.append(clvp)
|
||||
|
||||
if not self.preloaded_tensors and auto_conds is not None:
|
||||
auto_conds = auto_conds.cpu()
|
||||
auto_conds = migrate_to_device( auto_conds, 'cpu' )
|
||||
|
||||
clip_results = torch.cat(clip_results, dim=0)
|
||||
samples = torch.cat(samples, dim=0)
|
||||
best_results = samples[torch.topk(clip_results, k=k).indices]
|
||||
|
||||
if not self.preloaded_tensors:
|
||||
self.clvp = self.clvp.cpu()
|
||||
if self.cvvp is not None:
|
||||
self.cvvp = self.cvvp.cpu()
|
||||
|
||||
del samples
|
||||
self.clvp = migrate_to_device( self.clvp, 'cpu' )
|
||||
self.cvvp = migrate_to_device( self.cvvp, 'cpu' )
|
||||
|
||||
|
||||
if get_device_name() == "dml":
|
||||
text_tokens = text_tokens.cpu()
|
||||
best_results = best_results.cpu()
|
||||
auto_conditioning = auto_conditioning.cpu()
|
||||
self.autoregressive = self.autoregressive.cpu()
|
||||
text_tokens = migrate_to_device( text_tokens, 'cpu' )
|
||||
best_results = migrate_to_device( best_results, 'cpu' )
|
||||
auto_conditioning = migrate_to_device( auto_conditioning, 'cpu' )
|
||||
self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' )
|
||||
else:
|
||||
#text_tokens = text_tokens.to(self.device)
|
||||
#best_results = best_results.to(self.device)
|
||||
auto_conditioning = auto_conditioning.to(self.device)
|
||||
self.autoregressive = self.autoregressive.to(self.device)
|
||||
|
||||
del samples
|
||||
|
||||
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
|
||||
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
|
||||
# results, but will increase memory usage.
|
||||
|
@ -715,21 +724,19 @@ class TextToSpeech:
|
|||
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
|
||||
return_latent=True, clip_inputs=False)
|
||||
|
||||
diffusion_conditioning = diffusion_conditioning.to(self.device)
|
||||
diffusion_conditioning = migrate_to_device( diffusion_conditioning, self.device )
|
||||
|
||||
if get_device_name() == "dml":
|
||||
self.autoregressive = self.autoregressive.to(self.device)
|
||||
best_results = best_results.to(self.device)
|
||||
best_latents = best_latents.to(self.device)
|
||||
|
||||
self.vocoder = self.vocoder.cpu()
|
||||
self.autoregressive = migrate_to_device( self.autoregressive, self.device )
|
||||
best_results = migrate_to_device( best_results, self.device )
|
||||
best_latents = migrate_to_device( best_latents, self.device )
|
||||
self.vocoder = migrate_to_device( self.vocoder, 'cpu' )
|
||||
else:
|
||||
if not self.preloaded_tensors:
|
||||
self.autoregressive = self.autoregressive.cpu()
|
||||
|
||||
self.diffusion = self.diffusion.to(self.device)
|
||||
self.vocoder = self.vocoder.to(self.device)
|
||||
self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' )
|
||||
|
||||
self.diffusion = migrate_to_device( self.diffusion, self.device )
|
||||
self.vocoder = migrate_to_device( self.vocoder, self.device )
|
||||
|
||||
del text_tokens
|
||||
del auto_conditioning
|
||||
|
@ -758,12 +765,14 @@ class TextToSpeech:
|
|||
wav_candidates.append(wav)
|
||||
|
||||
if not self.preloaded_tensors:
|
||||
self.diffusion = self.diffusion.cpu()
|
||||
self.vocoder = self.vocoder.cpu()
|
||||
self.diffusion = migrate_to_device( self.diffusion, 'cpu' )
|
||||
self.vocoder = migrate_to_device( self.vocoder, 'cpu' )
|
||||
|
||||
def potentially_redact(clip, text):
|
||||
if self.enable_redaction:
|
||||
return self.aligner.redact(clip.squeeze(1).to('cpu' if get_device_name() == "dml" else self.device), text, self.output_sample_rate).unsqueeze(1)
|
||||
t = clip.squeeze(1)
|
||||
t = migrate_to_device( t, 'cpu' if get_device_name() == "dml" else self.device)
|
||||
return self.aligner.redact(t, text, self.output_sample_rate).unsqueeze(1)
|
||||
return clip
|
||||
wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
|
||||
|
||||
|
@ -772,7 +781,7 @@ class TextToSpeech:
|
|||
else:
|
||||
res = wav_candidates[0]
|
||||
|
||||
gc.collect()
|
||||
do_gc()
|
||||
|
||||
if return_deterministic_state:
|
||||
return res, (deterministic_seed, text, voice_samples, conditioning_latents)
|
||||
|
|
|
@ -9,6 +9,8 @@ from tortoise.models.xtransformers import Encoder
|
|||
|
||||
import tortoise.utils.torch_intermediary as ml
|
||||
|
||||
from tortoise.utils.device import print_stats, do_gc
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
@ -124,14 +126,13 @@ class CLVP(nn.Module):
|
|||
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
|
||||
speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
|
||||
|
||||
enc_text = self.text_transformer(text_emb, mask=text_mask)
|
||||
enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
|
||||
|
||||
text_latents = self.to_text_latent(masked_mean(self.text_transformer(text_emb, mask=text_mask), text_mask, dim=1))
|
||||
|
||||
text_latents = masked_mean(enc_text, text_mask, dim=1)
|
||||
speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
|
||||
|
||||
text_latents = self.to_text_latent(text_latents)
|
||||
speech_latents = self.to_speech_latent(speech_latents)
|
||||
# on ROCm at least, allocated VRAM spikes here
|
||||
do_gc()
|
||||
speech_latents = self.to_speech_latent(masked_mean(self.speech_transformer(speech_emb, mask=voice_mask), voice_mask, dim=1))
|
||||
do_gc()
|
||||
|
||||
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
||||
|
||||
|
|
|
@ -5,6 +5,29 @@ import importlib
|
|||
DEVICE_OVERRIDE = None
|
||||
DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
|
||||
|
||||
from inspect import currentframe, getframeinfo
|
||||
import gc
|
||||
|
||||
def do_gc():
|
||||
gc.collect()
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def print_stats(collect=False):
|
||||
cf = currentframe().f_back
|
||||
msg = f'{getframeinfo(cf).filename}:{cf.f_lineno}'
|
||||
|
||||
if collect:
|
||||
do_gc()
|
||||
|
||||
tot = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
|
||||
res = torch.cuda.memory_reserved(0) / (1024 ** 3)
|
||||
alloc = torch.cuda.memory_allocated(0) / (1024 ** 3)
|
||||
print("[{}] Total: {:.3f} | Reserved: {:.3f} | Allocated: {:.3f} | Free: {:.3f}".format( msg, tot, res, alloc, tot-res ))
|
||||
|
||||
|
||||
def has_dml():
|
||||
loader = importlib.find_loader('torch_directml')
|
||||
if loader is None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user