forked from mrq/tortoise-tts
Remove entmax dep
This commit is contained in:
parent
12acac6f77
commit
dc0390ade1
|
@ -6,6 +6,5 @@ inflect
|
||||||
progressbar
|
progressbar
|
||||||
einops
|
einops
|
||||||
unidecode
|
unidecode
|
||||||
entmax
|
|
||||||
scipy
|
scipy
|
||||||
librosa
|
librosa
|
|
@ -10,7 +10,6 @@ from collections import namedtuple
|
||||||
from einops import rearrange, repeat, reduce
|
from einops import rearrange, repeat, reduce
|
||||||
from einops.layers.torch import Rearrange
|
from einops.layers.torch import Rearrange
|
||||||
|
|
||||||
from entmax import entmax15
|
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
DEFAULT_DIM_HEAD = 64
|
DEFAULT_DIM_HEAD = 64
|
||||||
|
@ -556,7 +555,7 @@ class Attention(nn.Module):
|
||||||
self.sparse_topk = sparse_topk
|
self.sparse_topk = sparse_topk
|
||||||
|
|
||||||
# entmax
|
# entmax
|
||||||
self.attn_fn = entmax15 if use_entmax15 else F.softmax
|
self.attn_fn = F.softmax
|
||||||
|
|
||||||
# add memory key / values
|
# add memory key / values
|
||||||
self.num_mem_kv = num_mem_kv
|
self.num_mem_kv = num_mem_kv
|
||||||
|
|
Loading…
Reference in New Issue
Block a user