forked from mrq/DL-Art-School
update environment and fix a bunch of deps
This commit is contained in:
parent
45afefabed
commit
f8108cfdb2
.idea
codes
2
.idea/.gitignore
vendored
2
.idea/.gitignore
vendored
|
@ -1,2 +0,0 @@
|
|||
# Default ignored files
|
||||
/workspace.xml
|
8
.idea/dlas.iml
Normal file
8
.idea/dlas.iml
Normal file
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.9 (torch12)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
19
.idea/inspectionProfiles/Project_Default.xml
Normal file
19
.idea/inspectionProfiles/Project_Default.xml
Normal file
|
@ -0,0 +1,19 @@
|
|||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="6">
|
||||
<item index="0" class="java.lang.String" itemvalue="numba" />
|
||||
<item index="1" class="java.lang.String" itemvalue="jupyter" />
|
||||
<item index="2" class="java.lang.String" itemvalue="nltk" />
|
||||
<item index="3" class="java.lang.String" itemvalue="pytorch-pretrained-biggan" />
|
||||
<item index="4" class="java.lang.String" itemvalue="gunicorn" />
|
||||
<item index="5" class="java.lang.String" itemvalue="scipy" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
|
@ -1,7 +1,4 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="JavaScriptSettings">
|
||||
<option name="languageLevel" value="ES6" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (torch)" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (torch12)" project-jdk-type="Python SDK" />
|
||||
</project>
|
|
@ -1,19 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<sourceFolder url="file://$MODULE_DIR$/codes" isTestSource="false" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/codes/temp" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/datasets" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/experiments" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/results" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/tb_logger" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.9 (torch)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="GOOGLE" />
|
||||
<option name="myDocStringFormat" value="Google" />
|
||||
</component>
|
||||
</module>
|
|
@ -2,7 +2,7 @@
|
|||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/mmsr.iml" filepath="$PROJECT_DIR$/.idea/mmsr.iml" />
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/dlas.iml" filepath="$PROJECT_DIR$/.idea/dlas.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
|
@ -1,6 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PySciProjectComponent">
|
||||
<option name="PY_MATPLOTLIB_IN_TOOLWINDOW" value="false" />
|
||||
</component>
|
||||
</project>
|
|
@ -1,9 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/codes/models/flownet2" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/codes/models/switched_conv" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/codes/switched_conv" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
150
.idea/workspace.xml
Normal file
150
.idea/workspace.xml
Normal file
|
@ -0,0 +1,150 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="AutoImportSettings">
|
||||
<option name="autoReloadType" value="SELECTIVE" />
|
||||
</component>
|
||||
<component name="ChangeListManager">
|
||||
<list default="true" id="c96d3871-3547-4d31-96a9-6f06ea0717ab" name="Changes" comment="">
|
||||
<change beforePath="$PROJECT_DIR$/.idea/.gitignore" beforeDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/inspectionProfiles/profiles_settings.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/inspectionProfiles/profiles_settings.xml" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/mmsr.iml" beforeDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/modules.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/modules.xml" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/other.xml" beforeDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/vcs.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/vcs.xml" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/codes/models/arch_util.py" beforeDir="false" afterPath="$PROJECT_DIR$/codes/models/arch_util.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/codes/models/audio/music/unet_diffusion_music_codes.py" beforeDir="false" afterPath="$PROJECT_DIR$/codes/models/audio/music/unet_diffusion_music_codes.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/codes/models/audio/tts/diffusion_encoder.py" beforeDir="false" afterPath="$PROJECT_DIR$/codes/models/audio/tts/diffusion_encoder.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/codes/models/diffusion/unet_diffusion.py" beforeDir="false" afterPath="$PROJECT_DIR$/codes/models/diffusion/unet_diffusion.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/codes/models/lucidrains/x_transformers.py" beforeDir="false" afterPath="$PROJECT_DIR$/codes/models/lucidrains/x_transformers.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/codes/requirements.txt" beforeDir="false" afterPath="$PROJECT_DIR$/codes/requirements.txt" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/codes/trainer/eval/music_diffusion_fid.py" beforeDir="false" afterPath="$PROJECT_DIR$/codes/trainer/eval/music_diffusion_fid.py" afterDir="false" />
|
||||
</list>
|
||||
<option name="SHOW_DIALOG" value="false" />
|
||||
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
||||
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
||||
<option name="LAST_RESOLUTION" value="IGNORE" />
|
||||
</component>
|
||||
<component name="Git.Settings">
|
||||
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
||||
</component>
|
||||
<component name="MarkdownSettingsMigration">
|
||||
<option name="stateVersion" value="1" />
|
||||
</component>
|
||||
<component name="ProjectId" id="2CQCsrwBRaEk3sGS5kZRpywk8nK" />
|
||||
<component name="ProjectViewState">
|
||||
<option name="hideEmptyMiddlePackages" value="true" />
|
||||
<option name="showLibraryContents" value="true" />
|
||||
</component>
|
||||
<component name="PropertiesComponent"><![CDATA[{
|
||||
"keyToString": {
|
||||
"RunOnceActivity.OpenProjectViewOnStart": "true",
|
||||
"RunOnceActivity.ShowReadmeOnStart": "true",
|
||||
"WebServerToolWindowFactoryState": "false",
|
||||
"last_opened_file_path": "D:/scraper",
|
||||
"node.js.detected.package.eslint": "true",
|
||||
"node.js.detected.package.tslint": "true",
|
||||
"node.js.selected.package.eslint": "(autodetect)",
|
||||
"node.js.selected.package.tslint": "(autodetect)",
|
||||
"nodejs_package_manager_path": "npm"
|
||||
}
|
||||
}]]></component>
|
||||
<component name="RunManager" selected="Python.music_diffusion_fid">
|
||||
<configuration name="arch_util" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
||||
<module name="dlas" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/codes/" />
|
||||
<option name="IS_MODULE_SDK" value="true" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/codes/models/arch_util.py" />
|
||||
<option name="PARAMETERS" value="" />
|
||||
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
<option name="REDIRECT_INPUT" value="false" />
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<configuration name="music_diffusion_fid" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
||||
<module name="dlas" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/codes" />
|
||||
<option name="IS_MODULE_SDK" value="true" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/codes/trainer/eval/music_diffusion_fid.py" />
|
||||
<option name="PARAMETERS" value="" />
|
||||
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
<option name="REDIRECT_INPUT" value="false" />
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<configuration name="Doctests in arch_util" type="tests" factoryName="Doctests" temporary="true" nameIsGenerated="true">
|
||||
<module name="dlas" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/codes/models" />
|
||||
<option name="IS_MODULE_SDK" value="true" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/codes/models/arch_util.py" />
|
||||
<option name="CLASS_NAME" value="" />
|
||||
<option name="METHOD_NAME" value="" />
|
||||
<option name="FOLDER_NAME" value="" />
|
||||
<option name="TEST_TYPE" value="TEST_SCRIPT" />
|
||||
<option name="PATTERN" value="" />
|
||||
<option name="USE_PATTERN" value="false" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<list>
|
||||
<item itemvalue="Python.music_diffusion_fid" />
|
||||
<item itemvalue="Python.arch_util" />
|
||||
<item itemvalue="Python tests.Doctests in arch_util" />
|
||||
</list>
|
||||
<recent_temporary>
|
||||
<list>
|
||||
<item itemvalue="Python.music_diffusion_fid" />
|
||||
<item itemvalue="Python.arch_util" />
|
||||
<item itemvalue="Python tests.Doctests in arch_util" />
|
||||
</list>
|
||||
</recent_temporary>
|
||||
</component>
|
||||
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
|
||||
<component name="TaskManager">
|
||||
<task active="true" id="Default" summary="Default task">
|
||||
<changelist id="c96d3871-3547-4d31-96a9-6f06ea0717ab" name="Changes" comment="" />
|
||||
<created>1658725632364</created>
|
||||
<option name="number" value="Default" />
|
||||
<option name="presentableId" value="Default" />
|
||||
<updated>1658725632364</updated>
|
||||
<workItem from="1658725634517" duration="27000" />
|
||||
<workItem from="1658726334149" duration="1376000" />
|
||||
</task>
|
||||
<servers />
|
||||
</component>
|
||||
<component name="TypeScriptGeneratedFilesManager">
|
||||
<option name="version" value="3" />
|
||||
</component>
|
||||
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
||||
<SUITE FILE_PATH="coverage/dlas$music_diffusion_fid.coverage" NAME="music_diffusion_fid Coverage Results" MODIFIED="1658727714920" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/codes" />
|
||||
<SUITE FILE_PATH="coverage/dlas$arch_util.coverage" NAME="arch_util Coverage Results" MODIFIED="1658726906198" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/codes/" />
|
||||
<SUITE FILE_PATH="coverage/dlas$Doctests_in_arch_util.coverage" NAME="Doctests in arch_util Coverage Results" MODIFIED="1658726625256" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/codes/models" />
|
||||
</component>
|
||||
</project>
|
|
@ -483,16 +483,24 @@ class RelativeQKBias(nn.Module):
|
|||
"""
|
||||
Very simple relative position bias scheme which should be directly added to QK matrix. This bias simply applies to
|
||||
the distance from the given element.
|
||||
|
||||
If symmetric=False, a different bias is applied to each side of the input element, otherwise the bias is symmetric.
|
||||
"""
|
||||
def __init__(self, l, max_positions=4000):
|
||||
def __init__(self, l, max_positions=4000, symmetric=True):
|
||||
super().__init__()
|
||||
self.emb = nn.Parameter(torch.randn(l+1) * .01)
|
||||
o = torch.arange(0,max_positions)
|
||||
c = o.unsqueeze(-1).repeat(1,max_positions)
|
||||
r = o.unsqueeze(0).repeat(max_positions,1)
|
||||
M = ((-(r-c).abs())+l).clamp(0,l)
|
||||
if symmetric:
|
||||
self.emb = nn.Parameter(torch.randn(l+1) * .01)
|
||||
o = torch.arange(0,max_positions)
|
||||
c = o.unsqueeze(-1).repeat(1,max_positions)
|
||||
r = o.unsqueeze(0).repeat(max_positions,1)
|
||||
M = ((-(r-c).abs())+l).clamp(0,l)
|
||||
else:
|
||||
self.emb = nn.Parameter(torch.randn(l*2+2) * .01)
|
||||
a = torch.arange(0,max_positions)
|
||||
c = a.unsqueeze(-1) - a
|
||||
m = (c >= -l).logical_and(c <= l)
|
||||
M = (l+c+1)*m
|
||||
self.register_buffer('M', M, persistent=False)
|
||||
self.initted = False
|
||||
|
||||
def forward(self, n):
|
||||
# Ideally, I'd return this:
|
||||
|
|
|
@ -8,7 +8,6 @@ import torch as th
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision # For debugging, not actually used.
|
||||
from x_transformers.x_transformers import RelativePositionBias
|
||||
|
||||
from models.audio.music.gpt_music import GptMusicLower
|
||||
from models.audio.music.music_quantizer import MusicQuantizer
|
||||
|
@ -291,10 +290,6 @@ class AttentionBlock(nn.Module):
|
|||
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
if relative_pos_embeddings:
|
||||
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
|
||||
else:
|
||||
self.relative_pos_embeddings = None
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
if self.do_checkpoint:
|
||||
|
@ -306,7 +301,7 @@ class AttentionBlock(nn.Module):
|
|||
b, c, *spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
qkv = self.qkv(self.norm(x))
|
||||
h = self.attention(qkv, mask, self.relative_pos_embeddings)
|
||||
h = self.attention(qkv, mask)
|
||||
h = self.proj_out(h)
|
||||
return (x + h).reshape(b, c, *spatial)
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from functools import partial
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \
|
||||
DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, Rezero, \
|
||||
DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, \
|
||||
exists, Attention, FeedForward, Scale, ShiftTokens, GRUGating, Residual, cast_tuple, equals, LayerIntermediates, \
|
||||
AttentionLayers, not_equals
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ import torch as th
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision # For debugging, not actually used.
|
||||
from x_transformers.x_transformers import RelativePositionBias
|
||||
|
||||
from models.diffusion.fp16_util import convert_module_to_f16, convert_module_to_f32
|
||||
from models.diffusion.nn import (
|
||||
|
@ -298,7 +297,6 @@ class AttentionBlock(nn.Module):
|
|||
num_head_channels=-1,
|
||||
use_new_attention_order=False,
|
||||
do_checkpoint=True,
|
||||
relative_pos_embeddings=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
|
@ -320,10 +318,6 @@ class AttentionBlock(nn.Module):
|
|||
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
if relative_pos_embeddings:
|
||||
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
|
||||
else:
|
||||
self.relative_pos_embeddings = None
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
if self.do_checkpoint:
|
||||
|
@ -335,7 +329,7 @@ class AttentionBlock(nn.Module):
|
|||
b, c, *spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
qkv = self.qkv(self.norm(x))
|
||||
h = self.attention(qkv, mask, self.relative_pos_embeddings)
|
||||
h = self.attention(qkv, mask)
|
||||
h = self.proj_out(h)
|
||||
return (x + h).reshape(b, c, *spatial)
|
||||
|
||||
|
|
|
@ -10,11 +10,8 @@ from collections import namedtuple
|
|||
from einops import rearrange, repeat, reduce
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
from entmax import entmax15
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
||||
|
||||
DEFAULT_DIM_HEAD = 64
|
||||
|
||||
Intermediates = namedtuple('Intermediates', [
|
||||
|
@ -1274,51 +1271,3 @@ class ContinuousTransformerWrapper(nn.Module):
|
|||
if len(res) > 1:
|
||||
return tuple(res)
|
||||
return res[0]
|
||||
|
||||
|
||||
class XTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
tie_token_emb=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
|
||||
dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
|
||||
|
||||
assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
|
||||
enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
|
||||
enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
|
||||
enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
|
||||
enc_transformer_kwargs['use_pos_emb'] = enc_kwargs.pop('use_pos_emb', True)
|
||||
|
||||
dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
|
||||
dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
|
||||
dec_transformer_kwargs['use_pos_emb'] = dec_kwargs.pop('use_pos_emb', True)
|
||||
|
||||
self.encoder = TransformerWrapper(
|
||||
**enc_transformer_kwargs,
|
||||
attn_layers=Encoder(dim=dim, **enc_kwargs)
|
||||
)
|
||||
|
||||
self.decoder = TransformerWrapper(
|
||||
**dec_transformer_kwargs,
|
||||
attn_layers=Decoder(dim=dim, cross_attend=True, **dec_kwargs)
|
||||
)
|
||||
|
||||
if tie_token_emb:
|
||||
self.decoder.token_emb = self.encoder.token_emb
|
||||
|
||||
self.decoder = AutoregressiveWrapper(self.decoder)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, seq_in, seq_out_start, seq_len, src_mask=None, src_attn_mask=None, **kwargs):
|
||||
encodings = self.encoder(seq_in, mask=src_mask, attn_mask=src_attn_mask, return_embeddings=True)
|
||||
return self.decoder.generate(seq_out_start, seq_len, context=encodings, context_mask=src_mask, **kwargs)
|
||||
|
||||
def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_attn_mask=None):
|
||||
enc = self.encoder(src, mask=src_mask, attn_mask=src_attn_mask, return_embeddings=True)
|
||||
out = self.decoder(tgt, context=enc, mask=tgt_mask, context_mask=src_mask)
|
||||
return out
|
||||
|
|
|
@ -30,11 +30,13 @@ Unidecode==1.0.22
|
|||
tgt == 1.4.4
|
||||
pyworld == 0.2.10
|
||||
audio2numpy
|
||||
SoundFile
|
||||
|
||||
# For text stuff
|
||||
transformers
|
||||
tokenizers
|
||||
jiwer # calculating WER
|
||||
omegaconf
|
||||
|
||||
# lucidrains stuff
|
||||
vector_quantize_pytorch
|
||||
|
@ -42,4 +44,5 @@ linear_attention_transformer
|
|||
rotary-embedding-torch
|
||||
axial_positional_embedding
|
||||
g-mlp-pytorch
|
||||
x-clip
|
||||
x-clip
|
||||
x_transformers
|
|
@ -315,14 +315,15 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
|
||||
if __name__ == '__main__':
|
||||
# For multilevel SR:
|
||||
"""
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr.yml', 'generator',
|
||||
also_load_savepoint=False, strict_load=False,
|
||||
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\18000_generator.pth'
|
||||
load_path='X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_archived_prev2\\models\\18000_generator.pth'
|
||||
).cuda()
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
|
||||
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
|
||||
'diffusion_steps': 256, # basis: 192
|
||||
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': True,
|
||||
'diffusion_steps': 64, # basis: 192
|
||||
'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': True,
|
||||
'diffusion_schedule': 'cosine', 'diffusion_type': 'chained_sr',
|
||||
}
|
||||
|
||||
|
@ -334,13 +335,12 @@ if __name__ == '__main__':
|
|||
).cuda()
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
|
||||
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
|
||||
'diffusion_steps': 256, # basis: 192
|
||||
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': True,
|
||||
'diffusion_steps': 64, # basis: 192
|
||||
'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': True, 'clip_audio': True,
|
||||
'diffusion_schedule': 'cosine', 'diffusion_type': 'from_codes_quant',
|
||||
}
|
||||
"""
|
||||
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 11, 'device': 'cuda', 'opt': {}}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 12, 'device': 'cuda', 'opt': {}}
|
||||
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
||||
fds = []
|
||||
for i in range(2):
|
||||
|
|
Loading…
Reference in New Issue
Block a user