Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
iris-garden committed Jul 22, 2024
1 parent af61085 commit 6a8d038
Show file tree
Hide file tree
Showing 6 changed files with 1,629 additions and 0 deletions.
1,352 changes: 1,352 additions & 0 deletions hail/python/hail/docs/tutorials/ggplot.ipynb

Large diffs are not rendered by default.

25 changes: 25 additions & 0 deletions hail/python/hail/ggplot2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typeguard import install_import_hook

from .ggplot2 import (
aes,
extend,
geom_histogram,
geom_line,
geom_point,
ggplot,
show,
undo,
)

install_import_hook("hail.ggplot2")

__all__ = [
"aes",
"extend",
"geom_point",
"geom_line",
"geom_histogram",
"ggplot",
"undo",
"show",
]
210 changes: 210 additions & 0 deletions hail/python/hail/ggplot2/ggplot2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
from dataclasses import asdict, replace
from textwrap import dedent, indent
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

from altair import X2, Chart, LayerChart, X, Y
from pandas import DataFrame

import hail as hl
from hail import MatrixTable, Table
from hail.ggplot.utils import typeguard_dataclass

"""
how are we going to cache stats/any other transformations to the data?
id(data), we can cache all transformations on that with a compound, ordered key of stats
for example, (stat1, stat2, stat3) would be distinct from (stat2, stat3, stat1)
and we can get as much of the beginning of that path as possible from the cache, so if we have (stat1, stat2, stat3) in the cache, and attempt to do (stat1, stat2, stat4, stat3), we can get the first two steps from that cache
and in fact we will need to store each new path as its own thing, so if we do stat1 then stat2 then stat3 then roll it back and do stat1 then stat2 then stat4, we should end up with the following keys in the cache under the key for the original data:
()
(stat1)
(stat1, stat2)
(stat1, stat2, stat3)
(stat1, stat2, stat4)
how does the user interact with the cache? if we undo the addition of a geom or stat, we can roll it back, and if they reapply the same stat, we pull from it
no caching ir, if the user wants to use the caching, they should make a stat subclass
expose a function so they can do that
and then the way to change it within the plot object is to use stat_function or whatever to supply your own thing that returns a table
value should be (hail agg, dataframe)
do this for identity too so we cache the df for it
"""


### types ###
Data = Union[Table, MatrixTable]


@typeguard_dataclass
class Mapping:
x: Optional[str]
y: Optional[str]
# TODO add the rest of the supported aesthetic names
color: Optional[str]


Geom = Literal["bar", "line", "circle"]
Stat = Literal["identity", "bin"]


@typeguard_dataclass
class Layer:
mapping: Mapping
data: Optional[Data]
geom: Optional[Geom]
stat: Stat
# FIXME if there's only one type per param name we can make this a typeddict
params: Dict[str, Any]


@typeguard_dataclass
class Plot:
data: Optional[Data]
mapping: Mapping
layers: list[Layer]


### module-level variables ###
_plot_cache: Dict[int, List[Plot]] = {}
_stat_cache: Dict[Tuple[int, ...], Data] = {}


### constructor functions ###
def aes(x: Optional[str] = None, y: Optional[str] = None, color: Optional[str] = None) -> Mapping:
return Mapping(x, y, color)


def geom_histogram(mapping: Mapping = aes(), data: Optional[Data] = None, bins: int = 30) -> Layer:
return Layer(mapping, data, "bar", "bin", {"bins": bins})


def geom_line(mapping: Mapping = aes(), data: Optional[Data] = None) -> Layer:
return Layer(mapping, data, "line", "identity", {})


def geom_point(mapping: Mapping = aes(), data: Optional[Data] = None) -> Layer:
return Layer(mapping, data, "circle", "identity", {})


def ggplot(data: Optional[Data] = None, mapping: Mapping = aes()) -> Plot:
global _plot_cache
new_plot = Plot(data, mapping, [])
_plot_cache |= {id(new_plot): []}
return new_plot


### functionality ###
def extend(plot: Plot, other: Any) -> Plot:
global _plot_cache
kwargs: Optional[Dict[str, Any]] = None
if isinstance(other, Mapping):
kwargs = {
"mapping": replace(
plot.mapping,
**{k: v for k, v in {"x": other.x, "y": other.y, "color": other.color}.items() if v is not None},
)
}
elif isinstance(other, Layer):
kwargs = {"layers": [*plot.layers, other]}

if kwargs is None:
raise ValueError("unsupported addition to plot")

new_plot = replace(plot, **kwargs)
_plot_cache |= {id(new_plot): _plot_cache[id(plot)] + [plot]}
_plot_cache = {k: v for k, v in _plot_cache.items() if k != id(plot)}
return new_plot


setattr(Plot, "__add__", extend)


_altair_configure_mark_keys = {"color"}
_altair_encode_keys = {"x": X, "x2": X2, "y": Y}


def show(plot: Plot) -> Union[Chart, LayerChart]:
global _stat_cache
base_chart = None
for layer in plot.layers:
mapping_dict = {}
for mapping in [plot.mapping, layer.mapping]:
mapping_dict = {**mapping_dict, **{k: v for k, v in asdict(mapping).items() if v is not None}}
# TODO should we break the stat stuff out to its own function?
kwargs = {"x": {}, "x2": {}, "y": {}}
cached = _stat_cache.get((id(plot.data), layer.stat), None)
if cached is not None:
data, df = cached
elif layer.stat == "identity":
data = plot.data
df = data.to_pandas()
elif layer.stat == "bin":
# TODO add caching
x = mapping_dict.get("x", None)
if x is None:
raise ValueError("x must be supplied for stat bin")
data = plot.data.aggregate(
hl.agg.hist(
plot.data[x],
plot.data.aggregate(hl.agg.min(plot.data[x])),
plot.data.aggregate(hl.agg.max(plot.data[x])),
layer.params["bins"],
)
)
df = DataFrame([
{x: data["bin_edges"][i], "x2": data["bin_edges"][i + 1], "y": data["bin_freq"][i]}
for i in range(len(data["bin_freq"]))
])
kwargs["x"] = {"bin": "binned"}
mapping_dict["x2"] = "x2"
mapping_dict["y"] = "y"
else:
raise ValueError("unknown stat")
_stat_cache |= {(id(plot.data), layer.stat): (data, df)}
chart = Chart(df)
if layer.geom is not None:
chart = getattr(chart, f"mark_{layer.geom}")(**{
k: v for k, v in mapping_dict.items() if k in _altair_configure_mark_keys
})
chart = chart.encode(**{
k: _altair_encode_keys[k](v, **kwargs[k]) for k, v in mapping_dict.items() if k in _altair_encode_keys
})
base_chart = chart if base_chart is None else base_chart + chart
return base_chart


def undo(plot: Plot, *, depth: int = 1) -> Plot:
global _plot_cache
old_plot = _plot_cache[id(plot)][0 - depth]
_plot_cache |= {id(old_plot): _plot_cache[id(plot)][: 0 - depth]}
_plot_cache = {k: v for k, v in _plot_cache.items() if k != id(plot)}
return old_plot


def plot_to_string(plot: Plot) -> str:
return dedent(f"""\
Plot(
data = {plot.data},
mapping = {indent_tail(str(plot.mapping), 3)},
layers = {indent_tail(str(plot.layers), 3)},
)""")


def indent_tail(string: str, indent_level: int = 1) -> str:
return "".join([
indent(part, " " * indent_level) if index == 2 else part for index, part in enumerate(string.partition("\n"))
])


setattr(Plot, "__str__", plot_to_string)


def mapping_to_string(mapping: Mapping) -> str:
return dedent(f"""\
Mapping(
x = {mapping.x},
y = {mapping.y},
)""")


setattr(Mapping, "__str__", mapping_to_string)
38 changes: 38 additions & 0 deletions hail/python/hail/ggplot2/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from dataclasses import dataclass, fields
from functools import wraps
from typing import Any, Callable, TypeVar, Union

from typeguard import check_type

ReturnType = TypeVar("ReturnType")
WrappedDecorator = Callable[[ReturnType], ReturnType]


def typeguard_dataclass(cls: ReturnType = None, /, **kwargs: Any) -> Union[ReturnType, WrappedDecorator]:
"""
Creates a `dataclass` that is `frozen` by default and has runtime typechecking for its fields.
"""

@wraps(dataclass)
def wrapper(cls: ReturnType) -> ReturnType:
def __setattr__(obj: ReturnType, name: str, value: Any) -> None:
if len(types := [_field.type for _field in fields(obj) if _field.name == name]) == 0:
raise TypeError(f"'{getattr(cls, '__name__', str(cls))}' has no field '{name}'.")
super().__setattr__(name, check_type(value, types[0]))

def __post_init__(obj: ReturnType) -> None:
for _field in fields(obj):
check_type(getattr(obj, _field.name), _field.type)

setattr(
cls,
*(
["__post_init__", __post_init__]
if (frozen := kwargs.get("frozen", True))
else ["__setattr__", __setattr__]
),
)
dataclass(cls, frozen=frozen, **{k: v for k, v in kwargs.items() if k != "frozen"})
return cls

return wrapper if cls is None else wrapper(cls)
2 changes: 2 additions & 0 deletions hail/python/pinned-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# This file was autogenerated by uv via the following command:
# uv pip compile --python-version 3.9 --python-platform linux hail/python/requirements.txt --output-file=hail/python/pinned-requirements.txt
altair==5
typeguard==4
aiodns==2.0.0
# via
# -c hail/python/hailtop/pinned-requirements.txt
Expand Down
2 changes: 2 additions & 0 deletions hail/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
-c dataproc-pre-installed-requirements.txt
-r hailtop/requirements.txt

altair==5
typeguard==4
avro>=1.10,<1.12
bokeh>=3,<3.4
decorator<5
Expand Down

0 comments on commit 6a8d038

Please sign in to comment.