Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Python libraries to work with current Scala libraries in dataframes #113

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added python/README.md
Empty file.
20 changes: 0 additions & 20 deletions python/__init__.py

This file was deleted.

9 changes: 6 additions & 3 deletions python/magellan/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

from pyspark import SparkContext
from pyspark.sql.column import Column, _create_column_from_literal
from pyspark.sql.functions import col as _col

def _bin_op(name, doc="binary operator"):
""" Create a method for given binary operator
"""
def _(col, other):
if isinstance(other, str):
other = _col(other)
jc = other._jc if isinstance(other, Column) else other
jcol = col._jc
sc = SparkContext._active_spark_context
Expand Down Expand Up @@ -56,9 +59,9 @@ def _(col, other):
_.__doc__ = doc
return _

within = _bin_op("magellan.catalyst.Within")
intersects = _bin_op("magellan.catalyst.Intersects")
transform = _unary_op("magellan.catalyst.Transformer")
within = _bin_op("org.apache.spark.sql.types.Within")
intersects = _bin_op("org.apache.spark.sql.types.Intersects")
transform = _unary_op("org.apache.spark.sql.types.Transformer")
Column.within = within
Column.intersects = intersects
Column.transform = transform
Expand Down
27 changes: 27 additions & 0 deletions python/magellan/dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from pyspark.sql import DataFrame
from spylon.spark.utils import SparkJVMHelpers


def _immutable_scala_map(jvm_helpers, dict_like):
jvm_helpers.to_scala_map(dict_like).toMap(jvm_helpers.jvm.scala.Predef.conforms())


def index(df, precision):
jvm_helpers = SparkJVMHelpers(df._sc)
jdf = df._jdf
sparkSession = jdf.sparkSession()
SpatialJoinHint = df._sc._jvm.magellan.catalyst.SpatialJoinHint
Dataset = df._sc._jvm.org.apache.spark.sql.Dataset

new_jdf = Dataset(
sparkSession,
SpatialJoinHint(
jdf.logicalPlan(),
_immutable_scala_map(jvm_helpers, {"magellan.index.precision": str(precision)})
),
jdf.exprEnc())

return DataFrame(new_jdf, df.sql_ctx)


DataFrame.index = index
79 changes: 66 additions & 13 deletions python/magellan/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import json
import sys

from itertools import izip
from itertools import izip, repeat

from pyspark import SparkContext
from pyspark.sql.types import DataType, UserDefinedType, StructField, StructType, \
from pyspark.sql.types import DataType, UserDefinedType, Row, StructField, StructType, \
ArrayType, DoubleType, IntegerType

__all__ = ['Point']
Expand Down Expand Up @@ -69,14 +69,16 @@ def scalaUDT(cls):
"""
The class name of the paired Scala UDT.
"""
return "magellan.PointUDT"
return "org.apache.spark.sql.types.PointUDT"

def serialize(self, obj):
"""
Converts the a user-type object into a SQL datum.
"""
if isinstance(obj, Point):
return obj
#pnt = Row(IntegerType(), DoubleType(),DoubleType(),DoubleType(),DoubleType(),DoubleType(),DoubleType() )
#return pnt("udt",obj.x,obj.y,obj.x,obj.y,obj.x,obj.y)
return 1,obj.x, obj.y,obj.x, obj.y,obj.x, obj.y
else:
raise TypeError("cannot serialize %r of type %r" % (obj, type(obj)))

Expand All @@ -87,9 +89,9 @@ def deserialize(self, datum):
if isinstance(datum, Point):
return datum
else:
assert len(datum) == 2, \
"PointUDT.deserialize given row with length %d but requires 2" % len(datum)
return Point(datum[0], datum[1])
assert len(datum) == 7, \
"PointUDT.deserialize given row with length %d but requires 7" % len(datum)
return Point(datum[5], datum[6])

def simpleString(self):
return 'point'
Expand Down Expand Up @@ -173,14 +175,19 @@ def scalaUDT(cls):
"""
The class name of the paired Scala UDT.
"""
return "magellan.PolygonUDT"
return "org.apache.spark.sql.types.PolygonUDT"

def serialize(self, obj):
"""
Converts the a user-type object into a SQL datum.
"""
if isinstance(obj, Polygon):
return obj
x_list = []
y_list = []
for p in obj.points:
x_list.append(p.x)
y_list.append(p.y)
return 5, min(x_list), min(y_list), max(x_list), max(y_list), obj.indices, x_list, y_list
else:
raise TypeError("cannot serialize %r of type %r" % (obj, type(obj)))

Expand All @@ -191,9 +198,9 @@ def deserialize(self, datum):
if isinstance(datum, Polygon):
return datum
else:
assert len(datum) == 2, \
assert len(datum) == 8, \
"PolygonUDT.deserialize given row with length %d but requires 2" % len(datum)
return Polygon(datum[0], [self.pointUDT.deserialize(point) for point in datum[1]])
return Polygon(datum[5], [self.pointUDT.deserialize(point) for point in zip(repeat(1), datum[6], datum[7], datum[6], datum[7], datum[6], datum[7])])

def simpleString(self):
return 'polygon'
Expand Down Expand Up @@ -223,6 +230,20 @@ class Polygon(Shape):

def __init__(self, indices = [], points = []):
self.indices = indices
self.xcoordinates = [p.x for p in points]
self.ycoordinates = [p.y for p in points]
if points:
self.xmin = min(self.xcoordinates)
self.ymin = min(self.ycoordinates)
self.xmax = max(self.xcoordinates)
self.ymax = max(self.ycoordinates)
else:
self.xmin = None
self.ymin = None
self.xmax = None
self.ymax = None
self.boundingBox = BoundingBox(self.xmin, self.ymin, self.xmax, self.ymax)
self.size = len(points)
self.points = points

def __str__(self):
Expand Down Expand Up @@ -289,14 +310,19 @@ def scalaUDT(cls):
"""
The class name of the paired Scala UDT.
"""
return "magellan.PolyLineUDT"
return "org.apache.spark.sql.types.PolyLineUDT"

def serialize(self, obj):
"""
Converts the a user-type object into a SQL datum.
"""
if isinstance(obj, PolyLine):
return obj
x_list = []
y_list = []
for p in obj.points:
x_list.append(p.x)
y_list.append(p.y)
return 3, min(x_list), min(y_list), max(x_list), max(y_list), obj.indices, x_list, y_list
else:
raise TypeError("cannot serialize %r of type %r" % (obj, type(obj)))

Expand Down Expand Up @@ -335,6 +361,20 @@ class PolyLine(Shape):

def __init__(self, indices = [], points = []):
self.indices = indices
self.xcoordinates = [p.x for p in points]
self.ycoordinates = [p.y for p in points]
if points:
self.xmin = min(self.xcoordinates)
self.ymin = min(self.ycoordinates)
self.xmax = max(self.xcoordinates)
self.ymax = max(self.ycoordinates)
else:
self.xmin = None
self.ymin = None
self.xmax = None
self.ymax = None
self.boundingBox = BoundingBox(self.xmin, self.ymin, self.xmax, self.ymax)
self.size = len(points)
self.points = points

def __str__(self):
Expand Down Expand Up @@ -385,3 +425,16 @@ def _inbound_shape_converter(json_string):
def _create_row_inbound_converter(dataType):
return lambda *a: dataType.fromInternal(a)

class BoundingBox(object):

def __init__(self,xmin,ymin,xmax,ymax):
self.xmin = xmin
self.ymin = ymin
self.xmax = xmax
self.ymax = ymax

def intersects(self, other):
if not other.xmin >= self.xmax and other.ymax >= self.min and other.ymax <= self.ymin and other.xmax <= self.xmin:
return True
else:
return False
10 changes: 10 additions & 0 deletions python/magellan/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from spylon.spark.utils import SparkJVMHelpers


def inject_rules(spark_session):
from magellan.column import *
from magellan.dataframe import *

jvm_helpers = SparkJVMHelpers(spark_session._sc)
magellan_utils = jvm_helpers.import_scala_object('magellan.Utils')
magellan_utils.incjectRules(spark_session._jsparkSession)
3 changes: 3 additions & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
spylon
pyspark
pytest
21 changes: 21 additions & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""The setup script."""
from setuptools import setup, find_packages

requirements = [
# TODO: put package requirements here
]
setup(
name='magellan',
version='1.0.5',
description="Magellan",
long_description="Magellan",
author="harsha2010",
url='https://github.com/harsha2010/magellan',
packages=[package for package in find_packages() if package.startswith('magellan')],
include_package_data=True,
install_requires=requirements,
zip_safe=False,
keywords='magellan',
)
Empty file added python/tests/__init__.py
Empty file.
59 changes: 59 additions & 0 deletions python/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from pyspark import SparkConf, SQLContext, HiveContext
from pyspark.sql import DataFrame
import pytest
from base.spark.extensions import *

import os

pending = pytest.mark.xfail

ROOT = os.path.abspath(os.path.join(__file__, '../..'))

@pytest.fixture(scope="session")
def sparkContext():

conf = SparkConf() \
.setAppName('py.test')

sc = SparkContext(conf=conf)

# disable logging

sc.setLogLevel("OFF")

return sc


@pytest.fixture(scope="session")
def sqlContext(sparkContext):
return SQLContext(sparkContext)


@pytest.fixture(scope="session")
def hiveContext(sparkContext):
return HiveContext(sparkContext)


def dfassert(left, right, useSet=False, skipExtraColumns=False):
if not isinstance(right, DataFrame):
right = left.sql_ctx.createDataFrame(right)

if skipExtraColumns:
columns = list(set(left.columns) & set(right.columns))
left = left[columns]
right = right[columns]

assert sorted(left.columns) == sorted(right.columns)

def _orderableColumns(df):
return [col for col in df.columns if df[col].dataType.typeName() != 'array']

left = left[sorted(left.columns)]
right = right[sorted(right.columns)]

converter = set if useSet else list

orderedLeft = left.orderBy(*_orderableColumns(left)) if _orderableColumns(left) else left
orderedRight = right.orderBy(*_orderableColumns(right)) if _orderableColumns(right) else right

assert converter(orderedLeft.collect()) == converter(orderedRight.collect())
Empty file.