forked from mrq/tortoise-tts
maybe solved that odd VRAM spike when doing the clvp pass
This commit is contained in:
parent
fec0685405
commit
97cd58e7eb
143
tortoise/api.py
143
tortoise/api.py
|
@ -29,7 +29,7 @@ from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named
|
||||||
from tortoise.utils.tokenizer import VoiceBpeTokenizer
|
from tortoise.utils.tokenizer import VoiceBpeTokenizer
|
||||||
from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
|
from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
|
||||||
|
|
||||||
from tortoise.utils.device import get_device, get_device_name, get_device_batch_size
|
from tortoise.utils.device import get_device, get_device_name, get_device_batch_size, print_stats, do_gc
|
||||||
|
|
||||||
pbar = None
|
pbar = None
|
||||||
STOP_SIGNAL = False
|
STOP_SIGNAL = False
|
||||||
|
@ -172,8 +172,8 @@ def format_conditioning(clip, cond_length=132300, device='cuda', sampling_rate=2
|
||||||
rand_start = random.randint(0, gap)
|
rand_start = random.randint(0, gap)
|
||||||
clip = clip[:, rand_start:rand_start + cond_length]
|
clip = clip[:, rand_start:rand_start + cond_length]
|
||||||
mel_clip = TorchMelSpectrogram(sampling_rate=sampling_rate)(clip.unsqueeze(0)).squeeze(0)
|
mel_clip = TorchMelSpectrogram(sampling_rate=sampling_rate)(clip.unsqueeze(0)).squeeze(0)
|
||||||
return mel_clip.unsqueeze(0).to(device)
|
mel_clip = mel_clip.unsqueeze(0)
|
||||||
|
return migrate_to_device(mel_clip, device)
|
||||||
|
|
||||||
def fix_autoregressive_output(codes, stop_token, complain=True):
|
def fix_autoregressive_output(codes, stop_token, complain=True):
|
||||||
"""
|
"""
|
||||||
|
@ -241,6 +241,25 @@ def classify_audio_clip(clip):
|
||||||
results = F.softmax(classifier(clip), dim=-1)
|
results = F.softmax(classifier(clip), dim=-1)
|
||||||
return results[0][0]
|
return results[0][0]
|
||||||
|
|
||||||
|
def migrate_to_device( t, device ):
|
||||||
|
if t is None:
|
||||||
|
return t
|
||||||
|
|
||||||
|
if not hasattr(t, 'device'):
|
||||||
|
t.device = device
|
||||||
|
t.manually_track_device = True
|
||||||
|
elif t.device == device:
|
||||||
|
return t
|
||||||
|
|
||||||
|
if hasattr(t, 'manually_track_device') and t.manually_track_device:
|
||||||
|
t.device = device
|
||||||
|
|
||||||
|
t = t.to(device)
|
||||||
|
|
||||||
|
do_gc()
|
||||||
|
|
||||||
|
return t
|
||||||
|
|
||||||
class TextToSpeech:
|
class TextToSpeech:
|
||||||
"""
|
"""
|
||||||
Main entry point into Tortoise.
|
Main entry point into Tortoise.
|
||||||
|
@ -315,10 +334,11 @@ class TextToSpeech:
|
||||||
self.rlg_diffusion = None
|
self.rlg_diffusion = None
|
||||||
|
|
||||||
if self.preloaded_tensors:
|
if self.preloaded_tensors:
|
||||||
self.autoregressive = self.autoregressive.to(self.device)
|
self.autoregressive = migrate_to_device( self.autoregressive, self.device )
|
||||||
self.diffusion = self.diffusion.to(self.device)
|
self.diffusion = migrate_to_device( self.diffusion, self.device )
|
||||||
self.clvp = self.clvp.to(self.device)
|
self.clvp = migrate_to_device( self.clvp, self.device )
|
||||||
self.vocoder = self.vocoder.to(self.device)
|
self.vocoder = migrate_to_device( self.vocoder, self.device )
|
||||||
|
|
||||||
self.loading = False
|
self.loading = False
|
||||||
|
|
||||||
def load_autoregressive_model(self, autoregressive_model_path):
|
def load_autoregressive_model(self, autoregressive_model_path):
|
||||||
|
@ -341,7 +361,7 @@ class TextToSpeech:
|
||||||
self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path))
|
self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path))
|
||||||
self.autoregressive.post_init_gpt2_config(kv_cache=self.use_kv_cache)
|
self.autoregressive.post_init_gpt2_config(kv_cache=self.use_kv_cache)
|
||||||
if self.preloaded_tensors:
|
if self.preloaded_tensors:
|
||||||
self.autoregressive = self.autoregressive.to(self.device)
|
self.autoregressive = migrate_to_device( self.autoregressive, self.device )
|
||||||
|
|
||||||
self.loading = False
|
self.loading = False
|
||||||
print(f"Loaded autoregressive model")
|
print(f"Loaded autoregressive model")
|
||||||
|
@ -382,7 +402,7 @@ class TextToSpeech:
|
||||||
|
|
||||||
self.vocoder.eval(inference=True)
|
self.vocoder.eval(inference=True)
|
||||||
if self.preloaded_tensors:
|
if self.preloaded_tensors:
|
||||||
self.vocoder = self.vocoder.to(self.device)
|
self.vocoder = migrate_to_device( self.vocoder, self.device )
|
||||||
self.loading = False
|
self.loading = False
|
||||||
print(f"Loaded vocoder model")
|
print(f"Loaded vocoder model")
|
||||||
|
|
||||||
|
@ -393,7 +413,7 @@ class TextToSpeech:
|
||||||
self.cvvp.load_state_dict(torch.load(get_model_path('cvvp.pth', self.models_dir)))
|
self.cvvp.load_state_dict(torch.load(get_model_path('cvvp.pth', self.models_dir)))
|
||||||
|
|
||||||
if self.preloaded_tensors:
|
if self.preloaded_tensors:
|
||||||
self.cvvp = self.cvvp.to(self.device)
|
self.cvvp = migrate_to_device( self.cvvp, self.device )
|
||||||
|
|
||||||
def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, progress=None, slices=1, max_chunk_size=None, force_cpu=False):
|
def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, progress=None, slices=1, max_chunk_size=None, force_cpu=False):
|
||||||
"""
|
"""
|
||||||
|
@ -411,7 +431,7 @@ class TextToSpeech:
|
||||||
if not isinstance(voice_samples, list):
|
if not isinstance(voice_samples, list):
|
||||||
voice_samples = [voice_samples]
|
voice_samples = [voice_samples]
|
||||||
|
|
||||||
voice_samples = [v.to(device) for v in voice_samples]
|
voice_samples = [migrate_to_device(v, device) for v in voice_samples]
|
||||||
|
|
||||||
resampler = torchaudio.transforms.Resample(
|
resampler = torchaudio.transforms.Resample(
|
||||||
self.input_sample_rate,
|
self.input_sample_rate,
|
||||||
|
@ -420,24 +440,19 @@ class TextToSpeech:
|
||||||
rolloff=0.85,
|
rolloff=0.85,
|
||||||
resampling_method="kaiser_window",
|
resampling_method="kaiser_window",
|
||||||
beta=8.555504641634386,
|
beta=8.555504641634386,
|
||||||
)
|
).to(device)
|
||||||
|
|
||||||
samples = []
|
samples = []
|
||||||
auto_conds = []
|
auto_conds = []
|
||||||
for sample in voice_samples:
|
for sample in voice_samples:
|
||||||
auto_conds.append(format_conditioning(sample, device=device, sampling_rate=self.input_sample_rate))
|
auto_conds.append(format_conditioning(sample, device=device, sampling_rate=self.input_sample_rate))
|
||||||
samples.append(resampler(sample.cpu()).to(device)) # icky no good, easier to do the resampling on CPU than figure out how to do it on GPU
|
samples.append(resampler(sample))
|
||||||
|
|
||||||
auto_conds = torch.stack(auto_conds, dim=1)
|
auto_conds = torch.stack(auto_conds, dim=1)
|
||||||
|
|
||||||
|
self.autoregressive = migrate_to_device( self.autoregressive, device )
|
||||||
self.autoregressive = self.autoregressive.to(device)
|
|
||||||
auto_latent = self.autoregressive.get_conditioning(auto_conds)
|
auto_latent = self.autoregressive.get_conditioning(auto_conds)
|
||||||
if self.preloaded_tensors:
|
self.autoregressive = migrate_to_device( self.autoregressive, self.device if self.preloaded_tensors else 'cpu' )
|
||||||
self.autoregressive = self.autoregressive.to(self.device)
|
|
||||||
else:
|
|
||||||
self.autoregressive = self.autoregressive.cpu()
|
|
||||||
|
|
||||||
|
|
||||||
diffusion_conds = []
|
diffusion_conds = []
|
||||||
chunks = []
|
chunks = []
|
||||||
|
@ -460,21 +475,14 @@ class TextToSpeech:
|
||||||
for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."):
|
for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."):
|
||||||
check_for_kill_signal()
|
check_for_kill_signal()
|
||||||
chunk = pad_or_truncate(chunk, chunk_size)
|
chunk = pad_or_truncate(chunk, chunk_size)
|
||||||
cond_mel = wav_to_univnet_mel(chunk.to(device), do_normalization=False, device=device)
|
cond_mel = wav_to_univnet_mel(migrate_to_device( chunk, device ), do_normalization=False, device=device)
|
||||||
diffusion_conds.append(cond_mel)
|
diffusion_conds.append(cond_mel)
|
||||||
|
|
||||||
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
||||||
|
|
||||||
self.diffusion = self.diffusion.to(device)
|
self.diffusion = migrate_to_device( self.diffusion, device )
|
||||||
|
|
||||||
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
|
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
|
||||||
|
self.diffusion = migrate_to_device( self.diffusion, self.device if self.preloaded_tensors else 'cpu' )
|
||||||
if self.preloaded_tensors:
|
|
||||||
self.diffusion = self.diffusion.to(self.device)
|
|
||||||
else:
|
|
||||||
self.diffusion = self.diffusion.cpu()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if return_mels:
|
if return_mels:
|
||||||
return auto_latent, diffusion_latent, auto_conds, diffusion_conds
|
return auto_latent, diffusion_latent, auto_conds, diffusion_conds
|
||||||
|
@ -587,7 +595,9 @@ class TextToSpeech:
|
||||||
elif autoregressive_model != self.autoregressive_model_path:
|
elif autoregressive_model != self.autoregressive_model_path:
|
||||||
self.load_autoregressive_model(autoregressive_model)
|
self.load_autoregressive_model(autoregressive_model)
|
||||||
|
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
|
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0)
|
||||||
|
text_tokens = migrate_to_device( text_tokens, self.device )
|
||||||
|
|
||||||
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
|
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
|
||||||
assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
|
assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
|
||||||
|
|
||||||
|
@ -615,9 +625,9 @@ class TextToSpeech:
|
||||||
stop_mel_token = self.autoregressive.stop_mel_token
|
stop_mel_token = self.autoregressive.stop_mel_token
|
||||||
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
|
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
|
||||||
|
|
||||||
self.autoregressive = self.autoregressive.to(self.device)
|
self.autoregressive = migrate_to_device( self.autoregressive, self.device )
|
||||||
auto_conditioning = auto_conditioning.to(self.device)
|
auto_conditioning = migrate_to_device( auto_conditioning, self.device )
|
||||||
text_tokens = text_tokens.to(self.device)
|
text_tokens = migrate_to_device( text_tokens, self.device )
|
||||||
|
|
||||||
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p):
|
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p):
|
||||||
for b in tqdm_override(range(num_batches), verbose=verbose, progress=progress, desc="Generating autoregressive samples"):
|
for b in tqdm_override(range(num_batches), verbose=verbose, progress=progress, desc="Generating autoregressive samples"):
|
||||||
|
@ -636,24 +646,24 @@ class TextToSpeech:
|
||||||
samples.append(codes)
|
samples.append(codes)
|
||||||
|
|
||||||
if not self.preloaded_tensors:
|
if not self.preloaded_tensors:
|
||||||
self.autoregressive = self.autoregressive.cpu()
|
self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' )
|
||||||
auto_conditioning = auto_conditioning.cpu()
|
|
||||||
|
|
||||||
clip_results = []
|
clip_results = []
|
||||||
|
|
||||||
if auto_conds is not None:
|
if auto_conds is not None:
|
||||||
auto_conds = auto_conds.to(self.device)
|
auto_conditioning = migrate_to_device( auto_conditioning, self.device )
|
||||||
|
|
||||||
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p):
|
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=half_p):
|
||||||
if not self.minor_optimizations:
|
if not self.preloaded_tensors:
|
||||||
self.autoregressive = self.autoregressive.cpu()
|
self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' )
|
||||||
self.clvp = self.clvp.to(self.device)
|
self.clvp = migrate_to_device( self.clvp, self.device )
|
||||||
|
|
||||||
if cvvp_amount > 0:
|
if cvvp_amount > 0:
|
||||||
if self.cvvp is None:
|
if self.cvvp is None:
|
||||||
self.load_cvvp()
|
self.load_cvvp()
|
||||||
if not self.minor_optimizations:
|
|
||||||
self.cvvp = self.cvvp.to(self.device)
|
if not self.preloaded_tensors:
|
||||||
|
self.cvvp = migrate_to_device( self.cvvp, self.device )
|
||||||
|
|
||||||
desc="Computing best candidates"
|
desc="Computing best candidates"
|
||||||
if verbose:
|
if verbose:
|
||||||
|
@ -662,6 +672,7 @@ class TextToSpeech:
|
||||||
else:
|
else:
|
||||||
desc = f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%"
|
desc = f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%"
|
||||||
|
|
||||||
|
|
||||||
for batch in tqdm_override(samples, verbose=verbose, progress=progress, desc=desc):
|
for batch in tqdm_override(samples, verbose=verbose, progress=progress, desc=desc):
|
||||||
check_for_kill_signal()
|
check_for_kill_signal()
|
||||||
for i in range(batch.shape[0]):
|
for i in range(batch.shape[0]):
|
||||||
|
@ -683,30 +694,28 @@ class TextToSpeech:
|
||||||
clip_results.append(clvp)
|
clip_results.append(clvp)
|
||||||
|
|
||||||
if not self.preloaded_tensors and auto_conds is not None:
|
if not self.preloaded_tensors and auto_conds is not None:
|
||||||
auto_conds = auto_conds.cpu()
|
auto_conds = migrate_to_device( auto_conds, 'cpu' )
|
||||||
|
|
||||||
clip_results = torch.cat(clip_results, dim=0)
|
clip_results = torch.cat(clip_results, dim=0)
|
||||||
samples = torch.cat(samples, dim=0)
|
samples = torch.cat(samples, dim=0)
|
||||||
best_results = samples[torch.topk(clip_results, k=k).indices]
|
best_results = samples[torch.topk(clip_results, k=k).indices]
|
||||||
|
|
||||||
if not self.preloaded_tensors:
|
if not self.preloaded_tensors:
|
||||||
self.clvp = self.clvp.cpu()
|
self.clvp = migrate_to_device( self.clvp, 'cpu' )
|
||||||
if self.cvvp is not None:
|
self.cvvp = migrate_to_device( self.cvvp, 'cpu' )
|
||||||
self.cvvp = self.cvvp.cpu()
|
|
||||||
|
|
||||||
del samples
|
|
||||||
|
|
||||||
if get_device_name() == "dml":
|
if get_device_name() == "dml":
|
||||||
text_tokens = text_tokens.cpu()
|
text_tokens = migrate_to_device( text_tokens, 'cpu' )
|
||||||
best_results = best_results.cpu()
|
best_results = migrate_to_device( best_results, 'cpu' )
|
||||||
auto_conditioning = auto_conditioning.cpu()
|
auto_conditioning = migrate_to_device( auto_conditioning, 'cpu' )
|
||||||
self.autoregressive = self.autoregressive.cpu()
|
self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' )
|
||||||
else:
|
else:
|
||||||
#text_tokens = text_tokens.to(self.device)
|
|
||||||
#best_results = best_results.to(self.device)
|
|
||||||
auto_conditioning = auto_conditioning.to(self.device)
|
auto_conditioning = auto_conditioning.to(self.device)
|
||||||
self.autoregressive = self.autoregressive.to(self.device)
|
self.autoregressive = self.autoregressive.to(self.device)
|
||||||
|
|
||||||
|
del samples
|
||||||
|
|
||||||
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
|
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
|
||||||
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
|
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
|
||||||
# results, but will increase memory usage.
|
# results, but will increase memory usage.
|
||||||
|
@ -715,21 +724,19 @@ class TextToSpeech:
|
||||||
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
|
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
|
||||||
return_latent=True, clip_inputs=False)
|
return_latent=True, clip_inputs=False)
|
||||||
|
|
||||||
diffusion_conditioning = diffusion_conditioning.to(self.device)
|
diffusion_conditioning = migrate_to_device( diffusion_conditioning, self.device )
|
||||||
|
|
||||||
if get_device_name() == "dml":
|
if get_device_name() == "dml":
|
||||||
self.autoregressive = self.autoregressive.to(self.device)
|
self.autoregressive = migrate_to_device( self.autoregressive, self.device )
|
||||||
best_results = best_results.to(self.device)
|
best_results = migrate_to_device( best_results, self.device )
|
||||||
best_latents = best_latents.to(self.device)
|
best_latents = migrate_to_device( best_latents, self.device )
|
||||||
|
self.vocoder = migrate_to_device( self.vocoder, 'cpu' )
|
||||||
self.vocoder = self.vocoder.cpu()
|
|
||||||
else:
|
else:
|
||||||
if not self.preloaded_tensors:
|
if not self.preloaded_tensors:
|
||||||
self.autoregressive = self.autoregressive.cpu()
|
self.autoregressive = migrate_to_device( self.autoregressive, 'cpu' )
|
||||||
|
|
||||||
self.diffusion = self.diffusion.to(self.device)
|
|
||||||
self.vocoder = self.vocoder.to(self.device)
|
|
||||||
|
|
||||||
|
self.diffusion = migrate_to_device( self.diffusion, self.device )
|
||||||
|
self.vocoder = migrate_to_device( self.vocoder, self.device )
|
||||||
|
|
||||||
del text_tokens
|
del text_tokens
|
||||||
del auto_conditioning
|
del auto_conditioning
|
||||||
|
@ -758,12 +765,14 @@ class TextToSpeech:
|
||||||
wav_candidates.append(wav)
|
wav_candidates.append(wav)
|
||||||
|
|
||||||
if not self.preloaded_tensors:
|
if not self.preloaded_tensors:
|
||||||
self.diffusion = self.diffusion.cpu()
|
self.diffusion = migrate_to_device( self.diffusion, 'cpu' )
|
||||||
self.vocoder = self.vocoder.cpu()
|
self.vocoder = migrate_to_device( self.vocoder, 'cpu' )
|
||||||
|
|
||||||
def potentially_redact(clip, text):
|
def potentially_redact(clip, text):
|
||||||
if self.enable_redaction:
|
if self.enable_redaction:
|
||||||
return self.aligner.redact(clip.squeeze(1).to('cpu' if get_device_name() == "dml" else self.device), text, self.output_sample_rate).unsqueeze(1)
|
t = clip.squeeze(1)
|
||||||
|
t = migrate_to_device( t, 'cpu' if get_device_name() == "dml" else self.device)
|
||||||
|
return self.aligner.redact(t, text, self.output_sample_rate).unsqueeze(1)
|
||||||
return clip
|
return clip
|
||||||
wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
|
wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
|
||||||
|
|
||||||
|
@ -772,7 +781,7 @@ class TextToSpeech:
|
||||||
else:
|
else:
|
||||||
res = wav_candidates[0]
|
res = wav_candidates[0]
|
||||||
|
|
||||||
gc.collect()
|
do_gc()
|
||||||
|
|
||||||
if return_deterministic_state:
|
if return_deterministic_state:
|
||||||
return res, (deterministic_seed, text, voice_samples, conditioning_latents)
|
return res, (deterministic_seed, text, voice_samples, conditioning_latents)
|
||||||
|
|
|
@ -9,6 +9,8 @@ from tortoise.models.xtransformers import Encoder
|
||||||
|
|
||||||
import tortoise.utils.torch_intermediary as ml
|
import tortoise.utils.torch_intermediary as ml
|
||||||
|
|
||||||
|
from tortoise.utils.device import print_stats, do_gc
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
|
@ -124,14 +126,13 @@ class CLVP(nn.Module):
|
||||||
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
|
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
|
||||||
speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
|
speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
|
||||||
|
|
||||||
enc_text = self.text_transformer(text_emb, mask=text_mask)
|
|
||||||
enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
|
|
||||||
|
|
||||||
text_latents = masked_mean(enc_text, text_mask, dim=1)
|
text_latents = self.to_text_latent(masked_mean(self.text_transformer(text_emb, mask=text_mask), text_mask, dim=1))
|
||||||
speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
|
|
||||||
|
|
||||||
text_latents = self.to_text_latent(text_latents)
|
# on ROCm at least, allocated VRAM spikes here
|
||||||
speech_latents = self.to_speech_latent(speech_latents)
|
do_gc()
|
||||||
|
speech_latents = self.to_speech_latent(masked_mean(self.speech_transformer(speech_emb, mask=voice_mask), voice_mask, dim=1))
|
||||||
|
do_gc()
|
||||||
|
|
||||||
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,29 @@ import importlib
|
||||||
DEVICE_OVERRIDE = None
|
DEVICE_OVERRIDE = None
|
||||||
DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
|
DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
|
||||||
|
|
||||||
|
from inspect import currentframe, getframeinfo
|
||||||
|
import gc
|
||||||
|
|
||||||
|
def do_gc():
|
||||||
|
gc.collect()
|
||||||
|
try:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def print_stats(collect=False):
|
||||||
|
cf = currentframe().f_back
|
||||||
|
msg = f'{getframeinfo(cf).filename}:{cf.f_lineno}'
|
||||||
|
|
||||||
|
if collect:
|
||||||
|
do_gc()
|
||||||
|
|
||||||
|
tot = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
|
||||||
|
res = torch.cuda.memory_reserved(0) / (1024 ** 3)
|
||||||
|
alloc = torch.cuda.memory_allocated(0) / (1024 ** 3)
|
||||||
|
print("[{}] Total: {:.3f} | Reserved: {:.3f} | Allocated: {:.3f} | Free: {:.3f}".format( msg, tot, res, alloc, tot-res ))
|
||||||
|
|
||||||
|
|
||||||
def has_dml():
|
def has_dml():
|
||||||
loader = importlib.find_loader('torch_directml')
|
loader = importlib.find_loader('torch_directml')
|
||||||
if loader is None:
|
if loader is None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user