Skip to content

Commit

Permalink
Use rdata package for reading and writing files
Browse files Browse the repository at this point in the history
  • Loading branch information
trossi committed May 30, 2024
1 parent 3dfb931 commit 024588a
Showing 1 changed file with 36 additions and 44 deletions.
80 changes: 36 additions & 44 deletions hmsc/utils/export_rds_utils.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,42 @@
import numpy as np
import ujson as json
import pandas as pd
import pyreadr
import os
import rdata
import tensorflow as tf
import xarray as xr


def load_model_from_rds(rds_file_path):
def convert_to_numpy(obj):
if isinstance(obj, xr.DataArray):
return obj.to_numpy()

long_str = pyreadr.read_r(rds_file_path)
hmsc_obj = json.loads(long_str[None][None][0])

return hmsc_obj, hmsc_obj.get("hM")
if isinstance(obj, tf.Tensor):
return obj.numpy()

if isinstance(obj, dict):
new = {}
for key, value in obj.items():
new[key] = convert_to_numpy(value)
return new

if isinstance(obj, list):
new = []
for value in obj:
new.append(convert_to_numpy(value))
return new

return obj

def save_chains_postList_to_rds(postList, postList_file_path, nChains, elapsedTime=-1, flag_save_eta=True):

json_data = {chain: {} for chain in range(nChains)}
json_data["time"] = elapsedTime

for chain in range(nChains):
for i in range(len(postList[chain])):
sample_data = {}
params = postList[chain][i]

sample_data["Beta"] = params["Beta"].numpy().tolist()
sample_data["BetaSel"] = [par.numpy().tolist() for par in params["BetaSel"]]
sample_data["Gamma"] = params["Gamma"].numpy().tolist()
sample_data["iV"] = params["iV"].numpy().tolist()
sample_data["rhoInd"] = (params["rhoInd"]+1).numpy().tolist()
sample_data["sigma"] = params["sigma"].numpy().tolist()

sample_data["Lambda"] = dict(zip(np.arange(len(params["AlphaInd"])), [par.numpy().tolist() for par in params["Lambda"]]))
sample_data["Psi"] = dict(zip(np.arange(len(params["AlphaInd"])), [par.numpy().tolist() for par in params["Psi"]]))
sample_data["Delta"] = dict(zip(np.arange(len(params["AlphaInd"])), [par.numpy().tolist() for par in params["Delta"]]))
sample_data["Eta"] = dict(zip(np.arange(len(params["AlphaInd"])), [par.numpy().tolist() for par in params["Eta"]])) if flag_save_eta else None
sample_data["Alpha"] = dict(zip(np.arange(len(params["AlphaInd"])), [(par+1).numpy().tolist() for par in params["AlphaInd"]]))

if params["wRRR"] is not None:
sample_data["wRRR"] = params["wRRR"].numpy().tolist()
sample_data["PsiRRR"] = params["PsiRRR"].numpy().tolist()
sample_data["DeltaRRR"] = params["DeltaRRR"].numpy().tolist()
else:
sample_data["wRRR"] = sample_data["PsiRRR"] = sample_data["DeltaRRR"] = None

json_data[chain][i] = sample_data

json_str = json.dumps(json_data)

pyreadr.write_rds(postList_file_path, pd.DataFrame([[json_str]]), compress="gzip")
def load_model_from_rds(rds_file_path):
init_obj = rdata.read_rds(rds_file_path)
init_obj = convert_to_numpy(init_obj)
return init_obj, init_obj["hM"]


def save_chains_postList_to_rds(postList, postList_file_path, nChains, elapsedTime=-1, flag_save_eta=True):
data = {}
data["list"] = convert_to_numpy(postList)
if not flag_save_eta:
for i in range(len(data["list"])):
for j in range(len(data["list"][i])):
data["list"][i][j]["Eta"] = None
data["time"] = elapsedTime
rdata.write_rds(postList_file_path, data)

0 comments on commit 024588a

Please sign in to comment.