diff --git a/samgeo/samgeo2.py b/samgeo/samgeo2.py index b1b29e74..7bc2211b 100644 --- a/samgeo/samgeo2.py +++ b/samgeo/samgeo2.py @@ -753,6 +753,9 @@ def predict_by_points( else: raise ValueError("Please set the input image first using set_image().") + if isinstance(point_coords_batch, dict): + point_coords_batch = gpd.GeoDataFrame.from_features(point_coords_batch) + if isinstance(point_coords_batch, str) or isinstance( point_coords_batch, gpd.GeoDataFrame ): @@ -774,13 +777,16 @@ def predict_by_points( elif isinstance(point_coords_batch, list): if point_crs is not None: - point_coords_batch = common.coords_to_xy( + point_coords_batch_crs = common.coords_to_xy( self.source, point_coords_batch, point_crs ) + else: + point_coords_batch_crs = point_coords_batch + num_points = len(point_coords_batch) + points = [] - points.append([[point] for point in point_coords_batch]) + points.append([[point] for point in point_coords_batch_crs]) - num_points = len(points) if point_labels_batch is None: labels = np.array([[1] for i in range(num_points)]) elif isinstance(point_labels_batch, list): @@ -790,7 +796,7 @@ def predict_by_points( labels = point_labels_batch points = np.array(points[0]) - labels = np.array(labels[0]) + labels = np.array(labels) elif isinstance(point_coords_batch, np.ndarray): points = point_coords_batch @@ -836,6 +842,7 @@ def predict_by_points( foreground=True, unique=unique, mask_multiplier=mask_multiplier, + dtype=dtype, **kwargs, )