Skip to content

Commit

Permalink
Merge pull request #50 from jorenham/deterministic-main-return-type
Browse files Browse the repository at this point in the history
deterministic `@main` return type
  • Loading branch information
jorenham authored Sep 20, 2024
2 parents 8d574ee + 0f453ed commit ee7850c
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 109 deletions.
135 changes: 63 additions & 72 deletions mainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,22 @@
from __future__ import annotations

import asyncio
import os
import sys
from collections.abc import Coroutine
from typing import (
TYPE_CHECKING,
Any,
Callable,
Protocol,
TypeVar,
cast,
final,
overload,
)
from typing import TYPE_CHECKING, Callable, Protocol, TypeVar, cast, overload


if TYPE_CHECKING:
import contextvars
from collections.abc import Coroutine

if sys.version_info < (3, 11):
from typing_extensions import Never, Protocol, TypeAlias, TypeGuard
else:
from typing import Never, Protocol, TypeAlias, TypeGuard
from typing_extensions import Never, TypeGuard

__all__ = ('main',)

_R = TypeVar('_R')
_R_co = TypeVar('_R_co', covariant=True)
_F = TypeVar('_F', bound=Callable[[], object])

_SFunc: TypeAlias = Callable[[], _R]
_AFunc: TypeAlias = Callable[[], Coroutine[Any, Any, _R]]


def _infer_debug() -> bool:
flags = sys.flags
Expand Down Expand Up @@ -100,63 +84,75 @@ def _unwrap_click(func: _HasCallbackFunction[object] | _F, /) -> _F | object:
return func


def _is_main_func(func: Callable[..., object]) -> bool:
def _is_main_func(func: Callable[..., object], /) -> bool:
return _unwrap_click(func).__module__ == '__main__'


@final
class _MainDecorator(Protocol):
@overload
def __call__(self, func: _AFunc[_R], /) -> _AFunc[_R] | _R: ...
@overload
def __call__(self, func: _SFunc[_R], /) -> _SFunc[_R] | _R: ...


@overload
def main(
func: None = ...,
def _run_async(
coro: Coroutine[object, object, object],
/,
*,
debug: bool | None = ...,
is_async: bool | None = ...,
use_uvloop: bool | None = ...,
context: contextvars.Context | None = ...,
) -> _MainDecorator: ...
debug: bool,
use_uvloop: bool,
context: contextvars.Context | None,
) -> object:
import asyncio

if sys.version_info >= (3, 11):
loop_factory = None
if use_uvloop:
import uvloop

loop_factory = uvloop.new_event_loop

with asyncio.Runner(debug=debug, loop_factory=loop_factory) as runner:
result = runner.run(coro, context=context)
else:
if use_uvloop:
import uvloop

uvloop.install()

result = asyncio.run(coro, debug=debug)

return result


@overload
def main(
func: _AFunc[_R],
func: None = None,
/,
*,
debug: bool | None = ...,
is_async: bool | None = ...,
use_uvloop: bool | None = ...,
context: contextvars.Context | None = ...,
) -> _AFunc[_R] | _R: ...
debug: bool | None = None,
is_async: bool | None = None,
use_uvloop: bool | None = None,
context: contextvars.Context | None = None,
) -> Callable[[_F], _F]: ...
@overload
def main(
func: _SFunc[_R],
func: _F,
/,
*,
debug: bool | None = ...,
is_async: bool | None = ...,
use_uvloop: bool | None = ...,
context: contextvars.Context | None = ...,
) -> _SFunc[_R] | _R: ...
debug: bool | None = None,
is_async: bool | None = None,
use_uvloop: bool | None = None,
context: contextvars.Context | None = None,
) -> _F: ...
def main(
func: _AFunc[_R] | _SFunc[_R] | None = None,
func: _F | None = None,
/,
*,
debug: bool | None = None,
is_async: bool | None = None,
use_uvloop: bool | None = None,
context: contextvars.Context | None = None,
) -> _MainDecorator | _AFunc[_R] | _SFunc[_R] | _R:
) -> Callable[[_F], _F] | _F:
"""
Decorate a function to be the main entrypoint.
"""
if func is None:

def _main(_func: _F, /) -> _F | object:
def _main(_func: _F, /) -> _F:
return main(
_func,
debug=debug,
Expand All @@ -165,7 +161,7 @@ def _main(_func: _F, /) -> _F | object:
context=context,
)

return cast(_MainDecorator, _main)
return _main

if not callable(func):
errmsg = f'expected a callable, got {type(func)}'
Expand All @@ -178,34 +174,29 @@ def _main(_func: _F, /) -> _F | object:
if not frame or frame.f_globals.get('__name__') != '__main__':
return func

if debug or (debug is None and _infer_debug()):
if debug is None:
debug = _infer_debug()
if debug:
_enable_debug()

if is_async is False or (
is_async is None
and not asyncio.iscoroutinefunction(func)
): # fmt: skip
return cast(_R, func())
result = func()

if use_uvloop is None:
use_uvloop = _infer_uvloop()
if is_async is False:
return func
if is_async is None:
import asyncio

if sys.version_info < (3, 11):
if use_uvloop:
import uvloop

uvloop.install()
if not asyncio.iscoroutine(result):
return func

return asyncio.run(cast(Coroutine[Any, Any, _R], func()), debug=debug)
coro = cast('Coroutine[object, object, object]', result)

loop_factory = None
if use_uvloop:
import uvloop
if use_uvloop is None:
use_uvloop = _infer_uvloop()

loop_factory = uvloop.new_event_loop
_ = _run_async(coro, debug=debug, use_uvloop=use_uvloop, context=context)

with asyncio.Runner(debug=debug, loop_factory=loop_factory) as runner:
return runner.run(func(), context=context)
return func


@main
Expand Down
15 changes: 3 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ classifiers = [
"Programming Language :: Python :: 3.13",
]
requires-python = ">=3.9"
dependencies = [
"typing-extensions>=4.1; python_version < '3.11'",
]
dependencies = []

[project.urls]
Repository = "https://github.com/jorenham/mainpy"
Expand Down Expand Up @@ -66,7 +64,7 @@ exclude = [

[tool.uv]
dev-dependencies = [
"typing-extensions",
"typing_extensions>=4.12.2",
"uvloop; sys_platform != 'win32'",

"basedpyright>=1.17.5,<2",
Expand Down Expand Up @@ -104,14 +102,7 @@ stubPath = "."
typeCheckingMode = "all"
venv = ".venv"
venvPath = "."

reportAny = false
reportInvalidCast = false
reportUnusedCallResult = false
# this appears to be broken since pyright 1.1.359
reportUntypedFunctionDecorator = false
# because of `sys.version_info()` conditionals
reportUnreachable = false
reportUnreachable = false # unavoidable with `sys.version_info` conditionals


[tool.repo-review]
Expand Down
12 changes: 8 additions & 4 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,29 @@
PY_SYNC = """
import mainpy
result = [None]
@mainpy.main
def sync_main():
print({!r})
return 42
result[0] = 42
assert sync_main == 42
assert result[0] == 42
""".strip()

# language=python
PY_ASYNC = """
import asyncio
import mainpy
result = [None]
@mainpy.main
async def async_main():
print({!r})
return await asyncio.sleep(1e-6, 42)
result[0] = await asyncio.sleep(1e-6, 42)
assert async_main == 42
assert result[0] == 42
""".strip()


Expand Down
41 changes: 28 additions & 13 deletions tests/test_mainpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,63 +49,78 @@ def app():


def test_sync(monkeypatch: pytest.MonkeyPatch):
result: list[object] = [None]

@mp.main(is_async=False)
@_patch_module(monkeypatch, '__main__')
def app():
return 'spam'
result[0] = 'spam'

assert app == 'spam'
assert result[0] == 'spam'
assert callable(app)


def test_sync_implicit(monkeypatch: pytest.MonkeyPatch):
result: list[object] = [None]

@mp.main
@_patch_module(monkeypatch, '__main__')
def app():
return 'spam'
result[0] = 'spam'

assert app == 'spam'
assert result[0] == 'spam'
assert callable(app)


def test_async(monkeypatch: pytest.MonkeyPatch):
result: list[object] = [None]

@mp.main(is_async=True, use_uvloop=False)
@_patch_module(monkeypatch, '__main__')
async def app():
return await asyncio.sleep(0, 'spam')
result[0] = await asyncio.sleep(0, 'spam')

assert app == 'spam'
assert result[0] == 'spam'
assert callable(app)
assert isinstance(
asyncio.get_event_loop_policy(),
asyncio.DefaultEventLoopPolicy,
)


def test_async_implicit(no_uvloop: None, monkeypatch: pytest.MonkeyPatch): # pyright: ignore[reportUnusedParameter]
result: list[object] = [None]

@mp.main
@_patch_module(monkeypatch, '__main__')
async def app():
return await asyncio.sleep(0, 'spam')
result[0] = await asyncio.sleep(0, 'spam')

assert app == 'spam'
assert result[0] == 'spam'
assert callable(app)
assert isinstance(
asyncio.get_event_loop_policy(),
asyncio.DefaultEventLoopPolicy,
)


def test_async_implicit_uvloop(monkeypatch: pytest.MonkeyPatch):
result: list[object] = [None]

@mp.main
@_patch_module(monkeypatch, '__main__')
async def loop_module():
await asyncio.sleep(0)
loop = asyncio.get_running_loop()
return loop.__module__.split('.')[0]
result[0] = loop.__module__.split('.')[0]

assert loop_module
assert isinstance(loop_module, str)
assert result[0]
assert isinstance(result[0], str)
assert callable(loop_module)

if importlib.util.find_spec('uvloop') is None:
assert not mp._infer_uvloop()
assert loop_module == 'asyncio'
assert result[0] == 'asyncio'
else:
assert mp._infer_uvloop()
assert loop_module == 'uvloop'
assert result[0] == 'uvloop'
Loading

0 comments on commit ee7850c

Please sign in to comment.