Skip to content

Commit

Permalink
convert indices to int8 before stacking
Browse files Browse the repository at this point in the history
  • Loading branch information
avantikalal committed Jul 16, 2024
1 parent c5f3c14 commit 7b20023
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/grelu/sequence/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,11 @@ def strings_to_indices(
strings
), "All input sequences must have the same length."
return np.stack(
[[BASE_TO_INDEX_HASH[base] for base in string] for string in strings]
).astype(np.int8)
[
np.array([BASE_TO_INDEX_HASH[base] for base in string], dtype(np.int8))
for string in strings
]
)


def indices_to_one_hot(indices: np.ndarray) -> Tensor:
Expand Down

0 comments on commit 7b20023

Please sign in to comment.