-
Notifications
You must be signed in to change notification settings - Fork 0
/
setup.py
40 lines (34 loc) · 1.17 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import subprocess
from pathlib import Path
import setuptools
import setuptools.command.build_ext
class make_ext(setuptools.command.build_ext.build_ext): # type:ignore[misc]
def build_extension(self, ext: setuptools.Extension) -> None:
if ext.name == "libflash_attention":
filename = Path(self.build_lib) / self.get_ext_filename(
self.get_ext_fullname(ext.name)
)
objdir = filename.with_suffix("")
subprocess.check_call(
[
"make",
f"OUT={filename}",
f"OBJDIR={objdir}",
]
)
else:
super().build_extension(ext)
setuptools.setup(
name="flash-attention-ipu",
version="0.1.0",
description="FlashAttention for Graphcore IPUs",
install_requires=Path("requirements.txt").read_text().rstrip("\n").split("\n"),
ext_modules=[
setuptools.Extension(
"libflash_attention",
list(map(str, Path("flash_attention_ipu/cpp").glob("*.[ch]pp"))),
)
],
cmdclass=dict(build_ext=make_ext),
)