diff --git a/requirements-test.txt b/requirements-test.txt index a0f5aea..b98770a 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,6 +1,6 @@ Django>=1.6 djangorestframework>=2.4.3 -pytest-django==2.6 +pytest-django==2.9.1 pytest==2.5.2 pytest-cov==1.6 flake8==2.2.2 diff --git a/rest_framework_xml/parsers.py b/rest_framework_xml/parsers.py index 5454356..7e1bd4a 100644 --- a/rest_framework_xml/parsers.py +++ b/rest_framework_xml/parsers.py @@ -37,6 +37,13 @@ def parse(self, stream, media_type=None, parser_context=None): return data + def _check_xml_list(self, element): + """ + Checks that an element has multiple tags and that they are all the same, + to validate that the element is a properly formatted list + """ + return len(element) > 1 and len(set([child.tag for child in element])) <= 1 + def _xml_convert(self, element): """ convert the xml `element` into the corresponding python object @@ -48,7 +55,7 @@ def _xml_convert(self, element): return self._type_convert(element.text) else: # if the fist child tag is list-item means all children are list-item - if children[0].tag == "list-item": + if self._check_xml_list(element): data = [] for child in children: data.append(self._xml_convert(child)) diff --git a/tests/test_parsers.py b/tests/test_parsers.py index b04af4a..3e5597a 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -52,6 +52,49 @@ def setUp(self): } ] } + self._invalid_list_input = StringIO( + '' + '' + '' + '1first' + '2second' + '3third' + '' + '' + ) + self._invalid_list_output = { + "list": { + "list-item": { + "sub_id": 1, + "sub_name": "first" + }, + "list-item2": { + "sub_id": 3, + "sub_name": "third" + } + } + } + self._valid_list_input = StringIO( + '' + '' + '' + '1first' + '2second' + '' + '' + ) + self._valid_list_output = { + "list": [ + { + "sub_id": 1, + "sub_name": "first" + }, + { + "sub_id": 2, + "sub_name": "second" + } + ] + } @unittest.skipUnless(etree, 'defusedxml not installed') def test_parse(self): @@ -64,3 +107,15 @@ def test_complex_data_parse(self): parser = XMLParser() data = parser.parse(self._complex_data_input) self.assertEqual(data, self._complex_data) + + @unittest.skipUnless(etree, 'defusedxml not installed') + def test_invalid_list_parse(self): + parser = XMLParser() + data = parser.parse(self._invalid_list_input) + self.assertEqual(data, self._invalid_list_output) + + @unittest.skipUnless(etree, 'defusedxml not installed') + def test_valid_list_parse(self): + parser = XMLParser() + data = parser.parse(self._valid_list_input) + self.assertEqual(data, self._valid_list_output)