Skip to content

Commit

Permalink
Set time_base for AudioResampler
Browse files Browse the repository at this point in the history
  • Loading branch information
daveisfera authored Dec 2, 2023
1 parent 8aa5fe7 commit 298b3b4
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 4 deletions.
6 changes: 5 additions & 1 deletion av/audio/resampler.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,14 @@ cdef class AudioResampler:
# handle resampling with aformat filter
# (similar to configure_output_audio_filter from ffmpeg)
self.graph = av.filter.Graph()
extra_args = {}
if frame.time_base is not None:
extra_args["time_base"] = str(frame.time_base)
abuffer = self.graph.add("abuffer",
sample_rate=str(frame.sample_rate),
sample_fmt=AudioFormat(frame.format).name,
channel_layout=frame.layout.name)
channel_layout=frame.layout.name,
**extra_args)
aformat = self.graph.add("aformat",
sample_rates=str(self.rate),
sample_fmts=self.format.name,
Expand Down
98 changes: 95 additions & 3 deletions tests/test_audioresampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_pts_assertion_same_rate(self):
oframes = resampler.resample(None)
self.assertEqual(len(oframes), 0)

def test_pts_assertion_new_rate(self):
def test_pts_assertion_new_rate_up(self):
resampler = AudioResampler("s16", "mono", 44100)

# resample one frame
Expand All @@ -131,15 +131,107 @@ def test_pts_assertion_new_rate(self):
self.assertEqual(oframe.sample_rate, 44100)
self.assertEqual(oframe.samples, 925)

iframe = AudioFrame("s16", "stereo", 1024)
iframe.sample_rate = 48000
iframe.time_base = "1/48000"
iframe.pts = 1024

oframes = resampler.resample(iframe)
self.assertEqual(len(oframes), 1)

oframe = oframes[0]
self.assertEqual(oframe.pts, 925)
self.assertEqual(oframe.time_base, Fraction(1, 44100))
self.assertEqual(oframe.sample_rate, 44100)
self.assertEqual(oframe.samples, 941)

# flush
oframes = resampler.resample(None)
self.assertEqual(len(oframes), 1)

oframe = oframes[0]
self.assertEqual(oframe.pts, 925)
self.assertEqual(oframe.pts, 941 + 925)
self.assertEqual(oframe.time_base, Fraction(1, 44100))
self.assertEqual(oframe.sample_rate, 44100)
self.assertEqual(oframe.samples, 16)
self.assertEqual(oframe.samples, 15)

def test_pts_assertion_new_rate_down(self):
resampler = AudioResampler("s16", "mono", 48000)

# resample one frame
iframe = AudioFrame("s16", "stereo", 1024)
iframe.sample_rate = 44100
iframe.time_base = "1/44100"
iframe.pts = 0

oframes = resampler.resample(iframe)
self.assertEqual(len(oframes), 1)

oframe = oframes[0]
self.assertEqual(oframe.pts, 0)
self.assertEqual(oframe.time_base, Fraction(1, 48000))
self.assertEqual(oframe.sample_rate, 48000)
self.assertEqual(oframe.samples, 1098)

iframe = AudioFrame("s16", "stereo", 1024)
iframe.sample_rate = 44100
iframe.time_base = "1/44100"
iframe.pts = 1024

oframes = resampler.resample(iframe)
self.assertEqual(len(oframes), 1)

oframe = oframes[0]
self.assertEqual(oframe.pts, 1098)
self.assertEqual(oframe.time_base, Fraction(1, 48000))
self.assertEqual(oframe.sample_rate, 48000)
self.assertEqual(oframe.samples, 1114)

# flush
oframes = resampler.resample(None)
self.assertEqual(len(oframes), 1)

oframe = oframes[0]
self.assertEqual(oframe.pts, 1114 + 1098)
self.assertEqual(oframe.time_base, Fraction(1, 48000))
self.assertEqual(oframe.sample_rate, 48000)
self.assertEqual(oframe.samples, 18)

def test_pts_assertion_new_rate_fltp(self):
resampler = AudioResampler("fltp", "mono", 8000, 1024)

# resample one frame
iframe = AudioFrame("s16", "mono", 1024)
iframe.sample_rate = 8000
iframe.time_base = "1/1000"
iframe.pts = 0

oframes = resampler.resample(iframe)
self.assertEqual(len(oframes), 1)

oframe = oframes[0]
self.assertEqual(oframe.pts, 0)
self.assertEqual(oframe.time_base, Fraction(1, 8000))
self.assertEqual(oframe.sample_rate, 8000)
self.assertEqual(oframe.samples, 1024)

iframe = AudioFrame("s16", "mono", 1024)
iframe.sample_rate = 8000
iframe.time_base = "1/1000"
iframe.pts = 8192

oframes = resampler.resample(iframe)
self.assertEqual(len(oframes), 1)

oframe = oframes[0]
self.assertEqual(oframe.pts, 65536)
self.assertEqual(oframe.time_base, Fraction(1, 8000))
self.assertEqual(oframe.sample_rate, 8000)
self.assertEqual(oframe.samples, 1024)

# flush
oframes = resampler.resample(None)
self.assertEqual(len(oframes), 0)

def test_pts_missing_time_base(self):
resampler = AudioResampler("s16", "mono", 44100)
Expand Down
108 changes: 108 additions & 0 deletions tests/test_codec_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,8 @@ def test_encoding_aac(self):
def test_encoding_mp2(self):
self.audio_encoding("mp2")

maxDiff = None

def audio_encoding(self, codec_name):
try:
codec = Codec(codec_name, "w")
Expand Down Expand Up @@ -350,17 +352,123 @@ def audio_encoding(self, codec_name):
samples = 0
packet_sizes = []

pts_expected = [
0,
1098,
2212,
3327,
4441,
5556,
6670,
7785,
8900,
10014,
11129,
12243,
13358,
14472,
15587,
16701,
17816,
18931,
20045,
21160,
22274,
23389,
24503,
25618,
26732,
27847,
28962,
30076,
31191,
32305,
33420,
34534,
35649,
36763,
37878,
38993,
40107,
41222,
42336,
43451,
44565,
45680,
46795,
47909,
49024,
50138,
51253,
52367,
53482,
54596,
55711,
56826,
57940,
59055,
60169,
61284,
62398,
63513,
64627,
65742,
66857,
67971,
69086,
70200,
71315,
72429,
73544,
74658,
75773,
76888,
78002,
79117,
80231,
81346,
82460,
83575,
84689,
85804,
86919,
88033,
89148,
90262,
91377,
92491,
93606,
94720,
95835,
96950,
98064,
99179,
100293,
101408,
]
if codec_name == "aac":
pts_expected_encoded = list((-1024 + n * 1024 for n in range(101)))
elif codec_name == "mp2":
pts_expected_encoded = list((-481 + n * 1152 for n in range(89)))
else:
pts_expected_encoded = pts_expected.copy()
with open(path, "wb") as f:
for frame in iter_frames(container, audio_stream):
resampled_frames = resampler.resample(frame)
for resampled_frame in resampled_frames:
self.assertEqual(resampled_frame.pts, pts_expected.pop(0))
self.assertEqual(resampled_frame.time_base, Fraction(1, 48000))
samples += resampled_frame.samples

for packet in ctx.encode(resampled_frame):
self.assertEqual(packet.pts, pts_expected_encoded.pop(0))
self.assertEqual(packet.time_base, Fraction(1, 48000))
packet_sizes.append(packet.size)
f.write(packet)

for packet in ctx.encode(None):
self.assertEqual(packet.pts, pts_expected_encoded.pop(0))
self.assertEqual(packet.time_base, Fraction(1, 48000))
packet_sizes.append(packet.size)
f.write(packet)

Expand Down
31 changes: 31 additions & 0 deletions tests/test_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,37 @@ def test_stream_index(self):
self.assertIs(apacket.stream, astream)
self.assertEqual(apacket.stream_index, 1)

def test_stream_audio_resample(self):
with av.open(self.sandboxed("output.mov"), "w") as output:
vstream = output.add_stream("mpeg4", 24)
vstream.pix_fmt = "yuv420p"
vstream.width = 320
vstream.height = 240

astream = output.add_stream("aac", sample_rate=8000, layout="mono")
frame_size = 512

pts_expected = [-1024, 0, 512, 1024, 1536, 2048, 2560]
pts = 0
for i in range(15):
aframe = AudioFrame("s16", "mono", samples=frame_size)
aframe.sample_rate = 8000
aframe.time_base = Fraction(1, 1000)
aframe.pts = pts
aframe.dts = pts
pts += 32
apackets = astream.encode(aframe)
if apackets:
apacket = apackets[0]
self.assertEqual(apacket.pts, pts_expected.pop(0))
self.assertEqual(apacket.time_base, Fraction(1, 8000))

apackets = astream.encode(None)
if apackets:
apacket = apackets[0]
self.assertEqual(apacket.pts, pts_expected.pop(0))
self.assertEqual(apacket.time_base, Fraction(1, 8000))

def test_set_id_and_time_base(self):
with av.open(self.sandboxed("output.mov"), "w") as output:
stream = output.add_stream("mp2")
Expand Down

0 comments on commit 298b3b4

Please sign in to comment.