Skip to content
This repository has been archived by the owner on May 1, 2024. It is now read-only.

Commit

Permalink
with rdd
Browse files Browse the repository at this point in the history
  • Loading branch information
rao-abdul-mannan committed Jun 12, 2018
1 parent 2b7b94b commit 86fd011
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 9 deletions.
40 changes: 40 additions & 0 deletions edx/analytics/tasks/common/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ManifestInputTargetMixin, convert_to_manifest_input_if_necessary, remove_manifest_target_if_exists
)
from edx.analytics.tasks.util.overwrite import OverwriteOutputMixin
from edx.analytics.tasks.util.spark_util import load_and_filter
from edx.analytics.tasks.util.url import UncheckedExternalURL, get_target_from_url, url_path_join

_file_path_to_package_meta_path = {}
Expand Down Expand Up @@ -201,6 +202,14 @@ class EventLogSelectionMixinSpark(EventLogSelectionDownstreamMixin):
description='Whether or not to process event log source directly with spark',
default=False
)
cache_rdd = luigi.BoolParameter(
description="Whether to cache rdd or not",
default=False
)
rdd_checkpoint_directory = luigi.Parameter(
description="Path to directory where rdd can be checkpointed",
default=None
)

def __init__(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -275,6 +284,37 @@ def get_event_log_dataframe(self, spark, *args, **kwargs):
)
return dataframe

def get_user_location_schema(self):
from pyspark.sql.types import StructType, StringType, IntegerType
schema = StructType().add("user_id", IntegerType(), True) \
.add("course_id", StringType(), True) \
.add("ip", StringType(), True) \
.add("timestamp", StringType(), True) \
.add("event_date", StringType(), True)

def get_dataframe(self, spark, *args, **kwargs):
from pyspark.sql.functions import to_date, udf, struct, date_format
input_source = self.get_input_source(*args)
user_location_schema = self.get_user_location_schema()
master_rdd = spark.sparkContext.union(
# filter out unwanted data as much as possible within each rdd before union
map(
lambda target: load_and_filter(spark, target.path, self.lower_bound_date_string,
self.upper_bound_date_string),
input_source
)
)
if self.rdd_checkpoint_directory:
# set checkpoint location before checkpointing
spark.sparkContext.setCheckpointDir(self.rdd_checkpoint_directory)
master_rdd.localCheckpoint()
if self.cache_rdd:
master_rdd.cache()
dataframe = spark.createDataFrame(master_rdd, schema=user_location_schema)
if 'user_id' not in dataframe.columns: # rename columns if they weren't named properly by createDataFrame
dataframe = dataframe.toDF('user_id', 'course_id', 'ip', 'timestamp', 'event_date')
return dataframe


class SparkJobTask(SparkMixin, OverwriteOutputMixin, EventLogSelectionDownstreamMixin, PySparkTask):
"""
Expand Down
16 changes: 7 additions & 9 deletions edx/analytics/tasks/insights/location_per_course.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,25 +217,23 @@ def run(self):
super(LastDailyIpAddressOfUserTaskSpark, self).run()

def spark_job(self, *args):
from edx.analytics.tasks.util.spark_util import get_event_predicate_labels, get_course_id, get_event_time_string
from edx.analytics.tasks.util.spark_util import validate_course_id
from pyspark.sql.functions import udf
from pyspark.sql.window import Window
from pyspark.sql.types import StringType
df = self.get_event_log_dataframe(self._spark, *args)
get_event_time = udf(get_event_time_string, StringType())
get_courseid = udf(get_course_id, StringType())
df = df.withColumn('course_id', get_courseid(df['context'])) \
.withColumn('timestamp', get_event_time(df['time']))
df = self.get_dataframe(self._spark, *args)
validate_courseid = udf(validate_course_id, StringType())
df = df.withColumn('course_id', validate_courseid(df['course_id']))
df.createOrReplaceTempView('location')
query = """
SELECT
timestamp, ip, user_id, course_id, dt
FROM (
SELECT
event_date as dt, context.user_id as user_id, course_id, timestamp, ip,
ROW_NUMBER() over ( PARTITION BY event_date, context.user_id, course_id ORDER BY timestamp desc) as rank
event_date as dt, user_id, course_id, timestamp, ip,
ROW_NUMBER() over ( PARTITION BY event_date, user_id, course_id ORDER BY timestamp desc) as rank
FROM location
WHERE ip <> '' AND timestamp <> '' AND context.user_id <> ''
WHERE ip <> '' AND timestamp <> '' AND user_id <> ''
) user_location
WHERE rank = 1
"""
Expand Down
52 changes: 52 additions & 0 deletions edx/analytics/tasks/util/spark_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""Support for spark tasks"""
import json
import re

import edx.analytics.tasks.util.opaque_key_util as opaque_key_util
from edx.analytics.tasks.util.constants import PredicateLabels

PATTERN_JSON = re.compile(r'^.*?(\{.*\})\s*$')

def get_event_predicate_labels(event_type, event_source):
"""
Expand Down Expand Up @@ -53,6 +57,54 @@ def get_event_time_string(event_time):
return ''


def filter_event_logs(row, lower_bound_date_string, upper_bound_date_string):
if row is None:
return ()
context = row.get('context', '')
raw_time = row.get('time', '')
if not context or not raw_time:
return ()
course_id = context.get('course_id', '').encode('utf-8')
user_id = context.get('user_id', None)
time = get_event_time_string(raw_time).encode('utf-8')
ip = row.get('ip', '').encode('utf-8')
if not user_id or not time:
return ()
date_string = raw_time.split("T")[0].encode('utf-8')
if date_string < lower_bound_date_string or date_string >= upper_bound_date_string:
return () # discard events outside the date interval
return (user_id, course_id, ip, time, date_string)


def parse_json_event(line, nested=False):
"""
Parse a tracking log input line as JSON to create a dict representation.
"""
try:
parsed = json.loads(line)
except Exception:
if not nested:
json_match = PATTERN_JSON.match(line)
if json_match:
return parse_json_event(json_match.group(1), nested=True)
return None
return parsed


def load_and_filter(spark_session, file, lower_bound_date_string, upper_bound_date_string):
return spark_session.sparkContext.textFile(file) \
.map(parse_json_event) \
.map(lambda row: filter_event_logs(row, lower_bound_date_string, upper_bound_date_string)) \
.filter(bool)


def validate_course_id(course_id):
course_id = opaque_key_util.normalize_course_id(course_id)
if course_id:
if opaque_key_util.is_valid_course_id(course_id):
return course_id
return ''

def get_course_id(event_context, from_url=False):
"""
Gets course_id from event's data.
Expand Down

0 comments on commit 86fd011

Please sign in to comment.