Skip to content

Commit

Permalink
Finish the parser prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Nov 15, 2024
1 parent 9f7ba5f commit 20e1a15
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions src/covvfit/_quasimultinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,11 +438,16 @@ class _ProblemData(NamedTuple):
)


def _is_scalar(value) -> bool:
return not hasattr(value, "__len__")


def _create_padded_array(
values: _OverDispersionType,
expected_lengths: list[int],
padding_length: int,
padding_value: float,
_out_dtype=float,
) -> Float[Array, "n_cities padding_length"]:
n_cities = len(expected_lengths)
if n_cities < 1:
Expand All @@ -452,24 +457,34 @@ def _create_padded_array(
f"Maximum length is {max(expected_lengths)}, which is greater than the padding {padding_length}."
)

out_array = jnp.full(shape=(n_cities, padding_length), fill_value=padding_value)
out_array = jnp.full(
shape=(n_cities, padding_length), fill_value=padding_value, dtype=_out_dtype
)

# First case: `values` argument is a single number (not an iterable)
if not hasattr(values, "__len__"):
if _is_scalar(values):
for i, length in enumerate(expected_lengths):
out_array = out_array.at[i, :length].set(values)
return out_array

# Second case: `values` argument is an iterable
# Second case: `values` argument is not a scalar, but rather an iterable:
if len(values) != n_cities:
raise ValueError(
f"Provided list has length {len(values)} rather than {n_cities}."
)

for i, (value, exp_len) in enumerate(zip(values, expected_lengths)):
pass

raise NotImplementedError
if _is_scalar(value): # For this city we have constant value provided
out_array = out_array.at[i, :exp_len].set(value)
else: # We have a vector of values provided
if len(value) != exp_len:
raise ValueError(
f"For {i}th component the provided array has length {len(value)} rather than {exp_len}."
)
vals = jnp.asarray(value, dtype=out_array.dtype)
out_array = out_array.at[i, :exp_len].set(vals)

return out_array


def _validate_and_pad(
Expand Down

0 comments on commit 20e1a15

Please sign in to comment.