maybe solved that odd VRAM spike when doing the clvp pass

This commit is contained in:
mrq 2023-03-12 12:48:29 -05:00
parent fec0685405
commit 97cd58e7eb
3 changed files with 108 additions and 75 deletions

View File

@ -29,7 +29,7 @@ from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named
from tortoise.utils.tokenizer import VoiceBpeTokenizer from tortoise.utils.tokenizer import VoiceBpeTokenizer
from tortoise.utils.wav2vec_alignment import Wav2VecAlignment from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
from tortoise.utils.device import get_device, get_device_name, get_device_batch_size from tortoise.utils.device import get_device, get_device_name, get_device_batch_size, print_stats, do_gc
pbar = None pbar = None
STOP_SIGNAL = False STOP_SIGNAL = False
@ -172,8 +172,8 @@ def format_conditioning(clip, cond_length=132300, device='cuda', sampling_rate=2
rand_start = random.randint(0, gap) rand_start = random.randint(0, gap)
clip = clip[:, rand_start:rand_start + cond_length] clip = clip[:, rand_start:rand_start + cond_length]
mel_clip = TorchMelSpectrogram(sampling_rate=sampling_rate)(clip.unsqueeze(0)).squeeze(0) mel_clip = TorchMelSpectrogram(sampling_rate=sampling_rate)(clip.unsqueeze(0)).squeeze(0)
return mel_clip.unsqueeze(0).to(device) mel_clip = mel_clip.unsqueeze(0)
return migrate_to_device(mel_clip, device)
def fix_autoregressive_output(codes, stop_token, complain=True): def fix_autoregressive_output(codes, stop_token, complain=True):
""" """
@ -241,6 +241,25 @@ def classify_audio_clip(clip):
results = F.softmax(classifier(clip), dim=-1) results = F.softmax(classifier(clip), dim=-1)
return results[0][0] return results[0][0]
def migrate_to_device( t, device ):
if t is None:
return t
if not hasattr(t, 'device'):
t.device = device
t.manually_track_device = True
elif t.device == device:
return t
if hasattr(t, 'manually_track_device') and t.manually_track_device:
t.device = device
t = t.to(device)
do_gc()
return t
class TextToSpeech: class TextToSpeech:
""" """
Main entry point into Tortoise. Main entry point into Tortoise.
@ -315,10 +334,11 @@ class TextToSpeech:
self.rlg_diffusion = None self.rlg_diffusion = None
if self.preloaded_tensors: if self.preloaded_tensors:
self.autoregressive = self.autoregressive.to(self.device) self.autoregressive = migrate_to_device( self.autoregressive, self.device )
self.diffusion = self.diffusion.to(self.device) self.diffusion = migrate_to_device( self.diffusion, self.device )
self.clvp = self.clvp.to(self.device) self.clvp = migrate_to_device( self.clvp, self.device )
self.vocoder = self.vocoder.to(self.device) self.vocoder = migrate_to_device( self.vocoder, self.device )
self.loading = False self.loading = False
def load_autoregressive_model(self, autoregressive_model_path): def load_autoregressive_model(self, autoregressive_model_path):
@ -341,7 +361,7 @@ class TextToSpeech:
self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path)) self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path))
self.autoregressive.post_init_gpt2_config(kv_cache=self.use_kv_cache) self.autoregressive.post_init_gpt2_config(kv_cache=self.use_kv_cache)
if self.preloaded_tensors: if self.preloaded_tensors:
self.autoregressive = self.autoregressive.to(self.device) self.autoregressive = migrate_to_device( self.autoregressive, self.device )
self.loading = False self.loading = False
print(f"Loaded autoregressive model") print(f"Loaded autoregressive model")
@ -382,7 +402,7 @@ class TextToSpeech:
self.vocoder.eval(inference=True) self.vocoder.eval(inference=True)
if self.preloaded_tensors: if self.preloaded_tensors:
self.vocoder = self.vocoder.to(self.device) self.vocoder = migrate_to_device( self.vocoder, self.device )
self.loading = False self.loading = False
print(f"Loaded vocoder model") print(f"Loaded vocoder model")
@ -393,7 +413,7 @@ class TextToSpeech:
self.cvvp.load_state_dict(torch.load(get_model_path('cvvp.pth', self.models_dir))) self.cvvp.load_state_dict(torch.load(get_model_path('cvvp.pth', self.models_dir)))
if self.preloaded_tensors: if self.preloaded_tensors:
self.cvvp = self.cvvp.to(self.device) self.cvvp = migrate_to_device( self.cvvp, self.device )
def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, progress=None, slices=1, max_chunk_size=None, force_cpu=False): def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, progress=None, slices=1, max_chunk_size=None, force_cpu=False):
""" """
@ -411,7 +431,7 @@ class TextToSpeech:
if not isinstance(voice_samples, list): if not isinstance(voice_samples, list):
voice_samples = [voice_samples] voice_samples = [voice_samples]
voice_samples = [v.to(device) for v in voice_samples] voice_samples = [migrate_to_device(v, device) for v in voice_samples]
resampler = torchaudio.transforms.Resample( resampler = torchaudio.transforms.Resample(
self.input_sample_rate, self.input_sample_rate,
@ -420,24 +440,19 @@ class TextToSpeech:
rolloff=0.85, rolloff=0.85,
resampling_method="kaiser_window", resampling_method="kaiser_window",
beta=8.555504641634386, beta=8.555504641634386,
) ).to(device)
samples = [] samples = []
auto_conds = [] auto_conds = []
for sample in voice_samples: for sample in voice_samples:
auto_conds.append(format_conditioning(sample, device=device, sampling_rate=self.input_sample_rate)) auto_conds.append(format_conditioning(sample, device=device, sampling_rate=self.input_sample_rate))
samples.append(resampler(sample.cpu()).to(device)) # icky no good, easier to do the resampling on CPU than figure out how to do it on GPU samples.append(resampler(sample))
auto_conds = torch.stack(auto_conds, dim=1) auto_conds = torch.stack(auto_conds, dim=1)
self.autoregressive = migrate_to_device( self.autoregressive, device )
self.autoregressive = self.autoregressive.to(device)
auto_latent = self.autoregressive.get_conditioning(auto_conds) auto_latent = self.autoregressive.get_conditioning(auto_conds)
if self.preloaded_tensors: self.autoregressive = migrate_to_device( self.autoregressive, self.device if self.preloaded_tensors else 'cpu' )
self.autoregressive = self.autoregressive.to(self.device)
else:
self.autoregressive = self.autoregressive.cpu()
diffusion_conds = [] diffusion_conds = []
chunks = [] chunks = []
@ -460,21 +475,14 @@ class TextToSpeech:
for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."): for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."):
check_for_kill_signal() check_for_kill_signal()
chunk = pad_or_truncate(chunk, chunk_size) chunk = pad_or_truncate(chunk, chunk_size)
cond_mel = wav_to_univnet_mel(chunk.to(device), do_normalization=False, device=device) cond_mel = wav_to_univnet_mel(migrate_to_device( chunk, device ), do_normalization=False, device=device)
diffusion_conds.append(cond_mel) diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1) diffusion_conds = torch.stack(diffusion_conds, dim=1)
self.diffusion = self.diffusion.to(device) self.diffusion = migrate_to_device( self.diffusion, device )
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds) diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
self.diffusion = migrate_to_device( self.diffusion, self.device if self.preloaded_tensors else 'cpu' )
if self.preloaded_tensors:
self.diffusion = self.diffusion.to(self.device)
else:
self.diffusion = self.diffusion.cpu()
if return_mels: if return_mels:
return auto_latent, diffusion_latent, auto_conds, diffusion_conds return auto_latent, diffusion_latent, auto_conds, diffusion_conds
@ -587,7 +595,9 @@ class TextToSpeech:
elif autoregressive_model != self.autoregressive_model_path: elif autoregressive_model != self.autoregressive_model_path:
self.load_autoregressive_model(autoregressive_model) self.load_autoregressive_model(autoregressive_model)
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0)
text_tokens = migrate_to_device( text_tokens, self.device )
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.' assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
@ -615,9 +625,9 @@ class TextToSpeech:
stop_mel_token = self.autoregressive.stop_mel_token stop_mel_token = self.autoregressive.stop_mel_token
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
self.autoregressive = self.autoregressive.to(self.device) self.autoregressive = migrate_to_device( self.autoregressive, self.device )
auto_conditioning = auto_conditioning.to(self.device) auto_conditioning = migrate_to_device( auto_conditioning, self.device )
text_tokens = text_tokens.to(self.device) text_tokens = migrate_to_device( text_tokens, self.device )
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p): with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p):
for b in tqdm_override(range(num_batches), verbose=verbose, progress=progress, desc="Generating autoregressive samples"): for b in tqdm_override(range(num_batches), verbose=verbose, progress=progress, desc="Generating autoregressive samples"):
@ -636,24 +646,24 @@ class TextToSpeech:
samples.append(codes) samples.append(codes)
if not self.preloaded_tensors: if not self.preloaded_tensors:
self.autoregressive = self.autoregressive.cpu() self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' )
auto_conditioning = auto_conditioning.cpu()
clip_results = [] clip_results = []
if auto_conds is not None: if auto_conds is not None:
auto_conds = auto_conds.to(self.device) auto_conditioning = migrate_to_device( auto_conditioning, self.device )
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p): with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p):
if not self.minor_optimizations: if not self.preloaded_tensors:
self.autoregressive = self.autoregressive.cpu() self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' )
self.clvp = self.clvp.to(self.device) self.clvp = migrate_to_device( self.clvp, self.device )
if cvvp_amount > 0: if cvvp_amount > 0:
if self.cvvp is None: if self.cvvp is None:
self.load_cvvp() self.load_cvvp()
if not self.minor_optimizations:
self.cvvp = self.cvvp.to(self.device) if not self.preloaded_tensors:
self.cvvp = migrate_to_device( self.cvvp, self.device )
desc="Computing best candidates" desc="Computing best candidates"
if verbose: if verbose:
@ -662,6 +672,7 @@ class TextToSpeech:
else: else:
desc = f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%" desc = f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%"
for batch in tqdm_override(samples, verbose=verbose, progress=progress, desc=desc): for batch in tqdm_override(samples, verbose=verbose, progress=progress, desc=desc):
check_for_kill_signal() check_for_kill_signal()
for i in range(batch.shape[0]): for i in range(batch.shape[0]):
@ -683,30 +694,28 @@ class TextToSpeech:
clip_results.append(clvp) clip_results.append(clvp)
if not self.preloaded_tensors and auto_conds is not None: if not self.preloaded_tensors and auto_conds is not None:
auto_conds = auto_conds.cpu() auto_conds = migrate_to_device( auto_conds, 'cpu' )
clip_results = torch.cat(clip_results, dim=0) clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0) samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=k).indices] best_results = samples[torch.topk(clip_results, k=k).indices]
if not self.preloaded_tensors: if not self.preloaded_tensors:
self.clvp = self.clvp.cpu() self.clvp = migrate_to_device( self.clvp, 'cpu' )
if self.cvvp is not None: self.cvvp = migrate_to_device( self.cvvp, 'cpu' )
self.cvvp = self.cvvp.cpu()
del samples
if get_device_name() == "dml": if get_device_name() == "dml":
text_tokens = text_tokens.cpu() text_tokens = migrate_to_device( text_tokens, 'cpu' )
best_results = best_results.cpu() best_results = migrate_to_device( best_results, 'cpu' )
auto_conditioning = auto_conditioning.cpu() auto_conditioning = migrate_to_device( auto_conditioning, 'cpu' )
self.autoregressive = self.autoregressive.cpu() self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' )
else: else:
#text_tokens = text_tokens.to(self.device)
#best_results = best_results.to(self.device)
auto_conditioning = auto_conditioning.to(self.device) auto_conditioning = auto_conditioning.to(self.device)
self.autoregressive = self.autoregressive.to(self.device) self.autoregressive = self.autoregressive.to(self.device)
del samples
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage. # results, but will increase memory usage.
@ -715,21 +724,19 @@ class TextToSpeech:
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
return_latent=True, clip_inputs=False) return_latent=True, clip_inputs=False)
diffusion_conditioning = diffusion_conditioning.to(self.device) diffusion_conditioning = migrate_to_device( diffusion_conditioning, self.device )
if get_device_name() == "dml": if get_device_name() == "dml":
self.autoregressive = self.autoregressive.to(self.device) self.autoregressive = migrate_to_device( self.autoregressive, self.device )
best_results = best_results.to(self.device) best_results = migrate_to_device( best_results, self.device )
best_latents = best_latents.to(self.device) best_latents = migrate_to_device( best_latents, self.device )
self.vocoder = migrate_to_device( self.vocoder, 'cpu' )
self.vocoder = self.vocoder.cpu()
else: else:
if not self.preloaded_tensors: if not self.preloaded_tensors:
self.autoregressive = self.autoregressive.cpu() self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' )
self.diffusion = self.diffusion.to(self.device)
self.vocoder = self.vocoder.to(self.device)
self.diffusion = migrate_to_device( self.diffusion, self.device )
self.vocoder = migrate_to_device( self.vocoder, self.device )
del text_tokens del text_tokens
del auto_conditioning del auto_conditioning
@ -758,12 +765,14 @@ class TextToSpeech:
wav_candidates.append(wav) wav_candidates.append(wav)
if not self.preloaded_tensors: if not self.preloaded_tensors:
self.diffusion = self.diffusion.cpu() self.diffusion = migrate_to_device( self.diffusion, 'cpu' )
self.vocoder = self.vocoder.cpu() self.vocoder = migrate_to_device( self.vocoder, 'cpu' )
def potentially_redact(clip, text): def potentially_redact(clip, text):
if self.enable_redaction: if self.enable_redaction:
return self.aligner.redact(clip.squeeze(1).to('cpu' if get_device_name() == "dml" else self.device), text, self.output_sample_rate).unsqueeze(1) t = clip.squeeze(1)
t = migrate_to_device( t, 'cpu' if get_device_name() == "dml" else self.device)
return self.aligner.redact(t, text, self.output_sample_rate).unsqueeze(1)
return clip return clip
wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates] wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
@ -772,7 +781,7 @@ class TextToSpeech:
else: else:
res = wav_candidates[0] res = wav_candidates[0]
gc.collect() do_gc()
if return_deterministic_state: if return_deterministic_state:
return res, (deterministic_seed, text, voice_samples, conditioning_latents) return res, (deterministic_seed, text, voice_samples, conditioning_latents)

View File

@ -9,6 +9,8 @@ from tortoise.models.xtransformers import Encoder
import tortoise.utils.torch_intermediary as ml import tortoise.utils.torch_intermediary as ml
from tortoise.utils.device import print_stats, do_gc
def exists(val): def exists(val):
return val is not None return val is not None
@ -124,14 +126,13 @@ class CLVP(nn.Module):
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device)) text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device)) speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
enc_text = self.text_transformer(text_emb, mask=text_mask)
enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
text_latents = masked_mean(enc_text, text_mask, dim=1) text_latents = self.to_text_latent(masked_mean(self.text_transformer(text_emb, mask=text_mask), text_mask, dim=1))
speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
text_latents = self.to_text_latent(text_latents) # on ROCm at least, allocated VRAM spikes here
speech_latents = self.to_speech_latent(speech_latents) do_gc()
speech_latents = self.to_speech_latent(masked_mean(self.speech_transformer(speech_emb, mask=voice_mask), voice_mask, dim=1))
do_gc()
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)) text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))

View File

@ -5,6 +5,29 @@ import importlib
DEVICE_OVERRIDE = None DEVICE_OVERRIDE = None
DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)] DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
from inspect import currentframe, getframeinfo
import gc
def do_gc():
gc.collect()
try:
torch.cuda.empty_cache()
except Exception as e:
pass
def print_stats(collect=False):
cf = currentframe().f_back
msg = f'{getframeinfo(cf).filename}:{cf.f_lineno}'
if collect:
do_gc()
tot = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
res = torch.cuda.memory_reserved(0) / (1024 ** 3)
alloc = torch.cuda.memory_allocated(0) / (1024 ** 3)
print("[{}] Total: {:.3f} | Reserved: {:.3f} | Allocated: {:.3f} | Free: {:.3f}".format( msg, tot, res, alloc, tot-res ))
def has_dml(): def has_dml():
loader = importlib.find_loader('torch_directml') loader = importlib.find_loader('torch_directml')
if loader is None: if loader is None: