From e35ce72f1363840d7de1892c82ce278e58b23d69 Mon Sep 17 00:00:00 2001
From: Philip de Nier <philipn@rd.bbc.co.uk>
Date: Fri, 16 Feb 2024 15:05:23 +0000
Subject: [PATCH] Check None packet when setting time_base after decode

---
 av/codec/context.pyx |  3 ++-
 tests/test_decode.py | 32 ++++++++++++++++++++++++++++++++
 2 files changed, 34 insertions(+), 1 deletion(-)

diff --git a/av/codec/context.pyx b/av/codec/context.pyx
index 5a28cd69e..09cf93b0c 100644
--- a/av/codec/context.pyx
+++ b/av/codec/context.pyx
@@ -527,7 +527,8 @@ cdef class CodecContext:
         # is carrying around.
         # TODO: Somehow get this from the stream so we can not pass the
         # packet here (because flushing packets are bogus).
-        frame._time_base = packet._time_base
+        if packet is not None:
+            frame._time_base = packet._time_base
 
         frame.index = self.ptr.frame_number - 1
 
diff --git a/tests/test_decode.py b/tests/test_decode.py
index bc9c96e58..87a84ba12 100644
--- a/tests/test_decode.py
+++ b/tests/test_decode.py
@@ -124,3 +124,35 @@ def test_decode_close_then_use(self):
                     getattr(container, attr)
                 except AssertionError:
                     pass
+
+    def test_flush_decoded_video_frame_count(self):
+        container = av.open(fate_suite("h264/interlaced_crop.mp4"))
+        video_stream = next(s for s in container.streams if s.type == "video")
+
+        self.assertIs(video_stream, container.streams.video[0])
+
+        # Decode the first GOP, which requires a flush to get all frames
+        have_keyframe = False
+        input_count = 0
+        output_count = 0
+
+        for packet in container.demux(video_stream):
+            if packet.is_keyframe:
+                if have_keyframe:
+                    break
+                have_keyframe = True
+
+            input_count += 1
+
+            for frame in video_stream.decode(packet):
+                output_count += 1
+
+        # Check the test works as expected and requires a flush
+        self.assertLess(output_count, input_count)
+
+        for frame in video_stream.decode(None):
+            # The Frame._time_base is not set by PyAV
+            self.assertIsNone(frame.time_base)
+            output_count += 1
+
+        self.assertEqual(output_count, input_count)