diff --git a/openwisp_radius/api/serializers.py b/openwisp_radius/api/serializers.py index b8f24bcb..34dafe74 100644 --- a/openwisp_radius/api/serializers.py +++ b/openwisp_radius/api/serializers.py @@ -151,13 +151,13 @@ class RadiusAccountingSerializer(serializers.ModelSerializer): framed_ipv6_address = serializers.IPAddressField( required=False, allow_blank=True, protocol='IPv6' ) - session_time = serializers.IntegerField(required=False, default=0) + session_time = serializers.IntegerField(required=False) stop_time = serializers.DateTimeField(required=False) update_time = serializers.DateTimeField(required=False) - input_octets = serializers.IntegerField(required=False, default=0) - output_octets = serializers.IntegerField(required=False, default=0) - # this is needed otherwise serialize will ignore status_type from accounting packet - # as it's not a model field + input_octets = serializers.IntegerField(required=False) + output_octets = serializers.IntegerField(required=False) + # this is needed otherwise serializer will ignore status_type + # from the accounting request because it's not a model field status_type = serializers.ChoiceField( write_only=True, required=True, choices=STATUS_TYPE_CHOICES ) @@ -191,9 +191,24 @@ def is_valid(self, raise_exception=False): raise error def run_validation(self, data): + """ + Custom validation to handle empty strings in: + - session_time + - input_octets + - output_octets + """ for field in ['session_time', 'input_octets', 'output_octets']: - if data.get('status_type', None) == 'Start' and data[field] == '': + # missing data in accounting start, + # let's set zero as default value + if data.get('status_type', None) == 'Start' and data.get(field) == '': data[field] = 0 + # missing data in accounting stop, + # let's remove the empty string to + # prevent the API from failing + # the existing values stored in previous + # interim-updates won't be changed + if data.get('status_type', None) == 'Stop' and data.get(field) == '': + del data[field] return super().run_validation(data) def validate(self, data): diff --git a/openwisp_radius/tests/test_api/test_freeradius_api.py b/openwisp_radius/tests/test_api/test_freeradius_api.py index a1572651..6ad66e7c 100644 --- a/openwisp_radius/tests/test_api/test_freeradius_api.py +++ b/openwisp_radius/tests/test_api/test_freeradius_api.py @@ -991,6 +991,41 @@ def test_accounting_stop_200(self): self.assertEqual(ra.start_time, start_time) self.assertAcctData(ra, data) + @freeze_time(START_DATE) + def test_accounting_stop_empty_octets(self): + self.assertEqual(RadiusAccounting.objects.count(), 0) + data = self.acct_post_data + data.update( + dict( + input_octets=9900909, + output_octets=1513075509, + ) + ) + ra = self._create_radius_accounting(**data) + ra.refresh_from_db() + start_time = ra.start_time + data = self.acct_post_data + data.update( + { + 'status_type': 'Stop', + 'terminate_cause': '', + 'input_octets': '', + 'output_octets': '', + } + ) + data = self._get_accounting_params(**data) + response = self.post_json(data) + self.assertEqual(response.status_code, 200) + self.assertIsNone(response.data) + self.assertEqual(RadiusAccounting.objects.count(), 1) + ra.refresh_from_db() + self.assertEqual(ra.update_time.timetuple(), now().timetuple()) + self.assertEqual(ra.stop_time.timetuple(), now().timetuple()) + self.assertEqual(ra.start_time, start_time) + self.assertEqual(ra.input_octets, 9900909) + self.assertEqual(ra.output_octets, 1513075509) + self.assertEqual(ra.session_time, 261) + @freeze_time(START_DATE) @capture_any_output() def test_accounting_stop_201(self):