diff --git a/requirements-test.txt b/requirements-test.txt
index a0f5aea..b98770a 100644
--- a/requirements-test.txt
+++ b/requirements-test.txt
@@ -1,6 +1,6 @@
Django>=1.6
djangorestframework>=2.4.3
-pytest-django==2.6
+pytest-django==2.9.1
pytest==2.5.2
pytest-cov==1.6
flake8==2.2.2
diff --git a/rest_framework_xml/parsers.py b/rest_framework_xml/parsers.py
index 5454356..7e1bd4a 100644
--- a/rest_framework_xml/parsers.py
+++ b/rest_framework_xml/parsers.py
@@ -37,6 +37,13 @@ def parse(self, stream, media_type=None, parser_context=None):
return data
+ def _check_xml_list(self, element):
+ """
+ Checks that an element has multiple tags and that they are all the same,
+ to validate that the element is a properly formatted list
+ """
+ return len(element) > 1 and len(set([child.tag for child in element])) <= 1
+
def _xml_convert(self, element):
"""
convert the xml `element` into the corresponding python object
@@ -48,7 +55,7 @@ def _xml_convert(self, element):
return self._type_convert(element.text)
else:
# if the fist child tag is list-item means all children are list-item
- if children[0].tag == "list-item":
+ if self._check_xml_list(element):
data = []
for child in children:
data.append(self._xml_convert(child))
diff --git a/tests/test_parsers.py b/tests/test_parsers.py
index b04af4a..3e5597a 100644
--- a/tests/test_parsers.py
+++ b/tests/test_parsers.py
@@ -52,6 +52,49 @@ def setUp(self):
}
]
}
+ self._invalid_list_input = StringIO(
+ ''
+ ''
+ ''
+ '1first'
+ '2second'
+ '3third'
+ '
'
+ ''
+ )
+ self._invalid_list_output = {
+ "list": {
+ "list-item": {
+ "sub_id": 1,
+ "sub_name": "first"
+ },
+ "list-item2": {
+ "sub_id": 3,
+ "sub_name": "third"
+ }
+ }
+ }
+ self._valid_list_input = StringIO(
+ ''
+ ''
+ ''
+ '1first'
+ '2second'
+ '
'
+ ''
+ )
+ self._valid_list_output = {
+ "list": [
+ {
+ "sub_id": 1,
+ "sub_name": "first"
+ },
+ {
+ "sub_id": 2,
+ "sub_name": "second"
+ }
+ ]
+ }
@unittest.skipUnless(etree, 'defusedxml not installed')
def test_parse(self):
@@ -64,3 +107,15 @@ def test_complex_data_parse(self):
parser = XMLParser()
data = parser.parse(self._complex_data_input)
self.assertEqual(data, self._complex_data)
+
+ @unittest.skipUnless(etree, 'defusedxml not installed')
+ def test_invalid_list_parse(self):
+ parser = XMLParser()
+ data = parser.parse(self._invalid_list_input)
+ self.assertEqual(data, self._invalid_list_output)
+
+ @unittest.skipUnless(etree, 'defusedxml not installed')
+ def test_valid_list_parse(self):
+ parser = XMLParser()
+ data = parser.parse(self._valid_list_input)
+ self.assertEqual(data, self._valid_list_output)