Skip to content

Commit

Permalink
python 3.11 and numpy compatibility changes
Browse files Browse the repository at this point in the history
  • Loading branch information
NicoNeureiter committed May 10, 2023
1 parent 0420520 commit 4a99811
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 9 deletions.
7 changes: 2 additions & 5 deletions sbayes/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,7 @@ def read_dictionary(dataframe, search_key):
param_dict = {}
for column_name in dataframe.columns:
if column_name.startswith(search_key):
param_dict[column_name] = dataframe[column_name].to_numpy(
dtype=np.float
)

param_dict[column_name] = dataframe[column_name].to_numpy(dtype=float)
return param_dict

def parse_weights(self, parameters: pd.DataFrame) -> dict[str, NDArray]:
Expand All @@ -218,7 +215,7 @@ def parse_weights(self, parameters: pd.DataFrame) -> dict[str, NDArray]:
weights = {}
for f in self.feature_names:
weights[f] = np.column_stack(
[parameters[f"w_{c}_{f}"].to_numpy(dtype=np.float) for c in components]
[parameters[f"w_{c}_{f}"].to_numpy(dtype=float) for c in components]
)

return weights
Expand Down
4 changes: 2 additions & 2 deletions sbayes/sampling/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def write_header(self, sample: Sample):
self.match_clusters = False

# Initialize cluster_sum array for matching
self.cluster_sum = np.zeros((sample.n_clusters, sample.n_objects), dtype=np.int)
self.cluster_sum = np.zeros((sample.n_clusters, sample.n_objects), dtype=int)

# Cluster sizes
for i in range(sample.n_clusters):
Expand Down Expand Up @@ -233,7 +233,7 @@ def write_header(self, sample: Sample):
# Nothing to match
self.match_clusters = False

self.cluster_sum = np.zeros((sample.n_clusters, sample.n_objects), dtype=np.int)
self.cluster_sum = np.zeros((sample.n_clusters, sample.n_objects), dtype=int)

def _write_sample(self, sample):
if self.match_clusters:
Expand Down
5 changes: 3 additions & 2 deletions sbayes/sampling/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,7 @@ def propose_new_sources(
n_features = sample_old.n_features

MODE = "gibbs"

if MODE == "gibbs":
sample_new, log_q, log_q_back = self.gibbs_sample_source(
sample_new, sample_old, object_subset=changed_objects
Expand Down Expand Up @@ -1025,7 +1026,7 @@ def grow_cluster(self, sample: Sample) -> tuple[Sample, float, float]:
return sample, 0, -np.inf

# Choose a random candidate and add it to the cluster
object_add = random.choice(candidates.nonzero()[0])
object_add = np.random.choice(candidates.nonzero()[0])
sample_new.clusters.add_object(z_id, object_add)

# Transition probability when growing
Expand Down Expand Up @@ -1068,7 +1069,7 @@ def shrink_cluster(self, sample: Sample) -> tuple[Sample, float, float]:

# Cluster is big enough: shrink
removal_candidates = self.get_removal_candidates(cluster_current)
object_remove = random.choice(removal_candidates)
object_remove = np.random.choice(removal_candidates)
sample_new.clusters.remove_object(z_id, object_remove)

# Transition probability when shrinking.
Expand Down

0 comments on commit 4a99811

Please sign in to comment.