Skip to content

Commit

Permalink
Make get_tsyganenko_params(list) -> list
Browse files Browse the repository at this point in the history
ddasilva committed Nov 23, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent b7bedb3 commit a10ec99
Showing 4 changed files with 65 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_test.yml
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v3
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@ version = "0.1.0"
authors = [{name = "Daniel da Silva", email = "[email protected]"}]
license = {file = "LICENSE"}
readme = {file = "README.md", content-type = "text/markdown"}
requires-python = ">=3.9"
requires-python = ">=3.8"
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
60 changes: 41 additions & 19 deletions rbinvariantslib/models.py
Original file line number Diff line number Diff line change
@@ -5,11 +5,9 @@
In this module, all grids returned are in units of Re and all magnetic
fields are in units of Gauss.
"""
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import datetime, timedelta
import os
from typing import cast, Dict, List, Tuple, Union

from astropy import constants, units
from cdasws import CdasWs
@@ -159,7 +157,6 @@ def interpolate(self, point, radius=0.1):

interp_result = pv.PolyData(interp.GetOutput())
B = tuple(np.array(interp_result["B"])[0])
B = cast(Tuple[float, float, float], B)

return B

@@ -416,11 +413,13 @@ def get_tsyganenko(
"""Helper function to get one of the tsyganenko fields on an LFM grid.
Parameters
-----------
----------
model_name : {'T96', 'TS05'}
Name of the magnetic field model to use.
params : dictionary of string to array
Parameters to support Tsyganenko magnetic field mode
Input parameters to the Tsyganenko model. Keys are at minimum:
"Pdyn", "SymH", "By" and "Bz". For the TS05 model, optional keys
are "W1" through "W6".
time : datetime, no timezone
Time to support the Tsyganenko magnetic field model
x_re_sm_grid : array of shape (m, n, p)
@@ -515,7 +514,7 @@ def get_tsyganenko_on_lfm_grid(
"""Helper function to get one of the tsyganenko fields on an LFM grid.
Parameters
-----------
----------
model_name : {'T96', 'TS05'}
Name of the magnetic field model to use.
params : dictionary of string to array
@@ -564,8 +563,9 @@ def get_tsyganenko_params(times):
Returns
-------
params : dict, str to array
dictionary mapping variable to array of parameters
params : dist or list of dicts
If list of times is passed, returns list of dicts. otherwise, just
returns dict. Each dicts mapping variable to float paramters.
"""
# Massage time argument into list if it is just a single datetime
times_list = []
@@ -597,21 +597,43 @@ def get_tsyganenko_params(times):
cdas_data = {v: np.array(data_result[v]) for v in var_names + ['Epoch']}

# Interpolate Tsyganenko parameters from CDAS data
params_dict = {}
fill_value_max = 99.0

for our_col, cdas_col in col_map.items():
mask = (cdas_data[cdas_col] < 99.0) # skip fill values

if len(times_list) == 1:
if len(times_list) == 1:
params_dict = {}

for our_col, cdas_col in col_map.items():
mask = (cdas_data[cdas_col] < fill_value_max) # skip fill values
(params_dict[our_col],) = np.interp(
date2num(times_list), date2num(cdas_data['Epoch'])[mask], cdas_data[cdas_col][mask]
date2num(times_list),
date2num(cdas_data['Epoch'])[mask],
cdas_data[cdas_col][mask]
)
else:
params_dict[our_col] = np.interp(
date2num(times_list), date2num(cdas_data['Epoch'])[mask], cdas_data[cdas_col][mask]

return_value = params_dict
else:
# Interpolate into dict of arrays
dict_of_arrays = {}
for our_col, cdas_col in col_map.items():
mask = (cdas_data[cdas_col] < fill_value_max) # skip fill values
dict_of_arrays[our_col] = np.interp(
date2num(times_list),
date2num(cdas_data['Epoch'])[mask],
cdas_data[cdas_col][mask]
)

# Convert to list of dicts
return_value = []

for i in range(len(times_list)):
cur_dict = {}

for our_col in col_map.keys():
cur_dict[our_col] = float(dict_of_arrays[our_col][i])

return_value.append(cur_dict)

return params_dict
return return_value


def get_swmf_cdf_model(path, xaxis=None, yaxis=None, zaxis=None):
@@ -748,7 +770,7 @@ def get_model(model_type, path, **kwargs):
Path to file on disk
Returns
--------
-------
model : :py:class:`~MagneticFieldModel`
Grid and Magnetic field values on that grid.
"""
33 changes: 22 additions & 11 deletions rbinvariantslib/tests/test_tsyganenko_cdaweb.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,8 @@ def test_single_time():
"""Test calling get_tsyganenko_params() with a single time"""
time = datetime(2013, 10, 2, 13, 5)
params = models.get_tsyganenko_params(time)
expected = {'Pdyn': 5.460000038146973, 'SymH': -56.0, 'By': -1.0700000524520874, 'Bz': 5.579999923706055}
expected = {'Pdyn': 5.460000038146973, 'SymH': -56.0,
'By': -1.0700000524520874, 'Bz': 5.579999923706055}

for key, value in params.items():
assert isinstance(value, float)
@@ -24,14 +25,24 @@ def test_multiple_times():
datetime(2013, 10, 2, 13, 10),
datetime(2013, 10, 2, 13, 15),
]
params = models.get_tsyganenko_params(times)
expected = {
'Pdyn': np.array([5.46000004, 4.88999987, 4.80000019]),
'SymH': np.array([-56., -56., -56.]),
'By': np.array([-1.07000005, 2.43000007, -1.26999998]),
'Bz': np.array([5.57999992, 3.72000003, 4.15999985])
}

for key, value in params.items():
assert_allclose(value, expected[key], rtol=.01, atol=0.1)
params_dict_list = models.get_tsyganenko_params(times)

expected = [
{'Pdyn': 5.460000038146973, 'SymH': -56.0,
'By': -1.0700000524520874, 'Bz': 5.579999923706055},
{'Pdyn': 4.889999866485596, 'SymH': -56.0,
'By': 2.430000066757202, 'Bz': 3.7200000286102295},
{'Pdyn': 4.800000190734863, 'SymH': -56.0,
'By': -1.2699999809265137, 'Bz': 4.159999847412109}
]

assert len(params_dict_list) == len(expected), \
'Params list not expected length'

for got_dict, expected_dict in zip(params_dict_list, expected):
assert len(got_dict) == len(expected_dict), \
'Dict has unexpected number of keys'

for key, got_value in got_dict.items():
assert isinstance(got_value, float)
assert abs(got_value - expected_dict[key]) < .1

0 comments on commit a10ec99

Please sign in to comment.