Skip to content

Commit 7f7f6f2

Browse files
authored
Merge pull request #47 from fcollonval/ft/improve-tests
Improve config handling
2 parents 02e2cbd + 67a0657 commit 7f7f6f2

File tree

9 files changed

+163
-150
lines changed

9 files changed

+163
-150
lines changed

jupyterlab_pullrequests/handlers.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import json
55
import logging
66
import traceback
7+
from typing import Optional
78
from http import HTTPStatus
89

910
import tornado
@@ -198,26 +199,33 @@ def get_body_value(handler):
198199
]
199200

200201

201-
def setup_handlers(web_app: "NotebookWebApplication", config: PRConfig):
202+
def setup_handlers(web_app: tornado.web.Application, config: PRConfig, log: Optional[logging.Logger]=None):
202203
host_pattern = ".*$"
203204
base_url = url_path_join(web_app.settings["base_url"], NAMESPACE)
204205

205-
logger = get_logger()
206+
log = log or logging.getLogger(__name__)
206207

207208
manager_class = MANAGERS.get(config.provider)
208209
if manager_class is None:
209-
logger.error(f"No manager defined for provider '{config.provider}'.")
210+
log.error(f"PR Manager: No manager defined for provider '{config.provider}'.")
210211
raise NotImplementedError()
211-
manager = manager_class(config.api_base_url, config.access_token)
212-
213-
web_app.add_handlers(
214-
host_pattern,
215-
[
216-
(
217-
url_path_join(base_url, pat),
218-
handler,
219-
{"logger": logger, "manager": manager},
220-
)
221-
for pat, handler in default_handlers
222-
],
223-
)
212+
log.info(f"PR Manager Class {manager_class}")
213+
try:
214+
manager = manager_class(config)
215+
except Exception as err:
216+
import traceback
217+
logging.error("PR Manager Exception", exc_info=1)
218+
raise err
219+
220+
handlers = [
221+
(
222+
url_path_join(base_url, pat),
223+
handler,
224+
{"logger": log, "manager": manager},
225+
)
226+
for pat, handler in default_handlers
227+
]
228+
229+
log.debug(f"PR Handlers: {handlers}")
230+
231+
web_app.add_handlers(host_pattern, handlers)

jupyterlab_pullrequests/managers/github.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,20 @@
66
from tornado.httputil import url_concat
77
from tornado.web import HTTPError
88

9-
from ..base import CommentReply, NewComment
9+
from ..base import CommentReply, NewComment, PRConfig
1010
from .manager import PullRequestsManager
1111

1212

1313
class GitHubManager(PullRequestsManager):
1414
"""Pull request manager for GitHub."""
1515

16-
def __init__(
17-
self, base_api_url: str = "https://api.github.com", access_token: str = ""
18-
) -> None:
19-
"""
20-
Args:
21-
base_api_url: Base REST API url for the versioning service
22-
access_token: Versioning service access token
23-
"""
24-
super().__init__(base_api_url=base_api_url, access_token=access_token)
25-
self._pull_requests_cache = {} # Dict[str, Dict]
16+
def __init__(self, config: PRConfig) -> None:
17+
super().__init__(config)
18+
self._pull_requests_cache = {}
19+
20+
@property
21+
def base_api_url(self):
22+
return self._config.api_base_url or "https://api.github.com"
2623

2724
@property
2825
def per_page_argument(self) -> Optional[Tuple[str, int]]:
@@ -40,7 +37,7 @@ async def get_current_user(self) -> Dict[str, str]:
4037
Returns:
4138
JSON description of the user matching the access token
4239
"""
43-
git_url = url_path_join(self._base_api_url, "user")
40+
git_url = url_path_join(self.base_api_url, "user")
4441
data = await self._call_github(git_url, has_pagination=False)
4542

4643
return {"username": data["login"]}
@@ -186,7 +183,7 @@ async def list_prs(self, username: str, pr_filter: str) -> List[Dict[str, str]]:
186183

187184
# Use search API to find matching pull requests and return
188185
git_url = url_path_join(
189-
self._base_api_url, "/search/issues?q=+state:open+type:pr" + search_filter
186+
self.base_api_url, "/search/issues?q=+state:open+type:pr" + search_filter
190187
)
191188

192189
results = await self._call_github(git_url)
@@ -273,7 +270,7 @@ async def _call_github(
273270
"""
274271
headers = {
275272
"Accept": media_type,
276-
"Authorization": f"token {self._access_token}",
273+
"Authorization": f"token {self._config.access_token}",
277274
}
278275

279276
return await super()._call_provider(

jupyterlab_pullrequests/managers/gitlab.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from tornado.httputil import url_concat
1414
from tornado.web import HTTPError
1515

16-
from ..base import CommentReply, NewComment
16+
from ..base import CommentReply, NewComment, PRConfig
1717
from ..log import get_logger
1818
from .manager import PullRequestsManager
1919

@@ -25,22 +25,20 @@ class GitLabManager(PullRequestsManager):
2525

2626
MINIMAL_VERSION = "13.1" # Due to pagination https://docs.gitlab.com/ee/api/README.html#pagination
2727

28-
def __init__(
29-
self, base_api_url: str = "https://gitlab.com/api/v4/", access_token: str = ""
30-
) -> None:
31-
"""
32-
Args:
33-
base_api_url: Base REST API url for the versioning service
34-
access_token: Versioning service access token
35-
"""
36-
super().__init__(base_api_url=base_api_url, access_token=access_token)
28+
def __init__(self, config: PRConfig) -> None:
29+
super().__init__(config)
30+
3731
# Creating new file discussion required some commit sha's so we will cache them
3832
self._merge_requests_cache = {} # Dict[str, Dict]
3933
# Creating discussion on unmodified line requires to figure out the line number
4034
# in the diff file for the original and the new file using Myers algorithm. So
4135
# we cache the diff to speed up the process.
4236
self._file_diff_cache = {} # Dict[Tuple[str, str], List[difflib.Match]]
4337

38+
@property
39+
def base_api_url(self):
40+
return self._config.api_base_url or "https://gitlab.com/api/v4/"
41+
4442
@property
4543
def per_page_argument(self) -> Optional[Tuple[str, int]]:
4644
"""Returns query argument to set number of items per page.
@@ -57,7 +55,7 @@ async def check_server_version(self) -> bool:
5755
Returns:
5856
Whether the server version is higher than the minimal supported version.
5957
"""
60-
url = url_path_join(self._base_api_url, "version")
58+
url = url_path_join(self.base_api_url, "version")
6159
data = await self._call_gitlab(url, has_pagination=False)
6260
server_version = data.get("version", "")
6361
is_valid = True
@@ -79,7 +77,7 @@ async def get_current_user(self) -> Dict[str, str]:
7977
# Check server compatibility
8078
await self.check_server_version()
8179

82-
git_url = url_path_join(self._base_api_url, "user")
80+
git_url = url_path_join(self.base_api_url, "user")
8381
data = await self._call_gitlab(git_url, has_pagination=False)
8482

8583
return {"username": data["username"]}
@@ -227,15 +225,15 @@ async def list_prs(self, username: str, pr_filter: str) -> List[Dict[str, str]]:
227225

228226
# Use search API to find matching pull requests and return
229227
git_url = url_path_join(
230-
self._base_api_url, "/merge_requests?state=opened&" + search_filter
228+
self.base_api_url, "/merge_requests?state=opened&" + search_filter
231229
)
232230

233231
results = await self._call_gitlab(git_url)
234232

235233
data = []
236234
for result in results:
237235
url = url_path_join(
238-
self._base_api_url,
236+
self.base_api_url,
239237
"projects",
240238
str(result["project_id"]),
241239
"merge_requests",
@@ -374,7 +372,7 @@ async def _call_gitlab(
374372
"""
375373

376374
headers = {
377-
"Authorization": f"Bearer {self._access_token}",
375+
"Authorization": f"Bearer {self._config.access_token}",
378376
"Accept": "application/json",
379377
}
380378
return await super()._call_provider(
@@ -481,7 +479,7 @@ def _response_to_comment(result: Dict[str, str]) -> Dict[str, str]:
481479
async def __get_content(self, project_id: int, filename: str, sha: str) -> str:
482480
url = url_concat(
483481
url_path_join(
484-
self._base_api_url,
482+
self.base_api_url,
485483
"projects",
486484
str(project_id),
487485
"repository/files",

jupyterlab_pullrequests/managers/manager.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,19 @@
1010

1111
from .._version import __version__
1212
from ..log import get_logger
13-
13+
from ..base import PRConfig
1414

1515
class PullRequestsManager(abc.ABC):
1616
"""Abstract base class for pull requests manager."""
1717

18-
def __init__(self, base_api_url: str = "", access_token: str = "") -> None:
19-
"""
20-
Args:
21-
base_api_url: Base REST API url for the versioning service
22-
access_token: Versioning service access token
23-
"""
18+
def __init__(self, config: PRConfig) -> None:
19+
self._config = config
2420
self._client = tornado.httpclient.AsyncHTTPClient()
25-
self._base_api_url = base_api_url
26-
self._access_token = access_token
2721

2822
@property
2923
def base_api_url(self) -> str:
3024
"""The provider base REST API URL"""
31-
return self._base_api_url
25+
return self._config.api_base_url
3226

3327
@property
3428
def log(self) -> logging.Logger:
@@ -142,7 +136,7 @@ async def _call_provider(
142136
List or Dict: Create from JSON response body if load_json is True
143137
str: Raw response body if load_json is False
144138
"""
145-
if not self._access_token:
139+
if not self._config.access_token:
146140
raise tornado.web.HTTPError(
147141
status_code=http.HTTPStatus.BAD_REQUEST,
148142
reason="No access token specified. Please set PRConfig.access_token in your user jupyter_server_config file.",
@@ -154,8 +148,8 @@ async def _call_provider(
154148
headers["Content-Type"] = "application/json"
155149
body = tornado.escape.json_encode(body)
156150

157-
if not url.startswith(self._base_api_url):
158-
url = url_path_join(self._base_api_url, url)
151+
if not url.startswith(self.base_api_url):
152+
url = url_path_join(self.base_api_url, url)
159153

160154
with_pagination = False
161155
if (
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,40 @@
1+
import pytest
2+
3+
from ..base import PRConfig
4+
5+
# the preferred method for loading jupyter_server (because entry_points)
16
pytest_plugins = ["jupyter_server.pytest_plugin"]
7+
8+
9+
@pytest.fixture
10+
def pr_base_config():
11+
return PRConfig()
12+
13+
14+
@pytest.fixture
15+
def pr_github_config(pr_base_config):
16+
return pr_base_config()
17+
18+
19+
@pytest.fixture
20+
def pr_github_manager(pr_base_config):
21+
from ..managers.github import GitHubManager
22+
return GitHubManager(pr_base_config)
23+
24+
25+
@pytest.fixture
26+
def pr_valid_github_manager(pr_github_manager):
27+
pr_github_manager._config.access_token = "valid"
28+
return pr_github_manager
29+
30+
31+
@pytest.fixture
32+
def pr_gitlab_manger(pr_base_config):
33+
from ..managers.gitlab import GitLabManager
34+
return GitLabManager(pr_base_config)
35+
36+
37+
@pytest.fixture
38+
def pr_valid_gitlab_manager(pr_gitlab_manger):
39+
pr_gitlab_manger._config.access_token = "valid"
40+
return pr_gitlab_manger

0 commit comments

Comments
 (0)