1
+ import json
1
2
import logging
2
3
import os
4
+ import re
3
5
import shutil
6
+ from abc import abstractmethod
4
7
from collections .abc import Iterable , Iterator
5
8
from contextlib import AbstractContextManager , contextmanager , suppress
6
9
from tempfile import NamedTemporaryFile
13
16
from fsspec .implementations .http import HTTPFileSystem
14
17
from funcy import cached_property
15
18
19
+ from scmrepo .git .backend .dulwich import _get_ssh_vendor
16
20
from scmrepo .git .credentials import Credential , CredentialNotFoundError
17
21
18
22
from .exceptions import LFSError
@@ -35,19 +39,12 @@ class LFSClient(AbstractContextManager):
35
39
_SESSION_RETRIES = 5
36
40
_SESSION_BACKOFF_FACTOR = 0.1
37
41
38
- def __init__ (
39
- self ,
40
- url : str ,
41
- git_url : Optional [str ] = None ,
42
- headers : Optional [dict [str , str ]] = None ,
43
- ):
42
+ def __init__ (self , url : str ):
44
43
"""
45
44
Args:
46
45
url: LFS server URL.
47
46
"""
48
47
self .url = url
49
- self .git_url = git_url
50
- self .headers : dict [str , str ] = headers or {}
51
48
52
49
def __exit__ (self , * args , ** kwargs ):
53
50
self .close ()
@@ -84,23 +81,18 @@ def loop(self):
84
81
85
82
@classmethod
86
83
def from_git_url (cls , git_url : str ) -> "LFSClient" :
87
- if git_url .endswith ( ". git" ):
88
- url = f" { git_url } /info/lfs"
89
- else :
90
- url = f" { git_url } .git/info/lfs"
91
- return cls ( url , git_url = git_url )
84
+ if git_url .startswith (( "ssh://" , " git@" ) ):
85
+ return _SSHLFSClient . from_git_url ( git_url )
86
+ if git_url . startswith ( "https://" ) :
87
+ return _HTTPLFSClient . from_git_url ( git_url )
88
+ raise NotImplementedError ( f"Unsupported Git URL: { git_url } " )
92
89
93
90
def close (self ):
94
91
pass
95
92
96
- def _get_auth (self ) -> Optional [aiohttp .BasicAuth ]:
97
- try :
98
- creds = Credential (url = self .git_url ).fill ()
99
- if creds .username and creds .password :
100
- return aiohttp .BasicAuth (creds .username , creds .password )
101
- except CredentialNotFoundError :
102
- pass
103
- return None
93
+ @abstractmethod
94
+ def _get_auth_header (self , * , upload : bool ) -> dict :
95
+ ...
104
96
105
97
async def _batch_request (
106
98
self ,
@@ -120,9 +112,10 @@ async def _batch_request(
120
112
if ref :
121
113
body ["ref" ] = [{"name" : ref }]
122
114
session = await self ._fs .set_session ()
123
- headers = dict (self .headers )
124
- headers ["Accept" ] = self .JSON_CONTENT_TYPE
125
- headers ["Content-Type" ] = self .JSON_CONTENT_TYPE
115
+ headers = {
116
+ "Accept" : self .JSON_CONTENT_TYPE ,
117
+ "Content-Type" : self .JSON_CONTENT_TYPE ,
118
+ }
126
119
try :
127
120
async with session .post (
128
121
url ,
@@ -134,13 +127,12 @@ async def _batch_request(
134
127
except aiohttp .ClientResponseError as exc :
135
128
if exc .status != 401 :
136
129
raise
137
- auth = self ._get_auth ( )
138
- if auth is None :
130
+ auth_header = self ._get_auth_header ( upload = upload )
131
+ if not auth_header :
139
132
raise
140
133
async with session .post (
141
134
url ,
142
- auth = auth ,
143
- headers = headers ,
135
+ headers = {** headers , ** auth_header },
144
136
json = body ,
145
137
raise_for_status = True ,
146
138
) as resp :
@@ -186,6 +178,85 @@ async def _get_one(from_path: str, to_path: str, **kwargs):
186
178
download = sync_wrapper (_download )
187
179
188
180
181
+ class _HTTPLFSClient (LFSClient ):
182
+ def __init__ (self , url : str , git_url : str ):
183
+ """
184
+ Args:
185
+ url: LFS server URL.
186
+ git_url: Git HTTP URL.
187
+ """
188
+ super ().__init__ (url )
189
+ self .git_url = git_url
190
+
191
+ @classmethod
192
+ def from_git_url (cls , git_url : str ) -> "_HTTPLFSClient" :
193
+ if git_url .endswith (".git" ):
194
+ url = f"{ git_url } /info/lfs"
195
+ else :
196
+ url = f"{ git_url } .git/info/lfs"
197
+ return cls (url , git_url = git_url )
198
+
199
+ def _get_auth_header (self , * , upload : bool ) -> dict :
200
+ try :
201
+ creds = Credential (url = self .git_url ).fill ()
202
+ if creds .username and creds .password :
203
+ return {
204
+ aiohttp .hdrs .AUTHORIZATION : aiohttp .BasicAuth (
205
+ creds .username , creds .password
206
+ ).encode ()
207
+ }
208
+ except CredentialNotFoundError :
209
+ pass
210
+ return {}
211
+
212
+
213
+ class _SSHLFSClient (LFSClient ):
214
+ _URL_PATTERN = re .compile (
215
+ r"(?:ssh://)?git@(?P<host>\S+?)(?::(?P<port>\d+))?(?:[:/])(?P<path>\S+?)\.git"
216
+ )
217
+
218
+ def __init__ (self , url : str , host : str , port : int , path : str ):
219
+ """
220
+ Args:
221
+ url: LFS server URL.
222
+ host: Git SSH server host.
223
+ port: Git SSH server port.
224
+ path: Git project path.
225
+ """
226
+ super ().__init__ (url )
227
+ self .host = host
228
+ self .port = port
229
+ self .path = path
230
+ self ._ssh = _get_ssh_vendor ()
231
+
232
+ @classmethod
233
+ def from_git_url (cls , git_url : str ) -> "_SSHLFSClient" :
234
+ result = cls ._URL_PATTERN .match (git_url )
235
+ if not result :
236
+ raise ValueError (f"Invalid Git SSH URL: { git_url } " )
237
+ host , port , path = result .group ("host" , "port" , "path" )
238
+ url = f"https://{ host } /{ path } .git/info/lfs"
239
+ return cls (url , host , int (port or 22 ), path )
240
+
241
+ def _get_auth_header (self , * , upload : bool ) -> dict :
242
+ return self ._git_lfs_authenticate (
243
+ self .host , self .port , f"{ self .path } .git" , upload = upload
244
+ ).get ("header" , {})
245
+
246
+ def _git_lfs_authenticate (
247
+ self , host : str , port : int , path : str , * , upload : bool = False
248
+ ) -> dict :
249
+ action = "upload" if upload else "download"
250
+ return json .loads (
251
+ self ._ssh .run_command (
252
+ command = f"git-lfs-authenticate { path } { action } " ,
253
+ host = host ,
254
+ port = port ,
255
+ username = "git" ,
256
+ ).read ()
257
+ )
258
+
259
+
189
260
@contextmanager
190
261
def _as_atomic (to_info : str , create_parents : bool = False ) -> Iterator [str ]:
191
262
parent = os .path .dirname (to_info )
0 commit comments