huggingface zerogpu cringe

This commit is contained in:
mrq 2025-04-16 15:25:45 -05:00
parent 814146a5e0
commit 9e27d2e02e
3 changed files with 6 additions and 14 deletions

View File

View File

@ -18,7 +18,12 @@ except Exception as e:
pass
"""
from transformers import pipeline
try:
from transformers import pipeline
except Exception as e:
def _kludge_cringe():
raise e
pipeline = _kludge_cringe
from functools import cache
from tqdm.auto import tqdm

View File

@ -123,16 +123,6 @@ class FiniteAudioEncoder(nn.Module):
self.level_weights = nn.Parameter(torch.ones(n_levels) / math.sqrt(n_levels)) if use_level_weights else None
self.use_ffn = use_ffn
if use_ffn:
nn.init.xavier_uniform_(self.proj[0].weight)
nn.init.xavier_uniform_(self.proj[2].weight)
nn.init.zeros_(self.proj[0].bias)
nn.init.zeros_(self.proj[2].bias)
elif token_dim != d_model:
nn.init.xavier_uniform_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None, stability = None, mode = None ) -> Tensor:
# empty
if xi.shape[0] == 0:
@ -162,9 +152,6 @@ class FiniteAudioEncoder(nn.Module):
if self.level_weights is None:
x = x.sum(dim=1)
elif stability:
weights = F.softmax(self.level_weights.float(), dim=0).view(1, -1, 1)
x = (x.float() * weights).sum(dim=1).to(xi.dtype)
else:
weights = F.softmax(self.level_weights, dim=0).view(1, -1, 1)
x = (x * weights).sum(dim=1)