This commit is contained in:
shumingma 2023-03-04 01:11:34 -08:00
parent 27d818674f
commit 32cb51ae38
2 changed files with 3 additions and 2 deletions

View File

@ -7,7 +7,7 @@ from setuptools import find_packages, setup
setup( setup(
name="torchscale", name="torchscale",
version="0.1.1", version="0.1.2",
author="TorchScale Team", author="TorchScale Team",
author_email="Shuming.Ma@microsoft.com", author_email="Shuming.Ma@microsoft.com",
description="Transformers at any scale", description="Transformers at any scale",
@ -15,7 +15,7 @@ setup(
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
keywords="Transformers at any scale", keywords="Transformers at any scale",
license="MIT", license="MIT",
url="https://github.com/msranlp/torchscale", url="https://github.com/microsoft/torchscale",
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"], install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"],
python_requires=">=3.8.0", python_requires=">=3.8.0",

View File

@ -63,6 +63,7 @@ class RelativePositionBias(nn.Module):
relative_position, # shape (qlen, klen) relative_position, # shape (qlen, klen)
bidirectional=self.bidirectional, bidirectional=self.bidirectional,
num_buckets=self.num_buckets, num_buckets=self.num_buckets,
max_distance=self.max_distance,
) )
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias( values = self.relative_attention_bias(