diff --git a/kobo/hub/views.py b/kobo/hub/views.py index 8e4cbd1c..f3082aa9 100644 --- a/kobo/hub/views.py +++ b/kobo/hub/views.py @@ -19,7 +19,7 @@ from django.urls import reverse else: from django.core.urlresolvers import reverse -from django.http import HttpResponse, StreamingHttpResponse, HttpResponseForbidden +from django.http import HttpResponse, HttpResponseNotFound, StreamingHttpResponse, HttpResponseForbidden from django.shortcuts import render, get_object_or_404 from django.template import RequestContext from django.views.generic import RedirectView @@ -111,13 +111,8 @@ def get_context_data(self, **kwargs): return context -def _stream_file(file_path, offset=0): +def _stream_file(f, offset=0): """Generator that returns 1M file chunks.""" - try: - f = open(file_path, "rb") - except IOError: - return - f.seek(offset) while 1: data = f.read(1024 ** 2) @@ -149,8 +144,13 @@ def _streamed_log_response(task, log_name, offset, as_attachment): except OSError: content_len = 0 + try: + f = open(file_path, "rb") + except OSError: + return HttpResponseNotFound('Cannot find file ' + log_name) + # use _stream_file() instead of passing file object in order to improve performance - response = StreamingHttpResponse(_stream_file(file_path, offset), content_type=mimetype) + response = StreamingHttpResponse(_stream_file(f, offset), content_type=mimetype) response["Content-Length"] = content_len if as_attachment: diff --git a/tests/test_view_log.py b/tests/test_view_log.py index 5f9ac5cb..9dd514c4 100644 --- a/tests/test_view_log.py +++ b/tests/test_view_log.py @@ -134,6 +134,20 @@ def setUp(self): # for more accurate memory_profiler tests gc.collect() + def test_view_log_404(self): + """Fetching of non-existent logs should yield error 404""" + response = self.get_log('missing.htm') + self.assertEqual(response.content, b'Cannot find file missing.htm') + self.assertEqual(response.status_code, 404) + + response = self.get_log('missing.html') + self.assertEqual(response.content, b'Cannot find file missing.html') + self.assertEqual(response.status_code, 404) + + response = self.get_log('missing.tar.gz', data={'format': 'raw'}) + self.assertEqual(response.content, b'Cannot find file missing.tar.gz') + self.assertEqual(response.status_code, 404) + def test_view_zipped_small_raw(self): """Fetching a small compressed log with raw format should yield the gzip-compressed content."""