Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix fad calculation for newer versions of scipy #25

Merged
merged 1 commit into from
Jan 25, 2024

Conversation

zhvng
Copy link
Contributor

@zhvng zhvng commented Jan 24, 2024

fixes fad calculation and unpins numpy and scipy dependencies

#19 pinned the numpy and scipy dependencies to fix an imaginary component error. The reason for this error is a change in scipy's linalg.sqrtm function. We can avoid this altogether by first converting the input to complex before calling sqrtm

tested these changes with the following script:

np.random.seed(42)

SAMPLE_RATE = 16000

frechet = FrechetAudioDistance(
    ckpt_dir="/home/ubuntu/cache/metrics/checkpoints",
    model_name="vggish",
    # submodel_name="630k-audioset", # for CLAP only
    sample_rate=SAMPLE_RATE,
    use_pca=False, # for VGGish only
    use_activation=False, # for VGGish only
    verbose=True,
    audio_load_worker=8,
    # enable_fusion=False, # for CLAP only
)

for traget, count, param in [("background", 10, None), ("test1", 5, 0.0001), ("test2", 5, 0.00001)]:
    os.makedirs(traget, exist_ok=True)
    frequencies = np.linspace(100, 1000, count).tolist()
    for freq in frequencies:
        samples = gen_sine_wave(freq, param=param)
        filename = os.path.join(traget, "sin_%.0f.wav" % freq)
        print("Creating: %s with %i samples." % (filename, samples.shape[0]))
        sf.write(filename, samples, SAMPLE_RATE, "PCM_24")

fad_score = frechet.score("background", "test1")
print("FAD score test 1: %.8f" % fad_score)

fad_score = frechet.score("background", "test2")
print("FAD score test 2: %.8f" % fad_score)

shutil.rmtree("background")
shutil.rmtree("test1")
shutil.rmtree("test2")

Results:

  • Before this change, with numpy==1.23.5, scipy==1.11.1
FAD score test 1: -1.00000000
FAD score test 2: -1.00000000
  • Before this change, numpy==1.23.4, scipy==1.10.1:
FAD score test 1: 11.65261423
FAD score test 2: 5.48465044
  • After the change, numpy==1.23.5, scipy==1.11.3:
FAD score test 1: 11.65261424
FAD score test 2: 5.48465045

@gudgud96
Copy link
Owner

@zhvng awesome, thanks for your contribution!

@gudgud96 gudgud96 merged commit 99ab594 into gudgud96:main Jan 25, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants