diff --git a/edx/analytics/tasks/common/spark.py b/edx/analytics/tasks/common/spark.py index 640c1b6aff..fb3bc17310 100644 --- a/edx/analytics/tasks/common/spark.py +++ b/edx/analytics/tasks/common/spark.py @@ -14,6 +14,7 @@ ManifestInputTargetMixin, convert_to_manifest_input_if_necessary, remove_manifest_target_if_exists ) from edx.analytics.tasks.util.overwrite import OverwriteOutputMixin +from edx.analytics.tasks.util.spark_util import load_and_filter from edx.analytics.tasks.util.url import UncheckedExternalURL, get_target_from_url, url_path_join _file_path_to_package_meta_path = {} @@ -201,6 +202,14 @@ class EventLogSelectionMixinSpark(EventLogSelectionDownstreamMixin): description='Whether or not to process event log source directly with spark', default=False ) + cache_rdd = luigi.BoolParameter( + description="Whether to cache rdd or not", + default=False + ) + rdd_checkpoint_directory = luigi.Parameter( + description="Path to directory where rdd can be checkpointed", + default=None + ) def __init__(self, *args, **kwargs): """ @@ -275,6 +284,37 @@ def get_event_log_dataframe(self, spark, *args, **kwargs): ) return dataframe + def get_user_location_schema(self): + from pyspark.sql.types import StructType, StringType, IntegerType + schema = StructType().add("user_id", IntegerType(), True) \ + .add("course_id", StringType(), True) \ + .add("ip", StringType(), True) \ + .add("timestamp", StringType(), True) \ + .add("event_date", StringType(), True) + + def get_dataframe(self, spark, *args, **kwargs): + from pyspark.sql.functions import to_date, udf, struct, date_format + input_source = self.get_input_source(*args) + user_location_schema = self.get_user_location_schema() + master_rdd = spark.sparkContext.union( + # filter out unwanted data as much as possible within each rdd before union + map( + lambda target: load_and_filter(spark, target.path, self.lower_bound_date_string, + self.upper_bound_date_string), + input_source + ) + ) + if self.rdd_checkpoint_directory: + # set checkpoint location before checkpointing + spark.sparkContext.setCheckpointDir(self.rdd_checkpoint_directory) + master_rdd.localCheckpoint() + if self.cache_rdd: + master_rdd.cache() + dataframe = spark.createDataFrame(master_rdd, schema=user_location_schema) + if 'user_id' not in dataframe.columns: # rename columns if they weren't named properly by createDataFrame + dataframe = dataframe.toDF('user_id', 'course_id', 'ip', 'timestamp', 'event_date') + return dataframe + class SparkJobTask(SparkMixin, OverwriteOutputMixin, EventLogSelectionDownstreamMixin, PySparkTask): """ diff --git a/edx/analytics/tasks/insights/location_per_course.py b/edx/analytics/tasks/insights/location_per_course.py index 5f767256cf..6fc4865a75 100644 --- a/edx/analytics/tasks/insights/location_per_course.py +++ b/edx/analytics/tasks/insights/location_per_course.py @@ -217,25 +217,23 @@ def run(self): super(LastDailyIpAddressOfUserTaskSpark, self).run() def spark_job(self, *args): - from edx.analytics.tasks.util.spark_util import get_event_predicate_labels, get_course_id, get_event_time_string + from edx.analytics.tasks.util.spark_util import validate_course_id from pyspark.sql.functions import udf from pyspark.sql.window import Window from pyspark.sql.types import StringType - df = self.get_event_log_dataframe(self._spark, *args) - get_event_time = udf(get_event_time_string, StringType()) - get_courseid = udf(get_course_id, StringType()) - df = df.withColumn('course_id', get_courseid(df['context'])) \ - .withColumn('timestamp', get_event_time(df['time'])) + df = self.get_dataframe(self._spark, *args) + validate_courseid = udf(validate_course_id, StringType()) + df = df.withColumn('course_id', validate_courseid(df['course_id'])) df.createOrReplaceTempView('location') query = """ SELECT timestamp, ip, user_id, course_id, dt FROM ( SELECT - event_date as dt, context.user_id as user_id, course_id, timestamp, ip, - ROW_NUMBER() over ( PARTITION BY event_date, context.user_id, course_id ORDER BY timestamp desc) as rank + event_date as dt, user_id, course_id, timestamp, ip, + ROW_NUMBER() over ( PARTITION BY event_date, user_id, course_id ORDER BY timestamp desc) as rank FROM location - WHERE ip <> '' AND timestamp <> '' AND context.user_id <> '' + WHERE ip <> '' AND timestamp <> '' AND user_id <> '' ) user_location WHERE rank = 1 """ diff --git a/edx/analytics/tasks/util/spark_util.py b/edx/analytics/tasks/util/spark_util.py index 8d03647e92..04915419db 100644 --- a/edx/analytics/tasks/util/spark_util.py +++ b/edx/analytics/tasks/util/spark_util.py @@ -1,7 +1,11 @@ """Support for spark tasks""" +import json +import re + import edx.analytics.tasks.util.opaque_key_util as opaque_key_util from edx.analytics.tasks.util.constants import PredicateLabels +PATTERN_JSON = re.compile(r'^.*?(\{.*\})\s*$') def get_event_predicate_labels(event_type, event_source): """ @@ -53,6 +57,54 @@ def get_event_time_string(event_time): return '' +def filter_event_logs(row, lower_bound_date_string, upper_bound_date_string): + if row is None: + return () + context = row.get('context', '') + raw_time = row.get('time', '') + if not context or not raw_time: + return () + course_id = context.get('course_id', '').encode('utf-8') + user_id = context.get('user_id', None) + time = get_event_time_string(raw_time).encode('utf-8') + ip = row.get('ip', '').encode('utf-8') + if not user_id or not time: + return () + date_string = raw_time.split("T")[0].encode('utf-8') + if date_string < lower_bound_date_string or date_string >= upper_bound_date_string: + return () # discard events outside the date interval + return (user_id, course_id, ip, time, date_string) + + +def parse_json_event(line, nested=False): + """ + Parse a tracking log input line as JSON to create a dict representation. + """ + try: + parsed = json.loads(line) + except Exception: + if not nested: + json_match = PATTERN_JSON.match(line) + if json_match: + return parse_json_event(json_match.group(1), nested=True) + return None + return parsed + + +def load_and_filter(spark_session, file, lower_bound_date_string, upper_bound_date_string): + return spark_session.sparkContext.textFile(file) \ + .map(parse_json_event) \ + .map(lambda row: filter_event_logs(row, lower_bound_date_string, upper_bound_date_string)) \ + .filter(bool) + + +def validate_course_id(course_id): + course_id = opaque_key_util.normalize_course_id(course_id) + if course_id: + if opaque_key_util.is_valid_course_id(course_id): + return course_id + return '' + def get_course_id(event_context, from_url=False): """ Gets course_id from event's data.