diff --git a/Cargo.lock b/Cargo.lock index fee7506..d51b288 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -677,7 +677,7 @@ dependencies = [ [[package]] name = "polars-h3" -version = "0.4.3" +version = "0.4.4" dependencies = [ "h3o", "polars", diff --git a/Cargo.toml b/Cargo.toml index 18c10b8..6bed3d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "polars-h3" -version = "0.4.3" +version = "0.4.4" edition = "2021" [lib] @@ -16,6 +16,7 @@ h3o = "0.4.0" rayon = "1.10.0" [profile.release] +panic = "unwind" lto = true diff --git a/benchmarks/engine.py b/benchmarks/engine.py index 3d776b2..eaac250 100644 --- a/benchmarks/engine.py +++ b/benchmarks/engine.py @@ -7,10 +7,11 @@ - Attempted to also benchmark H3-Pandas, but project appears to be abandoned and doesn't work with h3 >= 4.0.0. """ -import json import random +import statistics import time -from dataclasses import asdict, dataclass, field +from collections import defaultdict +from dataclasses import dataclass, field from typing import Literal import duckdb @@ -33,8 +34,10 @@ class BenchmarkResult: library: Library name: str - seconds: float + avg_seconds: float num_rows: int + num_iterations: int + std_seconds: float = 0.0 @property def num_rows_human(self) -> str: @@ -46,9 +49,10 @@ def num_rows_human(self) -> str: return f"{self.num_rows / 1_000_000:,.0f}M" def __repr__(self) -> str: - return ( - f"{self.library}::{self.name}::{self.num_rows_human} = {self.seconds:.2f}s" - ) + if self.num_iterations == 0: + return f"{self.library}::{self.name}::{self.num_rows_human} = {self.avg_seconds:.2f}s" + else: + return f"{self.library}::{self.name}::{self.num_rows_human} = {self.avg_seconds:.2f}s ± {self.std_seconds:.2f}s" @dataclass @@ -276,10 +280,11 @@ def run_all( for library in libraries: func = config["funcs"][library] - start = time.perf_counter() + perf_times = [] for _ in range(self.config.num_iterations): + start = time.perf_counter() result_df = func(df.head(num_rows)) - end = time.perf_counter() + perf_times.append(time.perf_counter() - start) if self.config.verbose: print(f"Library: {library}") @@ -290,8 +295,12 @@ def run_all( BenchmarkResult( name=func_name, library=library, # type: ignore - seconds=(end - start), + avg_seconds=statistics.mean(perf_times), + std_seconds=statistics.stdev(perf_times) + if len(perf_times) > 1 + else 0, num_rows=num_rows, + num_iterations=self.config.num_iterations, ) ) print("done...") @@ -644,19 +653,45 @@ def _get_grid_paths_py_h3(self, df: pl.DataFrame) -> pl.DataFrame: ) +def _pretty_print_avg_results(results: list[BenchmarkResult]): + by_name = defaultdict(list) + + for d in results: + by_name[d.name].append(d) + + multiples = [] + for speeds in by_name.values(): + fastest = min(v.avg_seconds for v in speeds) + for v in speeds: + multiples.append((v.library, v.avg_seconds / fastest)) + + by_lib = defaultdict(list) + for lib, mult in multiples: + by_lib[lib].append(mult) + + median_by_lib = {lib: round(statistics.median(ms), 2) for lib, ms in by_lib.items()} + avg_by_lib = {lib: round(sum(ms) / len(ms), 2) for lib, ms in by_lib.items()} + + print("\n\n======= Benchmark Final Results =======\n") + print(f"{'Library':<10} {'Median':<8} {'Average':<8}") + print("-" * 26) + for lib in median_by_lib: + print(f"{lib:<10} {median_by_lib[lib]:<8} {avg_by_lib[lib]:<8}") + + if __name__ == "__main__": fast_factor = 1 param_config = ParamConfig( resolution=9, grid_ring_distance=3, - num_iterations=1, + num_iterations=3, libraries="all", difficulty_to_num_rows={ "basic": 10_000_000 // fast_factor, "medium": 10_000_000 // fast_factor, "complex": 100_000 // fast_factor, }, - functions=["latlng_to_cell"], + # functions=["latlng_to_cell"], # verbose=True, ) benchmark = Benchmark(config=param_config) @@ -664,10 +699,15 @@ def _get_grid_paths_py_h3(self, df: pl.DataFrame) -> pl.DataFrame: prev_func = None for result in results: if prev_func != result.name: - print(f"\n{result.name}") + print(f"\n{result.name} (num_iterations={param_config.num_iterations})") prev_func = result.name print(result) + _pretty_print_avg_results(results) + + import json + from dataclasses import asdict + if param_config.functions == "all": with open("benchmarks/benchmarks-results.json", "w") as f: json.dump([asdict(r) for r in results], f, indent=2) diff --git a/polars_h3/__init__.py b/polars_h3/__init__.py index a8255de..356de8a 100644 --- a/polars_h3/__init__.py +++ b/polars_h3/__init__.py @@ -41,6 +41,7 @@ cell_area, edge_length, get_num_cells, + get_pentagons, great_circle_distance, ) from .core.traversal import ( @@ -97,6 +98,7 @@ "directed_edge_to_boundary", "great_circle_distance", "average_hexagon_area", + "get_pentagons", "cell_area", "edge_length", "average_hexagon_edge_length", diff --git a/polars_h3/core/_consts.py b/polars_h3/core/_consts.py new file mode 100644 index 0000000..40b1447 --- /dev/null +++ b/polars_h3/core/_consts.py @@ -0,0 +1,285 @@ +NUM_CELLS_BY_RESOLUTION = { + 0: 122, + 1: 842, + 2: 5882, + 3: 41162, + 4: 288122, + 5: 2016842, + 6: 14117882, + 7: 98825162, + 8: 691776122, + 9: 4842432842, + 10: 33897029882, + 11: 237279209162, + 12: 1660954464122, + 13: 11626681248842, + 14: 81386768741882, + 15: 569707381193162, +} + + +AVG_EDGE_LENGTH_M = { + 0: 1107712.591, + 1: 418676.00549, + 2: 158244.6558, + 3: 59810.85794, + 4: 22606.3794, + 5: 8544.408276, + 6: 3229.482772, + 7: 1220.629759, + 8: 461.354684, + 9: 174.3756680, + 10: 65.907807, + 11: 24.910561, + 12: 9.415526, + 13: 3.559893, + 14: 1.348575, + 15: 0.509713, +} + +AVG_AREA_M2 = { + 0: 4357449416078.3833, + 1: 609788441794.1332, + 2: 86801780398.99721, + 3: 12393434655.08816, + 4: 1770347654.491307, + 5: 252903858.1819449, + 6: 36129062.164412454, + 7: 5161293.359717191, + 8: 737327.5975944176, + 9: 105332.51342720671, + 10: 15047.50190766435, + 11: 2149.643129451879, + 12: 307.091875631606, + 13: 43.870267947282954, + 14: 6.2671811353243125, + 15: 0.895311590760579, +} + + +PENTAGONS_BY_RESOLUTION = { + 0: [ + 576636674163867647, + 576988517884755967, + 577340361605644287, + 577832942814887935, + 578219970907865087, + 578536630256664575, + 578712552117108735, + 579029211465908223, + 579416239558885375, + 579908820768129023, + 580260664489017343, + 580612508209905663, + ], + 1: [ + 581109487465660415, + 581461331186548735, + 581813174907437055, + 582305756116680703, + 582692784209657855, + 583009443558457343, + 583185365418901503, + 583502024767700991, + 583889052860678143, + 584381634069921791, + 584733477790810111, + 585085321511698431, + ], + 2: [ + 585609238802333695, + 585961082523222015, + 586312926244110335, + 586805507453353983, + 587192535546331135, + 587509194895130623, + 587685116755574783, + 588001776104374271, + 588388804197351423, + 588881385406595071, + 589233229127483391, + 589585072848371711, + ], + 3: [ + 590112357393367039, + 590464201114255359, + 590816044835143679, + 591308626044387327, + 591695654137364479, + 592012313486163967, + 592188235346608127, + 592504894695407615, + 592891922788384767, + 593384503997628415, + 593736347718516735, + 594088191439405055, + ], + 4: [ + 594615896891195391, + 594967740612083711, + 595319584332972031, + 595812165542215679, + 596199193635192831, + 596515852983992319, + 596691774844436479, + 597008434193235967, + 597395462286213119, + 597888043495456767, + 598239887216345087, + 598591730937233407, + ], + 5: [ + 599119489002373119, + 599471332723261439, + 599823176444149759, + 600315757653393407, + 600702785746370559, + 601019445095170047, + 601195366955614207, + 601512026304413695, + 601899054397390847, + 602391635606634495, + 602743479327522815, + 603095323048411135, + ], + 6: [ + 603623087690219519, + 603974931411107839, + 604326775131996159, + 604819356341239807, + 605206384434216959, + 605523043783016447, + 605698965643460607, + 606015624992260095, + 606402653085237247, + 606895234294480895, + 607247078015369215, + 607598921736257535, + ], + 7: [ + 608126687200149503, + 608478530921037823, + 608830374641926143, + 609322955851169791, + 609709983944146943, + 610026643292946431, + 610202565153390591, + 610519224502190079, + 610906252595167231, + 611398833804410879, + 611750677525299199, + 612102521246187519, + ], + 8: [ + 612630286812839935, + 612982130533728255, + 613333974254616575, + 613826555463860223, + 614213583556837375, + 614530242905636863, + 614706164766081023, + 615022824114880511, + 615409852207857663, + 615902433417101311, + 616254277137989631, + 616606120858877951, + ], + 9: [ + 617133886438375423, + 617485730159263743, + 617837573880152063, + 618330155089395711, + 618717183182372863, + 619033842531172351, + 619209764391616511, + 619526423740415999, + 619913451833393151, + 620406033042636799, + 620757876763525119, + 621109720484413439, + ], + 10: [ + 621637486065516543, + 621989329786404863, + 622341173507293183, + 622833754716536831, + 623220782809513983, + 623537442158313471, + 623713364018757631, + 624030023367557119, + 624417051460534271, + 624909632669777919, + 625261476390666239, + 625613320111554559, + ], + 11: [ + 626141085692858367, + 626492929413746687, + 626844773134635007, + 627337354343878655, + 627724382436855807, + 628041041785655295, + 628216963646099455, + 628533622994898943, + 628920651087876095, + 629413232297119743, + 629765076018008063, + 630116919738896383, + ], + 12: [ + 630644685320225279, + 630996529041113599, + 631348372762001919, + 631840953971245567, + 632227982064222719, + 632544641413022207, + 632720563273466367, + 633037222622265855, + 633424250715243007, + 633916831924486655, + 634268675645374975, + 634620519366263295, + ], + 13: [ + 635148284947595327, + 635500128668483647, + 635851972389371967, + 636344553598615615, + 636731581691592767, + 637048241040392255, + 637224162900836415, + 637540822249635903, + 637927850342613055, + 638420431551856703, + 638772275272745023, + 639124118993633343, + ], + 14: [ + 639651884574965767, + 640003728295854087, + 640355572016742407, + 640848153225986055, + 641235181318963207, + 641551840667762695, + 641727762528206855, + 642044421877006343, + 642431449969983495, + 642924031179227143, + 643275874900115463, + 643627718621003783, + ], + 15: [ + 644155484202336256, + 644507327923224576, + 644859171644112896, + 645351752853356544, + 645738780946333696, + 646055440295133184, + 646231362155577344, + 646548021504376832, + 646935049597353984, + 647427630806597632, + 647779474527485952, + 648131318248374272, + ], +} diff --git a/polars_h3/core/metrics.py b/polars_h3/core/metrics.py index 2abb914..c8d3918 100644 --- a/polars_h3/core/metrics.py +++ b/polars_h3/core/metrics.py @@ -6,6 +6,8 @@ import polars as pl from polars.plugins import register_plugin_function +from . import _consts + if TYPE_CHECKING: from polars_h3.typing import IntoExprColumn @@ -108,26 +110,8 @@ def average_hexagon_area(resolution: IntoExprColumn, unit: str = "km^2") -> pl.E if unit not in ["km^2", "m^2"]: raise ValueError("Unit must be either 'km^2' or 'm^2'") - avg_area_m2 = { - 0: 4357449416078.3833, - 1: 609788441794.1332, - 2: 86801780398.99721, - 3: 12393434655.08816, - 4: 1770347654.491307, - 5: 252903858.1819449, - 6: 36129062.164412454, - 7: 5161293.359717191, - 8: 737327.5975944176, - 9: 105332.51342720671, - 10: 15047.50190766435, - 11: 2149.643129451879, - 12: 307.091875631606, - 13: 43.870267947282954, - 14: 6.2671811353243125, - } - resolution_expr = pl.col(resolution) if isinstance(resolution, str) else resolution - area_m2 = resolution_expr.replace(avg_area_m2) + area_m2 = resolution_expr.replace(_consts.AVG_AREA_M2) return pl.when(pl.lit(unit) == "km^2").then(area_m2 / 1_000_000).otherwise(area_m2) @@ -219,27 +203,8 @@ def average_hexagon_edge_length( if unit not in ["km", "m"]: raise ValueError("Unit must be either 'km' or 'm'") - avg_edge_length_m = { - 0: 1107712.591, - 1: 418676.00549999997, - 2: 158244.6558, - 3: 59810.85794, - 4: 22606.3794, - 5: 8544.408276, - 6: 3229.482772, - 7: 1220.629759, - 8: 461.354684, - 9: 174.37566800000002, - 10: 65.907807, - 11: 24.910561, - 12: 9.415526, - 13: 3.559893, - 14: 1.348575, - 15: 0.509713, - } - resolution_expr = pl.col(resolution) if isinstance(resolution, str) else resolution - edge_length_m = resolution_expr.replace(avg_edge_length_m) + edge_length_m = resolution_expr.replace(_consts.AVG_EDGE_LENGTH_M) return ( pl.when(pl.lit(unit) == "km") @@ -277,50 +242,38 @@ def get_num_cells(resolution: IntoExprColumn) -> pl.Expr: └─────────────┴──────────┘ ``` """ - num_cells = { - 0: 122, - 1: 842, - 2: 5882, - 3: 41162, - 4: 288122, - 5: 2016842, - 6: 14117882, - 7: 98825162, - 8: 691776122, - 9: 4842432842, - 10: 33897029882, - 11: 237279209162, - 12: 1660954464122, - 13: 11626681248842, - 14: 81386768741882, - 15: 569707381193162, - } resolution_expr = pl.col(resolution) if isinstance(resolution, str) else resolution - return resolution_expr.replace(num_cells) - - -# def get_res0_cells() -> pl.Expr: -# """ -# Get all resolution 0 cells. -# """ -# return register_plugin_function( -# args=[], -# plugin_path=LIB, -# function_name="get_res0_cells", -# is_elementwise=True, -# ) - - -# def get_pentagons(resolution: HexResolution) -> pl.Expr: -# """ -# Get all pentagon cells at the given resolution. -# """ -# _assert_valid_resolution(resolution) -# return register_plugin_function( -# args=[], -# plugin_path=LIB, -# function_name="get_pentagons", -# is_elementwise=True, -# kwargs={"resolution": resolution}, -# ) + return resolution_expr.cast(pl.UInt8).replace(_consts.NUM_CELLS_BY_RESOLUTION) + + +def get_pentagons(resolution: IntoExprColumn) -> pl.Expr: + """ + Get the number of pentagons at a given resolution. + + #### Parameters + - `resolution`: IntoExprColumn + Column or expression with the H3 resolution (0 to 15). + + #### Returns + Expr + Expression returning the number of cells as an integer. + + #### Examples + ```python + >>> df = pl.DataFrame({"resolution": [5]}, schema={"resolution": pl.UInt64}) + >>> df.with_columns(count=polars_h3.get_num_cells("resolution")) + shape: (1, 2) + ┌─────────────┬──────────┐ + │ resolution │ count │ + │ --- │ --- │ + │ u64 │ i64 │ + ╞══════════════╪══════════╡ + │ 5 │ 2016842 │ + └─────────────┴──────────┘ + ``` + """ + raise NotImplementedError("Not implemented") + + # resolution_expr = pl.col(resolution) if isinstance(resolution, str) else resolution + # return resolution_expr.replace(_consts.PENTAGONS_BY_RESOLUTION) diff --git a/pyproject.toml b/pyproject.toml index 915384c..4549da4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ homepage = "https://github.com/Filimonov/polars-h3" [project] name = "polars-h3" -version = "0.4.3" +version = "0.4.4" description = "H3 bindings for Polars" readme = "README.md" requires-python = ">=3.9" diff --git a/tests/test_indexing.py b/tests/test_indexing.py index 7500a75..87294a6 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -45,6 +45,14 @@ def test_latlng_to_cell_invalid_resolution(input_lat, input_lng, resolution): ) +def test_latlng_to_cell_missing_lat_lng(): + df = pl.DataFrame({"lat": [None], "lng": [None]}) + with pytest.raises(pl.exceptions.ComputeError): + df.with_columns( + h3_cell=polars_h3.latlng_to_cell("lat", "lng", 9, return_dtype=pl.UInt64) + ) + + @pytest.mark.parametrize( "input_lat,input_lng", [ diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 81f5fe1..4bc486e 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -310,37 +310,59 @@ def test_average_hexagon_edge_length(test_params): # assert df["count"].to_list()[0] == expected_count -# def test_get_res0_cells(): -# df = pl.DataFrame({"dummy": [1]}).with_columns( -# [ -# polars_h3.get_res0_cells().alias("cells_int"), -# ] +# def test_get_pentagons(): +# dicts = ( +# pl.DataFrame({"h3_resolution": list(range(0, 16))}) +# .with_columns( +# pentagons=polars_h3.get_pentagons(pl.col("h3_resolution")).list.sort(), +# ) +# .to_dicts() # ) - -# assert len(df["cells_int"][0]) == 122 -# assert len(df["cells_str"][0]) == 122 - - -# @pytest.mark.parametrize( -# "resolution, expected_valid", -# [ -# pytest.param(-1, False, id="negative_res"), -# pytest.param(16, False, id="too_high_res"), -# pytest.param(0, True, id="valid_res_0"), -# pytest.param(5, True, id="valid_res_5"), -# ], -# ) -# def test_get_pentagons(resolution: int, expected_valid: bool): -# df = pl.DataFrame({"resolution": [resolution]}).with_columns( -# [ -# polars_h3.get_pentagons("resolution").alias("pent_int"), -# polars_h3.get_pentagons_string("resolution").alias("pent_str"), -# ] -# ) - -# if expected_valid: -# assert len(df["pent_int"][0]) == 12 # Always 12 pentagons per resolution -# assert len(df["pent_str"][0]) == 12 -# else: -# assert df["pent_int"][0] is None -# assert df["pent_str"][0] is None +# for val in dicts: +# if val["h3_resolution"] == 0: +# assert val["pentagons"] == [ +# 576636674163867647, +# 576988517884755967, +# 577340361605644287, +# 577832942814887935, +# 578219970907865087, +# 578536630256664575, +# 578712552117108735, +# 579029211465908223, +# 579416239558885375, +# 579908820768129023, +# 580260664489017343, +# 580612508209905663, +# ] +# elif val["h3_resolution"] == 7: +# assert val["pentagons"] == [ +# 608126687200149503, +# 608478530921037823, +# 608830374641926143, +# 609322955851169791, +# 609709983944146943, +# 610026643292946431, +# 610202565153390591, +# 610519224502190079, +# 610906252595167231, +# 611398833804410879, +# 611750677525299199, +# 612102521246187519, +# ] +# elif val["h3_resolution"] == 15: +# assert val["pentagons"] == [ +# 644155484202336256, +# 644507327923224576, +# 644859171644112896, +# 645351752853356544, +# 645738780946333696, +# 646055440295133184, +# 646231362155577344, +# 646548021504376832, +# 646935049597353984, +# 647427630806597632, +# 647779474527485952, +# 648131318248374272, +# ] +# else: +# assert val["h3_resolution"] in list(range(0, 16)) diff --git a/uv.lock b/uv.lock index 781e21b..c9c07c9 100644 --- a/uv.lock +++ b/uv.lock @@ -1472,7 +1472,7 @@ wheels = [ [[package]] name = "polars-h3" -version = "0.4.0" +version = "0.4.4" source = { editable = "." } dependencies = [ { name = "polars" },