From d2b7ce7627390c3068e88b2d8cce1156878f35d6 Mon Sep 17 00:00:00 2001 From: Maximilian Linhoff Date: Tue, 11 Jul 2023 11:21:43 +0200 Subject: [PATCH] Make ctaplot optional --- lstchain/mc/plot_utils.py | 5 +++- .../scripts/benchmarks/charge_benchmark.py | 18 +++++++++----- lstchain/scripts/lstchain_mc_sensitivity.py | 14 +++++++---- lstchain/visualization/plot_dl2.py | 24 +++++++++++++++++-- setup.py | 3 +-- 5 files changed, 49 insertions(+), 15 deletions(-) diff --git a/lstchain/mc/plot_utils.py b/lstchain/mc/plot_utils.py index a0a2b88a94..9db0e71d8c 100644 --- a/lstchain/mc/plot_utils.py +++ b/lstchain/mc/plot_utils.py @@ -2,7 +2,6 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd -from ctaplot.plots import plot_sensitivity_magic_performance from matplotlib.colors import LogNorm from pyirf.spectral import CRAB_MAGIC_JHEAP2015 from astropy.visualization import quantity_support @@ -304,6 +303,10 @@ def sensitivity_plot_comparison(energy, sensitivity, ax=None): ------- fig_sens: `matplotlib.pyplot.figure` Figure containing sensitivity plot """ + try: + from ctaplot.plots import plot_sensitivity_magic_performance + except ModuleNotFoundError: + raise ModuleNotFoundError("Please install ctaplot: pip install ctaplot") # Final sensitivity plot ax = plt.gca() if ax is None else ax diff --git a/lstchain/scripts/benchmarks/charge_benchmark.py b/lstchain/scripts/benchmarks/charge_benchmark.py index 9b2c39c0d5..d1c386ebb5 100644 --- a/lstchain/scripts/benchmarks/charge_benchmark.py +++ b/lstchain/scripts/benchmarks/charge_benchmark.py @@ -4,15 +4,9 @@ import sys from pathlib import Path -import ctaplot import matplotlib.pyplot as plt import tables from astropy.table import Table -from ctaplot.plots.calib import ( - plot_charge_resolution, - plot_photoelectron_true_reco, - plot_pixels_pe_spectrum, -) from matplotlib.backends.backend_pdf import PdfPages from lstchain.io.config import ( @@ -49,6 +43,18 @@ def main(): + try: + import ctaplot + except ModuleNotFoundError: + print("ctaplot is needed for this script, please install using `pip install ctaplot`", file=sys.stderr) + sys.exit(1) + + from ctaplot.plots.calib import ( + plot_charge_resolution, + plot_photoelectron_true_reco, + plot_pixels_pe_spectrum, + ) + ctaplot.set_style() output_dir = args.output_dir.absolute() diff --git a/lstchain/scripts/lstchain_mc_sensitivity.py b/lstchain/scripts/lstchain_mc_sensitivity.py index 60c21e5677..38fc0d3785 100644 --- a/lstchain/scripts/lstchain_mc_sensitivity.py +++ b/lstchain/scripts/lstchain_mc_sensitivity.py @@ -13,13 +13,12 @@ --o /output/path """ - +import sys import argparse import os import warnings import astropy.units as u -import ctaplot import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -37,8 +36,6 @@ warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=RuntimeWarning) -ctaplot.set_style() - parser = argparse.ArgumentParser(description="Compute MC sensitivity curve.") parser.add_argument('--input-file-gamma-dl2', '--gdl2', type=str, @@ -57,6 +54,15 @@ def main(): args = parser.parse_args() + + try: + import ctaplot + except ModuleNotFoundError: + print("ctaplot is needed for this script, please install using `pip install ctaplot`", file=sys.stderr) + sys.exit(1) + + ctaplot.set_style() + ntelescopes_gamma = 1 n_bins_energy = 20 # Number of energy bins diff --git a/lstchain/visualization/plot_dl2.py b/lstchain/visualization/plot_dl2.py index 7fe0e547f2..54fa7aa1e4 100644 --- a/lstchain/visualization/plot_dl2.py +++ b/lstchain/visualization/plot_dl2.py @@ -6,7 +6,6 @@ import os import astropy.units as u -import ctaplot import joblib import matplotlib import matplotlib.pyplot as plt @@ -176,8 +175,12 @@ def energy_results(dl2_data, points_outfile=None, plot_outfile=None): ------- fig, axes: `matplotlib.pyplot.figure`, `matplotlib.pyplot.axes` """ - fig, axes = plt.subplots(2, 2, figsize=(12, 8)) + try: + import ctaplot + except ModuleNotFoundError: + raise ModuleNotFoundError("This function needs ctaplot. Please install ctaplot: pip install ctaplot") + fig, axes = plt.subplots(2, 2, figsize=(12, 8)) ctaplot.plot_energy_resolution(dl2_data.mc_energy.values * u.TeV, dl2_data.reco_energy.values * u.TeV, ax=axes[0, 0], bias_correction=False) @@ -562,6 +565,11 @@ def plot_roc_gamma(dl2_data, energy_bins=None, ax=None, **kwargs): ------- ax: `matplotlib.pyplot.axis` """ + try: + import ctaplot + except ModuleNotFoundError: + raise ModuleNotFoundError("This function needs ctaplot. Please install ctaplot: pip install ctaplot") + if energy_bins is None: ax = ctaplot.plot_roc_curve_gammaness(dl2_data.mc_type, dl2_data.gammaness, ax=ax, @@ -595,6 +603,10 @@ def plot_energy_resolution(dl2_data, ax=None, bias_correction=False, cta_req_nor ------- ax: `matplotlib.pyplot.axes` """ + try: + import ctaplot + except ModuleNotFoundError: + raise ModuleNotFoundError("This function needs ctaplot. Please install ctaplot: pip install ctaplot") ax = ctaplot.plot_energy_resolution(dl2_data.mc_energy.values * u.TeV, dl2_data.reco_energy.values * u.TeV, @@ -628,6 +640,10 @@ def plot_angular_resolution(dl2_data, ax=None, bias_correction=False, cta_req_no ------- ax: `matplotlib.pyplot.axes` """ + try: + import ctaplot + except ModuleNotFoundError: + raise ModuleNotFoundError("This function needs ctaplot. Please install ctaplot: pip install ctaplot") ax = ctaplot.plot_angular_resolution_per_energy(dl2_data.mc_alt.values * u.rad, dl2_data.reco_alt.values * u.rad, @@ -660,6 +676,10 @@ def direction_results(dl2_data, points_outfile=None, plot_outfile=None): ------- fig, axes: `matplotlib.pyplot.figure`, `matplotlib.pyplot.axes` """ + try: + import ctaplot + except ModuleNotFoundError: + raise ModuleNotFoundError("This function needs ctaplot. Please install ctaplot: pip install ctaplot") fig, axes = plt.subplots(2, 2, figsize=(15, 12)) diff --git a/setup.py b/setup.py index f890ce3b2e..2a807053f6 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,6 @@ def find_scripts(script_dir, prefix): 'bokeh~=2.0', 'ctapipe~=0.19.2', 'ctapipe_io_lst~=0.22.0', - 'ctaplot~=0.6.2', 'eventio>=1.9.1,<2.0.0a0', # at least 1.1.1, but not 2 'gammapy~=1.1', 'h5py', @@ -66,7 +65,7 @@ def find_scripts(script_dir, prefix): 'jinja2~=3.0.2', # pinned for bokeh 1.0 compatibility ], extras_require={ - "all": tests_require + docs_require, + "all": tests_require + docs_require + ["ctaplot~=0.6.2"], "tests": tests_require, "docs": docs_require, },