Skip to content

Commit

Permalink
Merge pull request #99 from d0c-s4vage/hotfix/98-typedefd_structs_wit…
Browse files Browse the repository at this point in the history
…h_params

Adds tests for typedef'd parameterized structs
  • Loading branch information
d0c-s4vage authored Jan 4, 2020
2 parents 4ec0ae2 + 68fd244 commit 150222e
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 17 deletions.
59 changes: 45 additions & 14 deletions pfp/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def _pfp__init(self, stream):
self._pfp__node.args, scope, self, None
)
param_list = params.instantiate(scope, struct_args, self._pfp__interp)
super(self.__class__, self)._pfp__init(stream)

if hasattr(super(self.__class__, self), "_pfp__init"):
super(self.__class__, self)._pfp__init(stream)

new_class = type(
struct_cls.__name__ + "_", (struct_cls,), {"_pfp__init": _pfp__init}
Expand Down Expand Up @@ -84,13 +86,28 @@ def StructUnionTypeRef(curr_scope, typedef_name, refd_name, interp, node):
elif isinstance(node, AST.Union):
cls = fields.Union

def __new__(self, *args, **kwargs):
def __new__(cls_, *args, **kwargs):
refd_type = curr_scope.get_type(refd_name)
if refd_type is None:
refd_node = node
else:
refd_node = refd_type._pfp__node
return StructUnionDef(typedef_name, interp, refd_node)(*args, **kwargs)

def merged_init(self, stream):
if six.PY3:
cls_._pfp__init(self, stream)
else:
cls_._pfp__init.__func__(self, stream)
self._pfp__init_orig(stream)

overrides = {}
if hasattr(cls_, "_pfp__init"):
overrides["_pfp__init"] = merged_init

res = base_cls = StructUnionDef(
typedef_name, interp, refd_node, overrides=overrides,
)
return res(*args, **kwargs)

new_class = type(
typedef_name,
Expand All @@ -102,13 +119,16 @@ def __new__(self, *args, **kwargs):
return new_class



def StructUnionDef(typedef_name, interp, node):
def StructUnionDef(typedef_name, interp, node, overrides=None, cls=None):
if overrides is None:
overrides = {}
if isinstance(node, AST.Struct):
cls = fields.Struct
if cls is None:
cls = fields.Struct
decls = StructDecls(node.decls, node.coord)
elif isinstance(node, AST.Union):
cls = fields.Union
if cls is None:
cls = fields.Union
decls = UnionDecls(node.decls, node.coord)

# this is so that we can have all nested structs added to
Expand All @@ -117,23 +137,34 @@ def StructUnionDef(typedef_name, interp, node):
# the new struct to not be added to its parent, and the user would
# not be able to see how far the script got
def __init__(self, stream=None, metadata_processor=None, do_init=True):
cls.__init__(self, stream, metadata_processor=metadata_processor)
cls.__init__(
self,
stream,
metadata_processor=metadata_processor,
)

if do_init:
self._pfp__init(stream)

def _pfp__init(self, stream):
self._pfp__interp._handle_node(decls, ctxt=self, stream=stream)

cls_members = {
"__init__": __init__,
"_pfp__init": _pfp__init,
"_pfp__node": node,
"_pfp__interp": interp,
}

for k, v in six.iteritems(overrides or {}):
if k in cls_members:
cls_members[k + "_orig"] = cls_members[k]
cls_members[k] = v

new_class = type(
typedef_name,
(cls,),
{
"__init__": __init__,
"_pfp__init": _pfp__init,
"_pfp__node": node,
"_pfp__interp": interp,
},
cls_members,
)
return new_class

Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
py010parser>=0.1.15
py010parser>=0.1.17
six>=1.10.0,<2.0.0
intervaltree>=3.0.2,<4.0.0
intervaltree>=3.0.2,<4.0.0
26 changes: 25 additions & 1 deletion tests/test_struct_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_struct_vit9696_5(self):
LittleEndian();
ME s;
""",
debug=True,
debug=False,
)
assert dom.s.magic == "\x00\x01\x02\x03"
assert dom.s.filesize == 0x03020100
Expand Down Expand Up @@ -239,6 +239,30 @@ def test_struct_with_parameters3(self):
self.assertEqual(dom.l.c[1], 2)
self.assertEqual(dom.l.c[2], 3)

def test_typedefd_struct_with_parameters(self):
dom = self._test_parse_build(
"\x01\x02\x03\x04\x01\x02\x03",
"""
struct TEST_STRUCT(int arraySize, int arraySize2)
{
uchar b[arraySize];
uchar c[arraySize2];
};
local int bytes = 4;
typedef struct TEST_STRUCT NEW_STRUCT;
NEW_STRUCT l(bytes, 3);
""",
)
self.assertEqual(len(dom.l.b), 4)
self.assertEqual(dom.l.b[0], 1)
self.assertEqual(dom.l.b[1], 2)
self.assertEqual(dom.l.b[2], 3)
self.assertEqual(dom.l.b[3], 4)
self.assertEqual(len(dom.l.c), 3)
self.assertEqual(dom.l.c[0], 1)
self.assertEqual(dom.l.c[1], 2)
self.assertEqual(dom.l.c[2], 3)

def test_struct_decl_with_struct_keyword(self):
dom = self._test_parse_build(
"ABCD",
Expand Down

0 comments on commit 150222e

Please sign in to comment.