Skip to content

Commit

Permalink
Merge branch 'main' into sackfield/add-fast-pusher
Browse files Browse the repository at this point in the history
  • Loading branch information
8W9aG authored Jan 17, 2025
2 parents 539cafc + 85b85bf commit ef294cf
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 47 deletions.
2 changes: 1 addition & 1 deletion pkg/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func newRunCommand() *cobra.Command {
addFastFlag(cmd)

flags := cmd.Flags()
// Flags after first argment are considered args and passed to command
// Flags after first argument are considered args and passed to command

// This is called `publish` for consistency with `docker run`
cmd.Flags().StringArrayVarP(&runPorts, "publish", "p", []string{}, "Publish a container's port to the host, e.g. -p 8000")
Expand Down
27 changes: 20 additions & 7 deletions python/cog/server/scope.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import warnings
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Callable, Generator, Optional, Union
from typing import Any, Callable, Generator, Optional, Union

from attrs import evolve, frozen

from ..types import ExperimentalFeatureWarning


@frozen
class Scope:
def __init__(
self,
*,
record_metric: Callable[[str, Union[float, int]], None],
) -> None:
self.record_metric = record_metric
record_metric: Callable[[str, Union[float, int]], None]
_tag: Optional[str] = None


_current_scope: ContextVar[Optional[Scope]] = ContextVar("scope", default=None)
Expand All @@ -24,6 +23,10 @@ def current_scope() -> Scope:
category=ExperimentalFeatureWarning,
stacklevel=1,
)
return _get_current_scope()


def _get_current_scope() -> Scope:
s = _current_scope.get()
if s is None:
raise RuntimeError("No scope available")
Expand All @@ -39,6 +42,16 @@ def scope(sc: Scope) -> Generator[None, None, None]:
_current_scope.reset(s)


@contextmanager
def evolve_scope(**kwargs: Any) -> Generator[None, None, None]:
new_scope = evolve(_get_current_scope(), **kwargs)
s = _current_scope.set(new_scope)
try:
yield
finally:
_current_scope.reset(s)


def emit_metric(name: str, value: Union[float, int]) -> None:
"""
DEPRECATED: This function will be removed in a future version of cog.
Expand Down
18 changes: 5 additions & 13 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import contextlib
import contextvars
import inspect
import multiprocessing
import os
Expand Down Expand Up @@ -58,15 +57,12 @@
InvalidStateException,
)
from .helpers import SimpleStreamRedirector, StreamRedirector
from .scope import Scope, scope
from .scope import Scope, _get_current_scope, evolve_scope, scope

if PYDANTIC_V2:
from .helpers import unwrap_pydantic_serialization_iterators

_spawn = multiprocessing.get_context("spawn")
_tag_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"tag", default=None
)

_PublicEventType = Union[Done, Log, PredictionOutput, PredictionOutputType]

Expand Down Expand Up @@ -407,7 +403,7 @@ def __init__(
self._cancelable = False
self._max_concurrency = max_concurrency

# for synchronous predictors only! async predictors use _tag_var instead
# for synchronous predictors only! async predictors use current_scope()._tag instead
self._sync_tag: Optional[str] = None
self._has_async_predictor = is_async

Expand Down Expand Up @@ -483,10 +479,8 @@ def record_metric(self, name: str, value: Union[float, int]) -> None:

@property
def _current_tag(self) -> Optional[str]:
# if _tag_var is set, use that (only applies within _apredict())
tag = _tag_var.get()
if tag:
return tag
if self._has_async_predictor:
return _get_current_scope()._tag
return self._sync_tag

def _load_predictor(self) -> Optional[BasePredictor]:
Expand Down Expand Up @@ -687,9 +681,7 @@ async def _apredict(
predict: Callable[..., Any],
redirector: SimpleStreamRedirector,
) -> None:
_tag_var.set(tag)

with self._handle_predict_error(redirector, tag=tag):
with evolve_scope(tag=tag), self._handle_predict_error(redirector, tag=tag):
future_result = predict(**payload)

if future_result:
Expand Down
48 changes: 24 additions & 24 deletions test-integration/test_integration/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def assert_versions_match(semver_version: str, pep440_version: str):
)

# Check base release version
assert (
semver_release == pep440_groups["release"]
), f"Release versions do not match: {semver_release} != {pep440_groups['release']}"
assert semver_release == pep440_groups["release"], (
f"Release versions do not match: {semver_release} != {pep440_groups['release']}"
)

# Check prerelease status
semver_pre = semver_groups["prerelease"]
Expand All @@ -67,35 +67,35 @@ def assert_versions_match(semver_version: str, pep440_version: str):

if semver_pre:
if semver_pre.startswith("alpha"):
assert (
pep440_groups["pre_l"] == "a"
), "Alpha pre-release status does not match"
assert not pep440_groups[
"dev"
], "Semver pre-release cannot also be a PEP440 dev build"
assert pep440_groups["pre_l"] == "a", (
"Alpha pre-release status does not match"
)
assert not pep440_groups["dev"], (
"Semver pre-release cannot also be a PEP440 dev build"
)

if semver_pre.startswith("beta"):
assert (
pep440_groups["pre_l"] == "b"
), "Beta pre-release status does not match"
assert not pep440_groups[
"dev"
], "Semver pre-release cannot also be a PEP440 dev build"
assert pep440_groups["pre_l"] == "b", (
"Beta pre-release status does not match"
)
assert not pep440_groups["dev"], (
"Semver pre-release cannot also be a PEP440 dev build"
)

if semver_pre.startswith("rc"):
assert (
pep440_groups["pre_l"] == "rc"
), "Release candidate pre-release status does not match"
assert not pep440_groups[
"dev"
], "Semver pre-release cannot also be a PEP440 dev build"
assert pep440_groups["pre_l"] == "rc", (
"Release candidate pre-release status does not match"
)
assert not pep440_groups["dev"], (
"Semver pre-release cannot also be a PEP440 dev build"
)

if semver_pre.startswith("dev"):
assert pep440_groups["dev_l"] == "dev", "Dev build status does not match"

assert (
semver_groups["buildmetadata"] == pep440_groups["local"]
), f"Local/build metadata component does not match: {semver_groups['buildmetadata']} != {pep440_groups['local']}"
assert semver_groups["buildmetadata"] == pep440_groups["local"], (
f"Local/build metadata component does not match: {semver_groups['buildmetadata']} != {pep440_groups['local']}"
)


def random_string(length):
Expand Down
2 changes: 1 addition & 1 deletion tools/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ main() {
esac
fi

# Check the users sudo priviledges
# Check the users sudo privileges
if [ -z "${SUDO+set}" ]; then
SUDO="sudo"
fi
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ commands =
[testenv:lint]
base_python = python3.12
skip_install = true
deps = ruff
deps = ruff==0.9.1
commands =
ruff check python/cog
ruff format --check python
Expand Down

0 comments on commit ef294cf

Please sign in to comment.