Skip to content

Commit

Permalink
Update predict_by_points
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Oct 20, 2024
1 parent 03015d7 commit 103f78c
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions samgeo/samgeo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,9 @@ def predict_by_points(
else:
raise ValueError("Please set the input image first using set_image().")

if isinstance(point_coords_batch, dict):
point_coords_batch = gpd.GeoDataFrame.from_features(point_coords_batch)

if isinstance(point_coords_batch, str) or isinstance(
point_coords_batch, gpd.GeoDataFrame
):
Expand All @@ -774,13 +777,16 @@ def predict_by_points(

elif isinstance(point_coords_batch, list):
if point_crs is not None:
point_coords_batch = common.coords_to_xy(
point_coords_batch_crs = common.coords_to_xy(
self.source, point_coords_batch, point_crs
)
else:
point_coords_batch_crs = point_coords_batch
num_points = len(point_coords_batch)

points = []
points.append([[point] for point in point_coords_batch])
points.append([[point] for point in point_coords_batch_crs])

num_points = len(points)
if point_labels_batch is None:
labels = np.array([[1] for i in range(num_points)])
elif isinstance(point_labels_batch, list):
Expand All @@ -790,7 +796,7 @@ def predict_by_points(
labels = point_labels_batch

points = np.array(points[0])
labels = np.array(labels[0])
labels = np.array(labels)

elif isinstance(point_coords_batch, np.ndarray):
points = point_coords_batch
Expand Down Expand Up @@ -836,6 +842,7 @@ def predict_by_points(
foreground=True,
unique=unique,
mask_multiplier=mask_multiplier,
dtype=dtype,
**kwargs,
)

Expand Down

0 comments on commit 103f78c

Please sign in to comment.