diff --git a/src/main/scala/com/mozilla/telemetry/streaming/EventsToAmplitude.scala b/src/main/scala/com/mozilla/telemetry/streaming/EventsToAmplitude.scala index b1aaf149..9c49a786 100644 --- a/src/main/scala/com/mozilla/telemetry/streaming/EventsToAmplitude.scala +++ b/src/main/scala/com/mozilla/telemetry/streaming/EventsToAmplitude.scala @@ -99,6 +99,14 @@ object EventsToAmplitude extends StreamingJobBase { descr = "In batch mode, pings will be packed into maxParallelRequests * multiplier partitions", required = false, default = Some(1)) + val dataSource: ScallopOption[String] = opt[String]( + descr = "Data source for batch mode: `heka` or `bigquery`", + required = false, + default = Some("heka")) + val bqSourceTable: ScallopOption[String] = opt[String]( + descr = "Source table, used when dataSource=='bigquery'", + required = false, + default = None) conflicts(kafkaBroker, List(from, to, fileLimit, minDelay, maxParallelRequests)) validateOpt (sample) { @@ -235,7 +243,6 @@ object EventsToAmplitude extends StreamingJobBase { } def sendBatchEvents(spark: SparkSession, opts: Opts): Unit = { - val config = readConfigFile(opts.configFilePath()) val maxParallelRequests = opts.maxParallelRequests() val partitionMultiplier = opts.partitionMultiplier() @@ -247,39 +254,74 @@ object EventsToAmplitude extends StreamingJobBase { val httpSinkConfig = AmplitudeHttpSink.Config.withMetrics(spark) val httpSink = AmplitudeHttpSink(apiKey = apiKey, url = opts.url(), httpSinkConfig) + if (opts.dataSource.get.contains("bigquery")) { + // ignore date, read from filtered table populated in Airflow DAG + + val rawEvents = spark.read.format("bigquery") + .option("table", opts.bqSourceTable.get.get) + .load() + + val events_json = rawEvents + .withColumn("event_json_escaped", + f.to_json(f.struct( + f.col("device_id"), f.col("session_id"), f.col("insert_id"), + f.col("event_type"), f.col("time"), f.col("event_properties"), + f.col("user_properties"), f.col("app_version"), f.col("os_name"), + f.col("os_version"), f.col("country"), f.col("city") + ))) + .withColumn("event_json", + f.regexp_replace(f.regexp_replace( + f.regexp_replace(f.col("event_json_escaped"), "\\\\\"", "\""), + "\"\\{", "{"), "\\}\"", "}")) + .select("device_id", "event_json") + + log.info("Sending to Amplitude...") + import spark.implicits._ + events_json.repartition(maxParallelRequests, f.col("device_id")) // Bug 1484819 + .select(f.array(f.col("event_json"))) + .as[Seq[String]] + .foreachPartition { it => + httpSink.batchAndProcess(it) + java.lang.Thread.sleep(minDelay) + } + log.info("Done!") + } else { + val config = readConfigFile(opts.configFilePath()) - datesBetween(opts.from(), opts.to.get).foreach { currentDate => - val dataset = com.mozilla.telemetry.heka.Dataset(config.source) - val topLevelFields = TOP_LEVEL_PING_FIELDS(config.source) + datesBetween(opts.from(), opts.to.get).foreach { currentDate => - val pings = config.getBatchFilters.filter{ - case(name, _) => topLevelFields.contains(name) - }.foldLeft(dataset){ - case(d, (key, values)) => d.where(key) { + val dataset = com.mozilla.telemetry.heka.Dataset(config.source) + val topLevelFields = TOP_LEVEL_PING_FIELDS(config.source) + + val pings = config.getBatchFilters.filter { + case (name, _) => topLevelFields.contains(name) + }.foldLeft(dataset) { + case (d, (key, values)) => d.where(key) { case v if values.contains(v) => true } }.where("submissionDate") { case date if date == currentDate => true }.records(opts.fileLimit.get, Some(maxParallelRequests * partitionMultiplier)) - .map(m => Row(m.toByteArray)) + .map(m => Row(m.toByteArray)) - val schema = StructType(List( + val schema = StructType(List( StructField("value", BinaryType, true) - )) + )) - val pingsDataFrame = spark.createDataFrame(pings, schema) + val pingsDataFrame = spark.createDataFrame(pings, schema) - log.info(s"Processing events for ${pingsDataFrame.count()} pings on $currentDate") + log.info(s"Processing events for ${pingsDataFrame.count()} pings on $currentDate") - import spark.implicits._ + import spark.implicits._ - getEvents(config, pingsDataFrame, opts.sample(), opts.raiseOnError()) - .repartition(maxParallelRequests, f.col("clientId")) // Bug 1484819 - .map(_.events) - .foreachPartition { it => - httpSink.batchAndProcess(it) - java.lang.Thread.sleep(minDelay) - } + getEvents(config, pingsDataFrame, opts.sample(), opts.raiseOnError()) + .repartition(maxParallelRequests, f.col("clientId")) // Bug 1484819 + .map(_.events) + .foreachPartition { it => + httpSink.batchAndProcess(it) + java.lang.Thread.sleep(minDelay) + } + } } spark.stop()