v0.1.2
This commit is contained in:
parent
27d818674f
commit
32cb51ae38
4
setup.py
4
setup.py
|
@ -7,7 +7,7 @@ from setuptools import find_packages, setup
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="torchscale",
|
name="torchscale",
|
||||||
version="0.1.1",
|
version="0.1.2",
|
||||||
author="TorchScale Team",
|
author="TorchScale Team",
|
||||||
author_email="Shuming.Ma@microsoft.com",
|
author_email="Shuming.Ma@microsoft.com",
|
||||||
description="Transformers at any scale",
|
description="Transformers at any scale",
|
||||||
|
@ -15,7 +15,7 @@ setup(
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
keywords="Transformers at any scale",
|
keywords="Transformers at any scale",
|
||||||
license="MIT",
|
license="MIT",
|
||||||
url="https://github.com/msranlp/torchscale",
|
url="https://github.com/microsoft/torchscale",
|
||||||
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
|
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
|
||||||
install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"],
|
install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"],
|
||||||
python_requires=">=3.8.0",
|
python_requires=">=3.8.0",
|
||||||
|
|
|
@ -63,6 +63,7 @@ class RelativePositionBias(nn.Module):
|
||||||
relative_position, # shape (qlen, klen)
|
relative_position, # shape (qlen, klen)
|
||||||
bidirectional=self.bidirectional,
|
bidirectional=self.bidirectional,
|
||||||
num_buckets=self.num_buckets,
|
num_buckets=self.num_buckets,
|
||||||
|
max_distance=self.max_distance,
|
||||||
)
|
)
|
||||||
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
|
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
|
||||||
values = self.relative_attention_bias(
|
values = self.relative_attention_bias(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user