From e47a759ed88740850a9228a66ca988208f2a1d07 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 21 Mar 2022 17:22:35 -0600 Subject: [PATCH] ....... --- codes/models/audio/tts/unet_diffusion_tts7.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/codes/models/audio/tts/unet_diffusion_tts7.py b/codes/models/audio/tts/unet_diffusion_tts7.py index 8075cef5..1894f1df 100644 --- a/codes/models/audio/tts/unet_diffusion_tts7.py +++ b/codes/models/audio/tts/unet_diffusion_tts7.py @@ -57,11 +57,13 @@ class CheckpointedXTransformerEncoder(nn.Module): Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid to channels-last that XTransformer expects. """ - def __init__(self, needs_permute=True, **xtransformer_kwargs): + def __init__(self, needs_permute=True, checkpoint=True, **xtransformer_kwargs): super().__init__() self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs) self.needs_permute = needs_permute + if not checkpoint: + return for i in range(len(self.transformer.attn_layers.layers)): n, b, r = self.transformer.attn_layers.layers[i] self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])