Skip to content

Commit c44c577

Browse files
sisppmrowla
authored andcommitted
lfs: add support for Git SSH URLs
1 parent 4819e71 commit c44c577

File tree

1 file changed

+99
-28
lines changed

1 file changed

+99
-28
lines changed

src/scmrepo/git/lfs/client.py

Lines changed: 99 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import json
12
import logging
23
import os
4+
import re
35
import shutil
6+
from abc import abstractmethod
47
from collections.abc import Iterable, Iterator
58
from contextlib import AbstractContextManager, contextmanager, suppress
69
from tempfile import NamedTemporaryFile
@@ -13,6 +16,7 @@
1316
from fsspec.implementations.http import HTTPFileSystem
1417
from funcy import cached_property
1518

19+
from scmrepo.git.backend.dulwich import _get_ssh_vendor
1620
from scmrepo.git.credentials import Credential, CredentialNotFoundError
1721

1822
from .exceptions import LFSError
@@ -35,19 +39,12 @@ class LFSClient(AbstractContextManager):
3539
_SESSION_RETRIES = 5
3640
_SESSION_BACKOFF_FACTOR = 0.1
3741

38-
def __init__(
39-
self,
40-
url: str,
41-
git_url: Optional[str] = None,
42-
headers: Optional[dict[str, str]] = None,
43-
):
42+
def __init__(self, url: str):
4443
"""
4544
Args:
4645
url: LFS server URL.
4746
"""
4847
self.url = url
49-
self.git_url = git_url
50-
self.headers: dict[str, str] = headers or {}
5148

5249
def __exit__(self, *args, **kwargs):
5350
self.close()
@@ -84,23 +81,18 @@ def loop(self):
8481

8582
@classmethod
8683
def from_git_url(cls, git_url: str) -> "LFSClient":
87-
if git_url.endswith(".git"):
88-
url = f"{git_url}/info/lfs"
89-
else:
90-
url = f"{git_url}.git/info/lfs"
91-
return cls(url, git_url=git_url)
84+
if git_url.startswith(("ssh://", "git@")):
85+
return _SSHLFSClient.from_git_url(git_url)
86+
if git_url.startswith("https://"):
87+
return _HTTPLFSClient.from_git_url(git_url)
88+
raise NotImplementedError(f"Unsupported Git URL: {git_url}")
9289

9390
def close(self):
9491
pass
9592

96-
def _get_auth(self) -> Optional[aiohttp.BasicAuth]:
97-
try:
98-
creds = Credential(url=self.git_url).fill()
99-
if creds.username and creds.password:
100-
return aiohttp.BasicAuth(creds.username, creds.password)
101-
except CredentialNotFoundError:
102-
pass
103-
return None
93+
@abstractmethod
94+
def _get_auth_header(self, *, upload: bool) -> dict:
95+
...
10496

10597
async def _batch_request(
10698
self,
@@ -120,9 +112,10 @@ async def _batch_request(
120112
if ref:
121113
body["ref"] = [{"name": ref}]
122114
session = await self._fs.set_session()
123-
headers = dict(self.headers)
124-
headers["Accept"] = self.JSON_CONTENT_TYPE
125-
headers["Content-Type"] = self.JSON_CONTENT_TYPE
115+
headers = {
116+
"Accept": self.JSON_CONTENT_TYPE,
117+
"Content-Type": self.JSON_CONTENT_TYPE,
118+
}
126119
try:
127120
async with session.post(
128121
url,
@@ -134,13 +127,12 @@ async def _batch_request(
134127
except aiohttp.ClientResponseError as exc:
135128
if exc.status != 401:
136129
raise
137-
auth = self._get_auth()
138-
if auth is None:
130+
auth_header = self._get_auth_header(upload=upload)
131+
if not auth_header:
139132
raise
140133
async with session.post(
141134
url,
142-
auth=auth,
143-
headers=headers,
135+
headers={**headers, **auth_header},
144136
json=body,
145137
raise_for_status=True,
146138
) as resp:
@@ -186,6 +178,85 @@ async def _get_one(from_path: str, to_path: str, **kwargs):
186178
download = sync_wrapper(_download)
187179

188180

181+
class _HTTPLFSClient(LFSClient):
182+
def __init__(self, url: str, git_url: str):
183+
"""
184+
Args:
185+
url: LFS server URL.
186+
git_url: Git HTTP URL.
187+
"""
188+
super().__init__(url)
189+
self.git_url = git_url
190+
191+
@classmethod
192+
def from_git_url(cls, git_url: str) -> "_HTTPLFSClient":
193+
if git_url.endswith(".git"):
194+
url = f"{git_url}/info/lfs"
195+
else:
196+
url = f"{git_url}.git/info/lfs"
197+
return cls(url, git_url=git_url)
198+
199+
def _get_auth_header(self, *, upload: bool) -> dict:
200+
try:
201+
creds = Credential(url=self.git_url).fill()
202+
if creds.username and creds.password:
203+
return {
204+
aiohttp.hdrs.AUTHORIZATION: aiohttp.BasicAuth(
205+
creds.username, creds.password
206+
).encode()
207+
}
208+
except CredentialNotFoundError:
209+
pass
210+
return {}
211+
212+
213+
class _SSHLFSClient(LFSClient):
214+
_URL_PATTERN = re.compile(
215+
r"(?:ssh://)?git@(?P<host>\S+?)(?::(?P<port>\d+))?(?:[:/])(?P<path>\S+?)\.git"
216+
)
217+
218+
def __init__(self, url: str, host: str, port: int, path: str):
219+
"""
220+
Args:
221+
url: LFS server URL.
222+
host: Git SSH server host.
223+
port: Git SSH server port.
224+
path: Git project path.
225+
"""
226+
super().__init__(url)
227+
self.host = host
228+
self.port = port
229+
self.path = path
230+
self._ssh = _get_ssh_vendor()
231+
232+
@classmethod
233+
def from_git_url(cls, git_url: str) -> "_SSHLFSClient":
234+
result = cls._URL_PATTERN.match(git_url)
235+
if not result:
236+
raise ValueError(f"Invalid Git SSH URL: {git_url}")
237+
host, port, path = result.group("host", "port", "path")
238+
url = f"https://{host}/{path}.git/info/lfs"
239+
return cls(url, host, int(port or 22), path)
240+
241+
def _get_auth_header(self, *, upload: bool) -> dict:
242+
return self._git_lfs_authenticate(
243+
self.host, self.port, f"{self.path}.git", upload=upload
244+
).get("header", {})
245+
246+
def _git_lfs_authenticate(
247+
self, host: str, port: int, path: str, *, upload: bool = False
248+
) -> dict:
249+
action = "upload" if upload else "download"
250+
return json.loads(
251+
self._ssh.run_command(
252+
command=f"git-lfs-authenticate {path} {action}",
253+
host=host,
254+
port=port,
255+
username="git",
256+
).read()
257+
)
258+
259+
189260
@contextmanager
190261
def _as_atomic(to_info: str, create_parents: bool = False) -> Iterator[str]:
191262
parent = os.path.dirname(to_info)

0 commit comments

Comments
 (0)