From 9e27d2e02e66cce2cd4ad1db37304f387fd30ec8 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 16 Apr 2025 15:25:45 -0500 Subject: [PATCH] huggingface zerogpu cringe --- vall_e/emb/codecs/__init__.py | 0 vall_e/emb/transcribe.py | 7 ++++++- vall_e/models/base_v2.py | 13 ------------- 3 files changed, 6 insertions(+), 14 deletions(-) create mode 100644 vall_e/emb/codecs/__init__.py diff --git a/vall_e/emb/codecs/__init__.py b/vall_e/emb/codecs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vall_e/emb/transcribe.py b/vall_e/emb/transcribe.py index e496eb0..d894ec5 100644 --- a/vall_e/emb/transcribe.py +++ b/vall_e/emb/transcribe.py @@ -18,7 +18,12 @@ except Exception as e: pass """ -from transformers import pipeline +try: + from transformers import pipeline +except Exception as e: + def _kludge_cringe(): + raise e + pipeline = _kludge_cringe from functools import cache from tqdm.auto import tqdm diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 66d387e..710260d 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -123,16 +123,6 @@ class FiniteAudioEncoder(nn.Module): self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels)) if use_level_weights else None self.use_ffn = use_ffn - if use_ffn: - nn.init.xavier_uniform_(self.proj[0].weight) - nn.init.xavier_uniform_(self.proj[2].weight) - - nn.init.zeros_(self.proj[0].bias) - nn.init.zeros_(self.proj[2].bias) - elif token_dim != d_model: - nn.init.xavier_uniform_(self.proj.weight) - nn.init.zeros_(self.proj.bias) - def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None, stability = None, mode = None ) -> Tensor: # empty if xi.shape[0] == 0: @@ -162,9 +152,6 @@ class FiniteAudioEncoder(nn.Module): if self.level_weights is None: x = x.sum(dim=1) - elif stability: - weights = F.softmax(self.level_weights.float(), dim=0).view(1, -1, 1) - x = (x.float() * weights).sum(dim=1).to(xi.dtype) else: weights = F.softmax(self.level_weights, dim=0).view(1, -1, 1) x = (x * weights).sum(dim=1)