Skip to content

Commit

Permalink
Add helper functions save_npz and load_npz
Browse files Browse the repository at this point in the history
  • Loading branch information
sindre0830 committed Oct 16, 2023
1 parent 712b42f commit e4ecfac
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions source/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import matplotlib.ticker as ticker
import tqdm
import collections
import scipy.sparse


def load_config(filepath: str) -> argparse.Namespace:
Expand Down Expand Up @@ -188,3 +189,14 @@ def plot_frequency_distribution(corpus, data_directory: str):
# save plot
save_plot(filepath)
plt.close()


def save_npz(filepath: str, x):
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, "wb") as file:
scipy.sparse.save_npz(file, x)


def load_npz(filepath: str):
with open(filepath, "rb") as file:
return scipy.sparse.load_npz(file)

0 comments on commit e4ecfac

Please sign in to comment.