Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support text-based spaces to support LLM stuff #19

Merged
merged 11 commits into from
Mar 25, 2024
13 changes: 12 additions & 1 deletion cogment_lab/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ async def on_message(self, messages: list):
"""Handle received messages."""
pass

async def on_ending(self, observation, rendered_frame):
"""Handle trial ending."""
pass

async def end(self):
"""Clean up when done."""
pass
Expand All @@ -250,10 +254,17 @@ async def impl(self, actor_session: ActorSession):
async for event in actor_session.all_events():
event: RecvEvent
self.current_event = event
if event.type != cogment.EventType.ACTIVE:
if event.type not in (cogment.EventType.ACTIVE, cogment.EventType.ENDING):
logging.info(f"Skipping event of type {event.type}")
continue

if event.type == cogment.EventType.ENDING:
observation = self.session_helper.get_observation(event)
await self.on_ending(observation.value, observation.rendered_frame)
continue

# type = ACTIVE

if event.observation:
observation = self.session_helper.get_observation(event)
if observation is None:
Expand Down
54 changes: 26 additions & 28 deletions cogment_lab/generated/data_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: data.proto
"""Generated protocol buffer code."""
Expand All @@ -31,34 +30,33 @@
import cogment_lab.generated.spaces_pb2 as spaces__pb2


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\ndata.proto\x12\x0b\x63ogment_lab\x1a\rndarray.proto\x1a\x0cspaces.proto"\xd7\x01\n\x10\x45nvironmentSpecs\x12\x16\n\x0eimplementation\x18\x01 \x01(\t\x12\x12\n\nturn_based\x18\x02 \x01(\x08\x12\x13\n\x0bnum_players\x18\x03 \x01(\x05\x12\x34\n\x11observation_space\x18\x04 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x05 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12\x1b\n\x13web_components_file\x18\x06 \x01(\t"s\n\nAgentSpecs\x12\x34\n\x11observation_space\x18\x01 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space"Y\n\x05Value\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x05H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x02H\x00\x42\x0c\n\nvalue_type"\xf1\x01\n\x11\x45nvironmentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x0e\n\x06render\x18\x02 \x01(\x08\x12\x14\n\x0crender_width\x18\x03 \x01(\x05\x12\x0c\n\x04seed\x18\x04 \x01(\r\x12\x0f\n\x07\x66latten\x18\x05 \x01(\x08\x12\x41\n\nreset_args\x18\x06 \x03(\x0b\x32-.cogment_lab.EnvironmentConfig.ResetArgsEntry\x1a\x44\n\x0eResetArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.cogment_lab.Value:\x02\x38\x01"/\n\nHFHubModel\x12\x0f\n\x07repo_id\x18\x01 \x01(\t\x12\x10\n\x08\x66ilename\x18\x02 \x01(\t"\xa4\x01\n\x0b\x41gentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12,\n\x0b\x61gent_specs\x18\x02 \x01(\x0b\x32\x17.cogment_lab.AgentSpecs\x12\x0c\n\x04seed\x18\x03 \x01(\r\x12\x10\n\x08model_id\x18\x04 \x01(\t\x12\x17\n\x0fmodel_iteration\x18\x05 \x01(\x05\x12\x1e\n\x16model_update_frequency\x18\x06 \x01(\x05"\r\n\x0bTrialConfig"\x88\x01\n\x0bObservation\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12\x0e\n\x06\x61\x63tive\x18\x02 \x01(\x08\x12\r\n\x05\x61live\x18\x03 \x01(\x08\x12\x1b\n\x0erendered_frame\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x11\n\x0f_rendered_frame":\n\x0cPlayerAction\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Arrayb\x06proto3'
)
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ndata.proto\x12\x0b\x63ogment_lab\x1a\rndarray.proto\x1a\x0cspaces.proto\"\xd7\x01\n\x10\x45nvironmentSpecs\x12\x16\n\x0eimplementation\x18\x01 \x01(\t\x12\x12\n\nturn_based\x18\x02 \x01(\x08\x12\x13\n\x0bnum_players\x18\x03 \x01(\x05\x12\x34\n\x11observation_space\x18\x04 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x05 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12\x1b\n\x13web_components_file\x18\x06 \x01(\t\"s\n\nAgentSpecs\x12\x34\n\x11observation_space\x18\x01 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"Y\n\x05Value\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x05H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x02H\x00\x42\x0c\n\nvalue_type\"\xf1\x01\n\x11\x45nvironmentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x0e\n\x06render\x18\x02 \x01(\x08\x12\x14\n\x0crender_width\x18\x03 \x01(\x05\x12\x0c\n\x04seed\x18\x04 \x01(\r\x12\x0f\n\x07\x66latten\x18\x05 \x01(\x08\x12\x41\n\nreset_args\x18\x06 \x03(\x0b\x32-.cogment_lab.EnvironmentConfig.ResetArgsEntry\x1a\x44\n\x0eResetArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.cogment_lab.Value:\x02\x38\x01\"/\n\nHFHubModel\x12\x0f\n\x07repo_id\x18\x01 \x01(\t\x12\x10\n\x08\x66ilename\x18\x02 \x01(\t\"\xa4\x01\n\x0b\x41gentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12,\n\x0b\x61gent_specs\x18\x02 \x01(\x0b\x32\x17.cogment_lab.AgentSpecs\x12\x0c\n\x04seed\x18\x03 \x01(\r\x12\x10\n\x08model_id\x18\x04 \x01(\t\x12\x17\n\x0fmodel_iteration\x18\x05 \x01(\x05\x12\x1e\n\x16model_update_frequency\x18\x06 \x01(\x05\"\r\n\x0bTrialConfig\"\x88\x01\n\x0bObservation\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12\x0e\n\x06\x61\x63tive\x18\x02 \x01(\x08\x12\r\n\x05\x61live\x18\x03 \x01(\x08\x12\x1b\n\x0erendered_frame\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x11\n\x0f_rendered_frame\":\n\x0cPlayerAction\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Arrayb\x06proto3')

_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "data_pb2", globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'data_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_ENVIRONMENTCONFIG_RESETARGSENTRY._options = None
_ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_options = b"8\001"
_ENVIRONMENTSPECS._serialized_start = 57
_ENVIRONMENTSPECS._serialized_end = 272
_AGENTSPECS._serialized_start = 274
_AGENTSPECS._serialized_end = 389
_VALUE._serialized_start = 391
_VALUE._serialized_end = 480
_ENVIRONMENTCONFIG._serialized_start = 483
_ENVIRONMENTCONFIG._serialized_end = 724
_ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_start = 656
_ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_end = 724
_HFHUBMODEL._serialized_start = 726
_HFHUBMODEL._serialized_end = 773
_AGENTCONFIG._serialized_start = 776
_AGENTCONFIG._serialized_end = 940
_TRIALCONFIG._serialized_start = 942
_TRIALCONFIG._serialized_end = 955
_OBSERVATION._serialized_start = 958
_OBSERVATION._serialized_end = 1094
_PLAYERACTION._serialized_start = 1096
_PLAYERACTION._serialized_end = 1154

DESCRIPTOR._options = None
_ENVIRONMENTCONFIG_RESETARGSENTRY._options = None
_ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_options = b'8\001'
_ENVIRONMENTSPECS._serialized_start=57
_ENVIRONMENTSPECS._serialized_end=272
_AGENTSPECS._serialized_start=274
_AGENTSPECS._serialized_end=389
_VALUE._serialized_start=391
_VALUE._serialized_end=480
_ENVIRONMENTCONFIG._serialized_start=483
_ENVIRONMENTCONFIG._serialized_end=724
_ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_start=656
_ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_end=724
_HFHUBMODEL._serialized_start=726
_HFHUBMODEL._serialized_end=773
_AGENTCONFIG._serialized_start=776
_AGENTCONFIG._serialized_end=940
_TRIALCONFIG._serialized_start=942
_TRIALCONFIG._serialized_end=955
_OBSERVATION._serialized_start=958
_OBSERVATION._serialized_end=1094
_PLAYERACTION._serialized_start=1096
_PLAYERACTION._serialized_end=1154
# @@protoc_insertion_point(module_scope)
20 changes: 10 additions & 10 deletions cogment_lab/generated/ndarray_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: ndarray.proto
"""Generated protocol buffer code."""
Expand All @@ -27,16 +26,17 @@
_sym_db = _symbol_database.Default()


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\rndarray.proto\x12\x14\x63ogment_lab.nd_array"\xb8\x01\n\x05\x41rray\x12*\n\x05\x64type\x18\x01 \x01(\x0e\x32\x1b.cogment_lab.nd_array.DType\x12\r\n\x05shape\x18\x02 \x03(\r\x12\x10\n\x08raw_data\x18\x03 \x01(\x0c\x12\x10\n\x08npy_data\x18\x04 \x01(\x0c\x12\x13\n\x0b\x64ouble_data\x18\x05 \x03(\x01\x12\x12\n\nint32_data\x18\x06 \x03(\x11\x12\x12\n\nint64_data\x18\x07 \x03(\x12\x12\x13\n\x0buint32_data\x18\x08 \x03(\r*\x83\x01\n\x05\x44Type\x12\x11\n\rDTYPE_UNKNOWN\x10\x00\x12\x11\n\rDTYPE_FLOAT32\x10\x01\x12\x11\n\rDTYPE_FLOAT64\x10\x02\x12\x0e\n\nDTYPE_INT8\x10\x03\x12\x0f\n\x0b\x44TYPE_INT32\x10\x04\x12\x0f\n\x0b\x44TYPE_INT64\x10\x05\x12\x0f\n\x0b\x44TYPE_UINT8\x10\x06\x62\x06proto3'
)


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rndarray.proto\x12\x14\x63ogment_lab.nd_array\"\xcd\x01\n\x05\x41rray\x12*\n\x05\x64type\x18\x01 \x01(\x0e\x32\x1b.cogment_lab.nd_array.DType\x12\r\n\x05shape\x18\x02 \x03(\r\x12\x10\n\x08raw_data\x18\x03 \x01(\x0c\x12\x10\n\x08npy_data\x18\x04 \x01(\x0c\x12\x13\n\x0b\x64ouble_data\x18\x05 \x03(\x01\x12\x12\n\nint32_data\x18\x06 \x03(\x11\x12\x12\n\nint64_data\x18\x07 \x03(\x12\x12\x13\n\x0buint32_data\x18\x08 \x03(\r\x12\x13\n\x0bstring_data\x18\t \x03(\t*\x95\x01\n\x05\x44Type\x12\x11\n\rDTYPE_UNKNOWN\x10\x00\x12\x11\n\rDTYPE_FLOAT32\x10\x01\x12\x11\n\rDTYPE_FLOAT64\x10\x02\x12\x0e\n\nDTYPE_INT8\x10\x03\x12\x0f\n\x0b\x44TYPE_INT32\x10\x04\x12\x0f\n\x0b\x44TYPE_INT64\x10\x05\x12\x0f\n\x0b\x44TYPE_UINT8\x10\x06\x12\x10\n\x0c\x44TYPE_STRING\x10\x07\x62\x06proto3')

_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "ndarray_pb2", globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ndarray_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_DTYPE._serialized_start = 227
_DTYPE._serialized_end = 358
_ARRAY._serialized_start = 40
_ARRAY._serialized_end = 224

DESCRIPTOR._options = None
_DTYPE._serialized_start=248
_DTYPE._serialized_end=397
_ARRAY._serialized_start=40
_ARRAY._serialized_end=245
# @@protoc_insertion_point(module_scope)
40 changes: 20 additions & 20 deletions cogment_lab/generated/spaces_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: spaces.proto
"""Generated protocol buffer code."""
Expand All @@ -30,26 +29,27 @@
import cogment_lab.generated.ndarray_pb2 as ndarray__pb2


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x0cspaces.proto\x12\x12\x63ogment_lab.spaces\x1a\rndarray.proto"$\n\x08\x44iscrete\x12\t\n\x01n\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x05"Z\n\x03\x42ox\x12(\n\x03low\x18\x02 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12)\n\x04high\x18\x03 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array"5\n\x0bMultiBinary\x12&\n\x01n\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array":\n\rMultiDiscrete\x12)\n\x04nvec\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array"|\n\x04\x44ict\x12\x31\n\x06spaces\x18\x01 \x03(\x0b\x32!.cogment_lab.spaces.Dict.SubSpace\x1a\x41\n\x08SubSpace\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space"\x89\x02\n\x05Space\x12\x30\n\x08\x64iscrete\x18\x01 \x01(\x0b\x32\x1c.cogment_lab.spaces.DiscreteH\x00\x12&\n\x03\x62ox\x18\x02 \x01(\x0b\x32\x17.cogment_lab.spaces.BoxH\x00\x12(\n\x04\x64ict\x18\x03 \x01(\x0b\x32\x18.cogment_lab.spaces.DictH\x00\x12\x37\n\x0cmulti_binary\x18\x04 \x01(\x0b\x32\x1f.cogment_lab.spaces.MultiBinaryH\x00\x12;\n\x0emulti_discrete\x18\x05 \x01(\x0b\x32!.cogment_lab.spaces.MultiDiscreteH\x00\x42\x06\n\x04kindb\x06proto3'
)
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cspaces.proto\x12\x12\x63ogment_lab.spaces\x1a\rndarray.proto\"$\n\x08\x44iscrete\x12\t\n\x01n\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x05\"Z\n\x03\x42ox\x12(\n\x03low\x18\x02 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12)\n\x04high\x18\x03 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"5\n\x0bMultiBinary\x12&\n\x01n\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\":\n\rMultiDiscrete\x12)\n\x04nvec\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"|\n\x04\x44ict\x12\x31\n\x06spaces\x18\x01 \x03(\x0b\x32!.cogment_lab.spaces.Dict.SubSpace\x1a\x41\n\x08SubSpace\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"?\n\x04Text\x12\x12\n\nmax_length\x18\x01 \x01(\x05\x12\x12\n\nmin_length\x18\x02 \x01(\x05\x12\x0f\n\x07\x63harset\x18\x03 \x01(\t\"\xb3\x02\n\x05Space\x12\x30\n\x08\x64iscrete\x18\x01 \x01(\x0b\x32\x1c.cogment_lab.spaces.DiscreteH\x00\x12&\n\x03\x62ox\x18\x02 \x01(\x0b\x32\x17.cogment_lab.spaces.BoxH\x00\x12(\n\x04\x64ict\x18\x03 \x01(\x0b\x32\x18.cogment_lab.spaces.DictH\x00\x12\x37\n\x0cmulti_binary\x18\x04 \x01(\x0b\x32\x1f.cogment_lab.spaces.MultiBinaryH\x00\x12;\n\x0emulti_discrete\x18\x05 \x01(\x0b\x32!.cogment_lab.spaces.MultiDiscreteH\x00\x12(\n\x04text\x18\x06 \x01(\x0b\x32\x18.cogment_lab.spaces.TextH\x00\x42\x06\n\x04kindb\x06proto3')

_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "spaces_pb2", globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'spaces_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_DISCRETE._serialized_start = 51
_DISCRETE._serialized_end = 87
_BOX._serialized_start = 89
_BOX._serialized_end = 179
_MULTIBINARY._serialized_start = 181
_MULTIBINARY._serialized_end = 234
_MULTIDISCRETE._serialized_start = 236
_MULTIDISCRETE._serialized_end = 294
_DICT._serialized_start = 296
_DICT._serialized_end = 420
_DICT_SUBSPACE._serialized_start = 355
_DICT_SUBSPACE._serialized_end = 420
_SPACE._serialized_start = 423
_SPACE._serialized_end = 688

DESCRIPTOR._options = None
_DISCRETE._serialized_start=51
_DISCRETE._serialized_end=87
_BOX._serialized_start=89
_BOX._serialized_end=179
_MULTIBINARY._serialized_start=181
_MULTIBINARY._serialized_end=234
_MULTIDISCRETE._serialized_start=236
_MULTIDISCRETE._serialized_end=294
_DICT._serialized_start=296
_DICT._serialized_end=420
_DICT_SUBSPACE._serialized_start=355
_DICT_SUBSPACE._serialized_end=420
_TEXT._serialized_start=422
_TEXT._serialized_end=485
_SPACE._serialized_start=488
_SPACE._serialized_end=795
# @@protoc_insertion_point(module_scope)
160 changes: 160 additions & 0 deletions cogment_lab/humans/gradio_actor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright 2024 AI Redefined Inc. <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import asyncio
import json
import logging
import multiprocessing as mp
import signal
from typing import Any, Callable

import cogment
import numpy as np

from cogment_lab.core import CogmentActor
from cogment_lab.generated import cog_settings
from cogment_lab.utils.runners import setup_logging


def obs_to_msg(obs: np.ndarray | dict[str, np.ndarray | dict]) -> dict[str, Any]:
if isinstance(obs, np.ndarray):
obs = obs.tolist()
elif isinstance(obs, dict):
obs = {k: obs_to_msg(v) for k, v in obs.items()}
elif isinstance(obs, np.integer):
obs = int(obs)
elif isinstance(obs, np.floating):
obs = float(obs)
return obs


def msg_to_action(data: str, action_map: list[str] | dict[str, int]) -> int:
if isinstance(action_map, list):
action_map = {action: i for i, action in enumerate(action_map)}
if data.startswith("{"):
action = json.loads(data)
elif data not in action_map:
action = action_map["no-op"]
else:
action = action_map[data]
logging.info(f"Processed action {action} from {data} with action_map {action_map}")
return action


class GradioActor(CogmentActor):
def __init__(self, send_queue: mp.Queue, recv_queue: mp.Queue):
super().__init__(send_queue, recv_queue)
self.send_queue = send_queue
self.recv_queue = recv_queue

async def act(self, observation: Any, rendered_frame: np.ndarray | None = None) -> int:
# logging.info(f"Received observation {observation} and frame inside gradio actor")
obs_data = obs_to_msg(observation)
self.send_queue.put((obs_data, rendered_frame))
# logging.info(f"Sent observation {obs_data} and frame inside gradio actor")
action = self.recv_queue.get()
# logging.info(f"Received action {action} inside gradio actor")
return action

async def on_ending(self, observation, rendered_frame):
obs_data = obs_to_msg(observation)
self.send_queue.put((obs_data, rendered_frame))


async def run_cogment_actor(port: int, send_queue: asyncio.Queue, recv_queue: asyncio.Queue, signal_queue: mp.Queue):
context = cogment.Context(cog_settings=cog_settings, user_id="cogment_lab")
gradio_actor = GradioActor(send_queue, recv_queue)

logging.info("Registering actor")
context.register_actor(impl=gradio_actor.impl, impl_name="gradio", actor_classes=["player"])

logging.info("Serving actor")
serve = context.serve_all_registered(cogment.ServedEndpoint(port=port))

signal_queue.put(True)
await serve


async def shutdown():
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
asyncio.get_event_loop().stop()


def signal_handler(sig, frame):
asyncio.create_task(shutdown())


async def gradio_actor_main(
cogment_port: int,
gradio_app_fn: Callable[[mp.Queue, mp.Queue, str], None],
signal_queue: mp.Queue,
log_file: str | None = None,
):
gradio_to_actor = mp.Queue()
actor_to_gradio = mp.Queue()

logging.info("Starting gradio interface")
process = mp.Process(target=gradio_app_fn, args=(gradio_to_actor, actor_to_gradio, log_file))
process.start()

try:
logging.info("Starting cogment actor")
cogment_task = asyncio.create_task(
run_cogment_actor(
port=cogment_port,
send_queue=actor_to_gradio,
recv_queue=gradio_to_actor,
signal_queue=signal_queue,
)
)

logging.info("Waiting for cogment actor to finish")

await cogment_task
finally:
process.terminate()
process.join()


def gradio_actor_runner(
cogment_port: int,
gradio_app_fn: Callable[[mp.Queue, mp.Queue, str], None],
signal_queue: mp.Queue,
log_file: str | None = None,
):
if log_file:
setup_logging(log_file)

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
for sig in [signal.SIGINT, signal.SIGTERM]:
loop.add_signal_handler(sig, lambda s=sig, frame=None: signal_handler(s, frame))

try:
loop.run_until_complete(
gradio_actor_main(
cogment_port=cogment_port,
gradio_app_fn=gradio_app_fn,
signal_queue=signal_queue,
log_file=log_file,
)
)
finally:
loop.run_until_complete(shutdown())
loop.close()
Loading
Loading