Skip to content

Commit

Permalink
Support for Gymnasium Tuple Space (#20)
Browse files Browse the repository at this point in the history
# Description

Add serialization support for spaces of type `gymnasium.spaces.Tuple`.

## Type of change

Please delete options that are not relevant.

- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] This change requires a documentation update

<!--
Example:

| Before | After |
| ------ | ----- |
| _gif/png before_ | _gif/png after_ |


To upload images to a PR -- simply drag and drop an image while in edit
mode and it should upload the image directly. You can then paste that
source into the above before/after sections.
-->

# Checklist:

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`pre-commit run --all-files` (see `CONTRIBUTING.md` instructions to set
it up)
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] New and existing unit tests pass locally with my changes

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->

---------

Co-authored-by: ariel <[email protected]>
  • Loading branch information
wduguay-air and RedTachyon authored Apr 10, 2024
1 parent e80c032 commit 651854a
Show file tree
Hide file tree
Showing 7 changed files with 317 additions and 21 deletions.
10 changes: 5 additions & 5 deletions cogment_lab/generated/cog_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@


_player_class = _cog.actor.ActorClass(
name="player",
config_type=data_pb.AgentConfig,
action_space=data_pb.PlayerAction,
observation_space=data_pb.Observation,
)
name="player",
config_type=data_pb.AgentConfig,
action_space=data_pb.PlayerAction,
observation_space=data_pb.Observation,
)


actor_classes = _cog.actor.ActorClassList(_player_class)
Expand Down
88 changes: 85 additions & 3 deletions cogment_lab/generated/data_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder


# @@protoc_insertion_point(imports)
Expand All @@ -32,8 +33,89 @@

DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ndata.proto\x12\x0b\x63ogment_lab\x1a\rndarray.proto\x1a\x0cspaces.proto\"\xd7\x01\n\x10\x45nvironmentSpecs\x12\x16\n\x0eimplementation\x18\x01 \x01(\t\x12\x12\n\nturn_based\x18\x02 \x01(\x08\x12\x13\n\x0bnum_players\x18\x03 \x01(\x05\x12\x34\n\x11observation_space\x18\x04 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x05 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12\x1b\n\x13web_components_file\x18\x06 \x01(\t\"s\n\nAgentSpecs\x12\x34\n\x11observation_space\x18\x01 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"Y\n\x05Value\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x05H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x02H\x00\x42\x0c\n\nvalue_type\"\xf1\x01\n\x11\x45nvironmentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x0e\n\x06render\x18\x02 \x01(\x08\x12\x14\n\x0crender_width\x18\x03 \x01(\x05\x12\x0c\n\x04seed\x18\x04 \x01(\r\x12\x0f\n\x07\x66latten\x18\x05 \x01(\x08\x12\x41\n\nreset_args\x18\x06 \x03(\x0b\x32-.cogment_lab.EnvironmentConfig.ResetArgsEntry\x1a\x44\n\x0eResetArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.cogment_lab.Value:\x02\x38\x01\"/\n\nHFHubModel\x12\x0f\n\x07repo_id\x18\x01 \x01(\t\x12\x10\n\x08\x66ilename\x18\x02 \x01(\t\"\xa4\x01\n\x0b\x41gentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12,\n\x0b\x61gent_specs\x18\x02 \x01(\x0b\x32\x17.cogment_lab.AgentSpecs\x12\x0c\n\x04seed\x18\x03 \x01(\r\x12\x10\n\x08model_id\x18\x04 \x01(\t\x12\x17\n\x0fmodel_iteration\x18\x05 \x01(\x05\x12\x1e\n\x16model_update_frequency\x18\x06 \x01(\x05\"\r\n\x0bTrialConfig\"\x88\x01\n\x0bObservation\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12\x0e\n\x06\x61\x63tive\x18\x02 \x01(\x08\x12\r\n\x05\x61live\x18\x03 \x01(\x08\x12\x1b\n\x0erendered_frame\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x11\n\x0f_rendered_frame\":\n\x0cPlayerAction\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Arrayb\x06proto3')

_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'data_pb2', globals())


_ENVIRONMENTSPECS = DESCRIPTOR.message_types_by_name['EnvironmentSpecs']
_AGENTSPECS = DESCRIPTOR.message_types_by_name['AgentSpecs']
_VALUE = DESCRIPTOR.message_types_by_name['Value']
_ENVIRONMENTCONFIG = DESCRIPTOR.message_types_by_name['EnvironmentConfig']
_ENVIRONMENTCONFIG_RESETARGSENTRY = _ENVIRONMENTCONFIG.nested_types_by_name['ResetArgsEntry']
_HFHUBMODEL = DESCRIPTOR.message_types_by_name['HFHubModel']
_AGENTCONFIG = DESCRIPTOR.message_types_by_name['AgentConfig']
_TRIALCONFIG = DESCRIPTOR.message_types_by_name['TrialConfig']
_OBSERVATION = DESCRIPTOR.message_types_by_name['Observation']
_PLAYERACTION = DESCRIPTOR.message_types_by_name['PlayerAction']
EnvironmentSpecs = _reflection.GeneratedProtocolMessageType('EnvironmentSpecs', (_message.Message,), {
'DESCRIPTOR' : _ENVIRONMENTSPECS,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.EnvironmentSpecs)
})
_sym_db.RegisterMessage(EnvironmentSpecs)

AgentSpecs = _reflection.GeneratedProtocolMessageType('AgentSpecs', (_message.Message,), {
'DESCRIPTOR' : _AGENTSPECS,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.AgentSpecs)
})
_sym_db.RegisterMessage(AgentSpecs)

Value = _reflection.GeneratedProtocolMessageType('Value', (_message.Message,), {
'DESCRIPTOR' : _VALUE,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.Value)
})
_sym_db.RegisterMessage(Value)

EnvironmentConfig = _reflection.GeneratedProtocolMessageType('EnvironmentConfig', (_message.Message,), {

'ResetArgsEntry' : _reflection.GeneratedProtocolMessageType('ResetArgsEntry', (_message.Message,), {
'DESCRIPTOR' : _ENVIRONMENTCONFIG_RESETARGSENTRY,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.EnvironmentConfig.ResetArgsEntry)
})
,
'DESCRIPTOR' : _ENVIRONMENTCONFIG,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.EnvironmentConfig)
})
_sym_db.RegisterMessage(EnvironmentConfig)
_sym_db.RegisterMessage(EnvironmentConfig.ResetArgsEntry)

HFHubModel = _reflection.GeneratedProtocolMessageType('HFHubModel', (_message.Message,), {
'DESCRIPTOR' : _HFHUBMODEL,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.HFHubModel)
})
_sym_db.RegisterMessage(HFHubModel)

AgentConfig = _reflection.GeneratedProtocolMessageType('AgentConfig', (_message.Message,), {
'DESCRIPTOR' : _AGENTCONFIG,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.AgentConfig)
})
_sym_db.RegisterMessage(AgentConfig)

TrialConfig = _reflection.GeneratedProtocolMessageType('TrialConfig', (_message.Message,), {
'DESCRIPTOR' : _TRIALCONFIG,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.TrialConfig)
})
_sym_db.RegisterMessage(TrialConfig)

Observation = _reflection.GeneratedProtocolMessageType('Observation', (_message.Message,), {
'DESCRIPTOR' : _OBSERVATION,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.Observation)
})
_sym_db.RegisterMessage(Observation)

PlayerAction = _reflection.GeneratedProtocolMessageType('PlayerAction', (_message.Message,), {
'DESCRIPTOR' : _PLAYERACTION,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.PlayerAction)
})
_sym_db.RegisterMessage(PlayerAction)

if _descriptor._USE_C_DESCRIPTORS == False:

DESCRIPTOR._options = None
Expand Down
26 changes: 23 additions & 3 deletions cogment_lab/generated/ndarray_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
from google.protobuf.internal import enum_type_wrapper


# @@protoc_insertion_point(imports)
Expand All @@ -30,8 +32,26 @@

DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rndarray.proto\x12\x14\x63ogment_lab.nd_array\"\xcd\x01\n\x05\x41rray\x12*\n\x05\x64type\x18\x01 \x01(\x0e\x32\x1b.cogment_lab.nd_array.DType\x12\r\n\x05shape\x18\x02 \x03(\r\x12\x10\n\x08raw_data\x18\x03 \x01(\x0c\x12\x10\n\x08npy_data\x18\x04 \x01(\x0c\x12\x13\n\x0b\x64ouble_data\x18\x05 \x03(\x01\x12\x12\n\nint32_data\x18\x06 \x03(\x11\x12\x12\n\nint64_data\x18\x07 \x03(\x12\x12\x13\n\x0buint32_data\x18\x08 \x03(\r\x12\x13\n\x0bstring_data\x18\t \x03(\t*\x95\x01\n\x05\x44Type\x12\x11\n\rDTYPE_UNKNOWN\x10\x00\x12\x11\n\rDTYPE_FLOAT32\x10\x01\x12\x11\n\rDTYPE_FLOAT64\x10\x02\x12\x0e\n\nDTYPE_INT8\x10\x03\x12\x0f\n\x0b\x44TYPE_INT32\x10\x04\x12\x0f\n\x0b\x44TYPE_INT64\x10\x05\x12\x0f\n\x0b\x44TYPE_UINT8\x10\x06\x12\x10\n\x0c\x44TYPE_STRING\x10\x07\x62\x06proto3')

_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ndarray_pb2', globals())
_DTYPE = DESCRIPTOR.enum_types_by_name['DType']
DType = enum_type_wrapper.EnumTypeWrapper(_DTYPE)
DTYPE_UNKNOWN = 0
DTYPE_FLOAT32 = 1
DTYPE_FLOAT64 = 2
DTYPE_INT8 = 3
DTYPE_INT32 = 4
DTYPE_INT64 = 5
DTYPE_UINT8 = 6
DTYPE_STRING = 7


_ARRAY = DESCRIPTOR.message_types_by_name['Array']
Array = _reflection.GeneratedProtocolMessageType('Array', (_message.Message,), {
'DESCRIPTOR' : _ARRAY,
'__module__' : 'ndarray_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.nd_array.Array)
})
_sym_db.RegisterMessage(Array)

if _descriptor._USE_C_DESCRIPTORS == False:

DESCRIPTOR._options = None
Expand Down
103 changes: 95 additions & 8 deletions cogment_lab/generated/spaces_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder


# @@protoc_insertion_point(imports)
Expand All @@ -29,10 +30,92 @@
import cogment_lab.generated.ndarray_pb2 as ndarray__pb2


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cspaces.proto\x12\x12\x63ogment_lab.spaces\x1a\rndarray.proto\"$\n\x08\x44iscrete\x12\t\n\x01n\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x05\"Z\n\x03\x42ox\x12(\n\x03low\x18\x02 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12)\n\x04high\x18\x03 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"5\n\x0bMultiBinary\x12&\n\x01n\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\":\n\rMultiDiscrete\x12)\n\x04nvec\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"|\n\x04\x44ict\x12\x31\n\x06spaces\x18\x01 \x03(\x0b\x32!.cogment_lab.spaces.Dict.SubSpace\x1a\x41\n\x08SubSpace\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"?\n\x04Text\x12\x12\n\nmax_length\x18\x01 \x01(\x05\x12\x12\n\nmin_length\x18\x02 \x01(\x05\x12\x0f\n\x07\x63harset\x18\x03 \x01(\t\"\xb3\x02\n\x05Space\x12\x30\n\x08\x64iscrete\x18\x01 \x01(\x0b\x32\x1c.cogment_lab.spaces.DiscreteH\x00\x12&\n\x03\x62ox\x18\x02 \x01(\x0b\x32\x17.cogment_lab.spaces.BoxH\x00\x12(\n\x04\x64ict\x18\x03 \x01(\x0b\x32\x18.cogment_lab.spaces.DictH\x00\x12\x37\n\x0cmulti_binary\x18\x04 \x01(\x0b\x32\x1f.cogment_lab.spaces.MultiBinaryH\x00\x12;\n\x0emulti_discrete\x18\x05 \x01(\x0b\x32!.cogment_lab.spaces.MultiDiscreteH\x00\x12(\n\x04text\x18\x06 \x01(\x0b\x32\x18.cogment_lab.spaces.TextH\x00\x42\x06\n\x04kindb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cspaces.proto\x12\x12\x63ogment_lab.spaces\x1a\rndarray.proto\"$\n\x08\x44iscrete\x12\t\n\x01n\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x05\"Z\n\x03\x42ox\x12(\n\x03low\x18\x02 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12)\n\x04high\x18\x03 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"5\n\x0bMultiBinary\x12&\n\x01n\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\":\n\rMultiDiscrete\x12)\n\x04nvec\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"|\n\x04\x44ict\x12\x31\n\x06spaces\x18\x01 \x03(\x0b\x32!.cogment_lab.spaces.Dict.SubSpace\x1a\x41\n\x08SubSpace\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"q\n\x05Tuple\x12\x32\n\x06spaces\x18\x02 \x03(\x0b\x32\".cogment_lab.spaces.Tuple.SubSpace\x1a\x34\n\x08SubSpace\x12(\n\x05space\x18\x01 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"?\n\x04Text\x12\x12\n\nmax_length\x18\x01 \x01(\x05\x12\x12\n\nmin_length\x18\x02 \x01(\x05\x12\x0f\n\x07\x63harset\x18\x03 \x01(\t\"\xdf\x02\n\x05Space\x12\x30\n\x08\x64iscrete\x18\x01 \x01(\x0b\x32\x1c.cogment_lab.spaces.DiscreteH\x00\x12&\n\x03\x62ox\x18\x02 \x01(\x0b\x32\x17.cogment_lab.spaces.BoxH\x00\x12(\n\x04\x64ict\x18\x03 \x01(\x0b\x32\x18.cogment_lab.spaces.DictH\x00\x12\x37\n\x0cmulti_binary\x18\x04 \x01(\x0b\x32\x1f.cogment_lab.spaces.MultiBinaryH\x00\x12;\n\x0emulti_discrete\x18\x05 \x01(\x0b\x32!.cogment_lab.spaces.MultiDiscreteH\x00\x12*\n\x05tuple\x18\x06 \x01(\x0b\x32\x19.cogment_lab.spaces.TupleH\x00\x12(\n\x04text\x18\x07 \x01(\x0b\x32\x18.cogment_lab.spaces.TextH\x00\x42\x06\n\x04kindb\x06proto3')



_DISCRETE = DESCRIPTOR.message_types_by_name['Discrete']
_BOX = DESCRIPTOR.message_types_by_name['Box']
_MULTIBINARY = DESCRIPTOR.message_types_by_name['MultiBinary']
_MULTIDISCRETE = DESCRIPTOR.message_types_by_name['MultiDiscrete']
_DICT = DESCRIPTOR.message_types_by_name['Dict']
_DICT_SUBSPACE = _DICT.nested_types_by_name['SubSpace']
_TUPLE = DESCRIPTOR.message_types_by_name['Tuple']
_TUPLE_SUBSPACE = _TUPLE.nested_types_by_name['SubSpace']
_TEXT = DESCRIPTOR.message_types_by_name['Text']
_SPACE = DESCRIPTOR.message_types_by_name['Space']
Discrete = _reflection.GeneratedProtocolMessageType('Discrete', (_message.Message,), {
'DESCRIPTOR' : _DISCRETE,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.Discrete)
})
_sym_db.RegisterMessage(Discrete)

Box = _reflection.GeneratedProtocolMessageType('Box', (_message.Message,), {
'DESCRIPTOR' : _BOX,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.Box)
})
_sym_db.RegisterMessage(Box)

MultiBinary = _reflection.GeneratedProtocolMessageType('MultiBinary', (_message.Message,), {
'DESCRIPTOR' : _MULTIBINARY,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.MultiBinary)
})
_sym_db.RegisterMessage(MultiBinary)

MultiDiscrete = _reflection.GeneratedProtocolMessageType('MultiDiscrete', (_message.Message,), {
'DESCRIPTOR' : _MULTIDISCRETE,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.MultiDiscrete)
})
_sym_db.RegisterMessage(MultiDiscrete)

Dict = _reflection.GeneratedProtocolMessageType('Dict', (_message.Message,), {

'SubSpace' : _reflection.GeneratedProtocolMessageType('SubSpace', (_message.Message,), {
'DESCRIPTOR' : _DICT_SUBSPACE,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.Dict.SubSpace)
})
,
'DESCRIPTOR' : _DICT,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.Dict)
})
_sym_db.RegisterMessage(Dict)
_sym_db.RegisterMessage(Dict.SubSpace)

Tuple = _reflection.GeneratedProtocolMessageType('Tuple', (_message.Message,), {

'SubSpace' : _reflection.GeneratedProtocolMessageType('SubSpace', (_message.Message,), {
'DESCRIPTOR' : _TUPLE_SUBSPACE,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.Tuple.SubSpace)
})
,
'DESCRIPTOR' : _TUPLE,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.Tuple)
})
_sym_db.RegisterMessage(Tuple)
_sym_db.RegisterMessage(Tuple.SubSpace)

Text = _reflection.GeneratedProtocolMessageType('Text', (_message.Message,), {
'DESCRIPTOR' : _TEXT,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.Text)
})
_sym_db.RegisterMessage(Text)

Space = _reflection.GeneratedProtocolMessageType('Space', (_message.Message,), {
'DESCRIPTOR' : _SPACE,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.Space)
})
_sym_db.RegisterMessage(Space)

_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'spaces_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:

DESCRIPTOR._options = None
Expand All @@ -48,8 +131,12 @@
_DICT._serialized_end=420
_DICT_SUBSPACE._serialized_start=355
_DICT_SUBSPACE._serialized_end=420
_TEXT._serialized_start=422
_TEXT._serialized_end=485
_SPACE._serialized_start=488
_SPACE._serialized_end=795
_TUPLE._serialized_start=422
_TUPLE._serialized_end=535
_TUPLE_SUBSPACE._serialized_start=483
_TUPLE_SUBSPACE._serialized_end=535
_TEXT._serialized_start=537
_TEXT._serialized_end=600
_SPACE._serialized_start=603
_SPACE._serialized_end=954
# @@protoc_insertion_point(module_scope)
10 changes: 9 additions & 1 deletion cogment_lab/protos/spaces.proto
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ message Dict {
repeated SubSpace spaces = 1;
}

message Tuple {
message SubSpace {
Space space = 1;
}
repeated SubSpace spaces = 2;
}

message Text {
int32 max_length = 1;
int32 min_length = 2;
Expand All @@ -57,6 +64,7 @@ message Space {
Dict dict = 3;
MultiBinary multi_binary = 4;
MultiDiscrete multi_discrete = 5;
Text text = 6;
Tuple tuple = 6;
Text text = 7;
}
}
12 changes: 11 additions & 1 deletion cogment_lab/specs/spaces_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from cogment_lab.generated.spaces_pb2 import MultiDiscrete # type: ignore
from cogment_lab.generated.spaces_pb2 import Space # type: ignore
from cogment_lab.generated.spaces_pb2 import Text # type: ignore
from cogment_lab.generated.spaces_pb2 import Tuple # type: ignore

from .ndarray_serialization import (
SerializationFormat,
Expand Down Expand Up @@ -66,6 +67,12 @@ def serialize_gym_space(space: gym.Space, serialization_format=SerializationForm
spaces.append(Dict.SubSpace(key=key, space=serialize_gym_space(gym_sub_space)))
return Space(dict=Dict(spaces=spaces))

if isinstance(space, gym.spaces.Tuple):
spaces = []
for gym_sub_space in space.spaces:
spaces.append(Tuple.SubSpace(space=serialize_gym_space(gym_sub_space, serialization_format)))
return Space(tuple=Tuple(spaces=spaces))

if isinstance(space, gym.spaces.Text):
return Space(text=Text(max_length=space.max_length, min_length=space.min_length, charset=space.characters))

Expand Down Expand Up @@ -101,8 +108,11 @@ def deserialize_space(pb_space: Space) -> gym.Space:
spaces = []
for sub_space in dict_space_pb.spaces:
spaces.append((sub_space.key, deserialize_space(sub_space.space)))

return gym.spaces.Dict(spaces=spaces)
if space_kind == "tuple":
tuple_space_pb = pb_space.tuple
spaces = [deserialize_space(sub_space) for sub_space in tuple_space_pb.spaces]
return gym.spaces.Tuple(spaces=spaces)
if space_kind == "text":
text_space_pb = pb_space.text
return gym.spaces.Text(
Expand Down
Loading

0 comments on commit 651854a

Please sign in to comment.