Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rip out TaskLocal objects and fix system task context inheritance. #499

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 2 additions & 130 deletions trio/_core/_local.py
Original file line number Diff line number Diff line change
@@ -1,117 +1,7 @@
# Task- and Run-local storage
from .._deprecate import deprecated
# Run-local storage
from . import _run

__all__ = ["TaskLocal", "RunLocal", "RunVar"]

# Our public API is intentionally almost identical to that of threading.local:
# the user allocates a trio.{Task,Run}Local() object, and then can attach
# arbitrary attributes to it. Reading one of these attributes later will
# return the last value that was assigned to this attribute *by code running
# inside the same task or run*.


# This is conceptually a method on _LocalBase, but given the way we're playing
# with attribute access making it a free-standing function is simpler:
def _local_dict(local_obj):
locals_type = object.__getattribute__(local_obj, "_locals_key")
try:
refobj = getattr(_run.GLOBAL_RUN_CONTEXT, locals_type)
except AttributeError:
raise RuntimeError("must be called from async context") from None
try:
return refobj._locals[local_obj]
except KeyError:
new_dict = dict(object.__getattribute__(local_obj, "_defaults"))
refobj._locals[local_obj] = new_dict
return new_dict


# Ughhh subclassing I feel so dirty
class _LocalBase:
__slots__ = ("_defaults",)

def __init__(self, **kwargs):
object.__setattr__(self, "_defaults", kwargs)

def __getattribute__(self, name):
ld = _local_dict(self)
if name == "__dict__":
return ld
try:
return ld[name]
except KeyError:
raise AttributeError(name) from None

def __setattr__(self, name, value):
_local_dict(self)[name] = value

def __delattr__(self, name):
try:
del _local_dict(self)[name]
except KeyError:
raise AttributeError(name) from None

def __dir__(self):
return list(_local_dict(self))


class TaskLocal(_LocalBase):
"""Task-local storage.

Instances of this class have no particular attributes or methods. Instead,
they serve as a blank slate to which you can add whatever attributes you
like. Modifications made within one task will only be visible to that task
– with one exception: when you start a new task, then any
:class:`TaskLocal` attributes that are visible in the task that called
``start`` or ``start_soon`` will be inherited by the child. This
inheritance takes the form of a shallow copy: further changes in the
parent will *not* affect the child, and changes in the child will not
affect the parent. (If you're familiar with how environment variables are
inherited across processes, then :class:`TaskLocal` inheritance is
somewhat similar.)

If you're familiar with :class:`threading.local`, then
:class:`trio.TaskLocal` is very similar, except adapted to work with tasks
instead of threads, and with the added feature that values are
automatically inherited across tasks.

When creating a :class:`TaskLocal` object, you can provide default values
as keyword arguments::

local = trio.TaskLocal(a=1)

async def main():
# The first time we access the TaskLocal object, the 'a' attribute
# is already present:
assert local.a == 1

The default values are like the default values to functions: they're only
evaluated once, when the object is created. So you shouldn't use mutable
objects as defaults -- they'll be shared not just across tasks, but even
across entirely unrelated runs! For example::

# Don't do this!!
local = trio.TaskLocal(a=[])

async def main():
assert local.a == []
local.a.append(1)

# First time, everything seems to work
trio.run(main)

# Second time, the assertion fails, because the first time modified
# the list object.
trio.run(main)

"""
__slots__ = ()
_locals_key = "task"

@deprecated("0.4.0", issue=420, instead="contextvars.ContextVar")
def __init__(self, **kwargs):
super().__init__(**kwargs)
__all__ = ["RunVar"]


class _RunVarToken(object):
Expand Down Expand Up @@ -206,21 +96,3 @@ def reset(self, token):

def __repr__(self):
return ("<RunVar name={!r}>".format(self._name))


class RunLocal(_LocalBase):
"""Run-local storage.

:class:`RunLocal` objects are very similar to :class:`trio.TaskLocal`
objects, except that attributes are shared across all the tasks within a
single call to :func:`trio.run`. They're also very similar to
:class:`threading.local` objects, except that :class:`RunLocal` objects
are automatically wiped clean when :func:`trio.run` returns.

"""
__slots__ = ()
_locals_key = "runner"

@deprecated("0.4.0", issue=478, instead=RunVar)
def __init__(self, **kwargs):
super().__init__(**kwargs)
27 changes: 13 additions & 14 deletions trio/_core/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,9 +489,6 @@ class Task:
# For introspection and nursery.start()
_child_nurseries = attr.ib(default=attr.Factory(list))

# Task-local values, see _local.py
_locals = attr.ib(default=attr.Factory(dict))

# these are counts of how many cancel/schedule points this task has
# executed, for assert{_no,}_yields
# XX maybe these should be exposed as part of a statistics() method?
Expand Down Expand Up @@ -591,6 +588,8 @@ class Runner:

# Run-local values, see _local.py
_locals = attr.ib(default=attr.Factory(dict))
# Used for system tasks
_system_context = attr.ib(default=attr.Factory(copy_context))

runq = attr.ib(default=attr.Factory(deque))
tasks = attr.ib(default=attr.Factory(set))
Expand Down Expand Up @@ -714,7 +713,7 @@ def reschedule(self, task, next_send=Value(None)):
self.instrument("task_scheduled", task)

def spawn_impl(
self, async_fn, args, nursery, name, *, ki_protection_enabled=False
self, async_fn, args, nursery, name, *, system_task=False,
):

######
Expand Down Expand Up @@ -818,28 +817,28 @@ def _return_value_looks_like_wrong_library(value):
name = "{}.{}".format(name.__module__, name.__qualname__)
except AttributeError:
name = repr(name)

if system_task:
context = self._system_context
else:
context = copy_context()

task = Task(
coro=coro,
parent_nursery=nursery,
runner=self,
name=name,
context=copy_context(),
context=context,
)
self.tasks.add(task)

if nursery is not None:
nursery._children.add(task)
for scope in nursery._cancel_stack:
scope._add_task(task)
coro.cr_frame.f_locals.setdefault(
LOCALS_KEY_KI_PROTECTION_ENABLED, ki_protection_enabled
LOCALS_KEY_KI_PROTECTION_ENABLED, system_task
)
if nursery is not None:
# Task locals are inherited from the spawning task, not the
# nursery task. The 'if nursery' check is just used as a guard to
# make sure we don't try to do this to the root task.
parent_task = current_task()
for local, values in parent_task._locals.items():
task._locals[local] = dict(values)
self.instrument("task_spawned", task)
# Special case: normally next_send should be a Result, but for the
# very first send we have to send a literal unboxed None.
Expand Down Expand Up @@ -924,7 +923,7 @@ def excfilter(exc):
system_task_wrapper, (async_fn, args),
self.system_nursery,
name,
ki_protection_enabled=True
system_task=True,
)

async def init(self, async_fn, args):
Expand Down
169 changes: 0 additions & 169 deletions trio/_core/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,175 +7,6 @@
from ...testing import Sequencer


async def test_local_smoketest(recwarn):
for cls in _core.TaskLocal, _core.RunLocal:
local = cls()

assert local.__dict__ == {}
assert vars(local) == {}
assert dir(local) == []
assert not hasattr(local, "a")

local.a = 1
assert local.a == 1
assert local.__dict__ == {"a": 1}
assert vars(local) == {"a": 1}
assert dir(local) == ["a"]
assert hasattr(local, "a")

del local.a

with pytest.raises(AttributeError):
local.a
with pytest.raises(AttributeError):
del local.a

assert local.__dict__ == {}
assert vars(local) == {}

local.__dict__["b"] = 2
assert local.b == 2

async def child():
assert local.b == 2

async with _core.open_nursery() as nursery:
nursery.start_soon(child)


async def test_local_isolation(recwarn):
tlocal = _core.TaskLocal()
rlocal = _core.RunLocal()

tlocal.a = "task root"
rlocal.a = "run root"

seq = Sequencer()

async def child1():
async with seq(0):
assert tlocal.a == "task root"
assert rlocal.a == "run root"

tlocal.a = "task child1"
rlocal.a = "run child1"

async with seq(2):
assert tlocal.a == "task child1"
assert rlocal.a == "run child2"

async def child2():
async with seq(1):
assert tlocal.a == "task root"
assert rlocal.a == "run child1"

tlocal.a = "task child2"
rlocal.a = "run child2"

async with _core.open_nursery() as nursery:
nursery.start_soon(child1)
nursery.start_soon(child2)

assert tlocal.a == "task root"
assert rlocal.a == "run child2"


def test_run_local_multiple_runs(recwarn):
r = _core.RunLocal()

async def main(x):
assert not hasattr(r, "attr")
r.attr = x
assert hasattr(r, "attr")
assert r.attr == x

# Nothing spills over from one run to the next
_core.run(main, 1)
_core.run(main, 2)


def test_run_local_simultaneous_runs(recwarn):
r = _core.RunLocal()

result_q = queue.Queue()

async def main(x, in_q, out_q):
in_q.get()
assert not hasattr(r, "attr")
r.attr = x
assert hasattr(r, "attr")
assert r.attr == x
out_q.put(None)
in_q.get()
assert r.attr == x

def harness(x, in_q, out_q):
result_q.put(_core.Result.capture(_core.run, main, x, in_q, out_q))

in_q1 = queue.Queue()
out_q1 = queue.Queue()
t1 = threading.Thread(target=harness, args=(1, in_q1, out_q1))
t1.start()

in_q2 = queue.Queue()
out_q2 = queue.Queue()
t2 = threading.Thread(target=harness, args=(2, in_q2, out_q2))
t2.start()

in_q1.put(None)
out_q1.get()

in_q2.put(None)
out_q2.get()

in_q1.put(None)
in_q2.put(None)
t1.join()
t2.join()
result_q.get().unwrap()
result_q.get().unwrap()

with pytest.raises(RuntimeError):
r.attr


def test_local_outside_run(recwarn):
for cls in _core.RunLocal, _core.TaskLocal:
local = cls()

with pytest.raises(RuntimeError):
local.a = 1

with pytest.raises(RuntimeError):
dir(local)


async def test_local_inheritance_from_spawner_not_supervisor(recwarn):
t = _core.TaskLocal()

t.x = "supervisor"

async def spawner(nursery):
t.x = "spawner"
nursery.start_soon(child)

async def child():
assert t.x == "spawner"

async with _core.open_nursery() as nursery:
nursery.start_soon(spawner, nursery)


async def test_local_defaults(recwarn):
for cls in _core.TaskLocal, _core.RunLocal:
local = cls(default1=123, default2="abc")
assert local.default1 == 123
assert local.default2 == "abc"
del local.default1
assert not hasattr(local, "default1")


# scary runvar tests
def test_runvar_smoketest():
t1 = _core.RunVar("test1")
t2 = _core.RunVar("test2", default="catfish")
Expand Down
Loading