diff --git a/scio-bigquery/src/main/scala/com/spotify/scio/bigquery/BigQueryClient.scala b/scio-bigquery/src/main/scala/com/spotify/scio/bigquery/BigQueryClient.scala index 073596b477..f76c79ead4 100644 --- a/scio-bigquery/src/main/scala/com/spotify/scio/bigquery/BigQueryClient.scala +++ b/scio-bigquery/src/main/scala/com/spotify/scio/bigquery/BigQueryClient.scala @@ -72,8 +72,16 @@ object BigQueryUtil { } +/** A query job that may delay execution. */ +trait QueryJob { + def waitForResult(): Unit + val jobReference: Option[JobReference] + val query: String + val table: TableReference +} + /** A simple BigQuery client. */ -class BigQueryClient private (private val projectId: String, auth: Option[Either[Credential, String]]) { +class BigQueryClient private (private val projectId: String, auth: Option[Either[Credential, String]]) { self => private val SCOPES = List(BigqueryScopes.BIGQUERY).asJava @@ -126,9 +134,9 @@ class BigQueryClient private (private val projectId: String, auth: Option[Either /** Get rows from a query. */ def getQueryRows(sqlQuery: String): Iterator[TableRow] = { - val (tableRef, jobRef) = queryIntoTable(sqlQuery) - jobRef.foreach(j => waitForJobs(j)) - getTableRows(tableRef) + val queryJob = queryIntoTable(sqlQuery) + queryJob.waitForResult() + getTableRows(queryJob.table) } /** Get rows from a table. */ @@ -169,30 +177,34 @@ class BigQueryClient private (private val projectId: String, auth: Option[Either } /** Execute a query and save results into a temporary table. */ - def queryIntoTable(sqlQuery: String): (TableReference, Option[JobReference]) = { - prepareStagingDataset() - - logger.info(s"Executing BigQuery for query: $sqlQuery") - + def queryIntoTable(sqlQuery: String): QueryJob = { try { val sourceTimes = BigQueryUtil.extractTables(sqlQuery).map(t => BigInt(getTable(t).getLastModifiedTime)) - val table = getCacheDestinationTable(sqlQuery).get - val time = BigInt(getTable(table).getLastModifiedTime) + val temp = getCacheDestinationTable(sqlQuery).get + val time = BigInt(getTable(temp).getLastModifiedTime) if (sourceTimes.forall(_ < time)) { - logger.info(s"Cache hit, existing destination table: ${BigQueryIO.toTableSpec(table)}") - (table, None) + logger.info(s"Cache hit for query: $sqlQuery") + logger.info(s"Existing destination table: ${BigQueryIO.toTableSpec(temp)}") + new QueryJob { + override def waitForResult(): Unit = {} + override val jobReference: Option[JobReference] = None + override val query: String = sqlQuery + override val table: TableReference = temp + } } else { val temp = temporaryTable(TABLE_PREFIX) - logger.info(s"Cache invalid, new destination table: ${BigQueryIO.toTableSpec(temp)}") + logger.info(s"Cache invalid for query: $sqlQuery") + logger.info(s"New destination table: ${BigQueryIO.toTableSpec(temp)}") setCacheDestinationTable(sqlQuery, temp) - (temp, Some(makeQuery(sqlQuery, temp))) + makeBigQueryJob(sqlQuery, temp) } } catch { case _: Throwable => val temp = temporaryTable(TABLE_PREFIX) - logger.info(s"Cache miss, new destination table: ${BigQueryIO.toTableSpec(temp)}") + logger.info(s"Cache miss for query: $sqlQuery") + logger.info(s"New destination table: ${BigQueryIO.toTableSpec(temp)}") setCacheDestinationTable(sqlQuery, temp) - (temp, Some(makeQuery(sqlQuery, temp))) + makeBigQueryJob(sqlQuery, temp) } } @@ -212,23 +224,30 @@ class BigQueryClient private (private val projectId: String, auth: Option[Either writeTableRows(BigQueryIO.parseTableSpec(tableSpec), rows, schema, writeDisposition, createDisposition) /** Wait for all jobs to finish. */ - def waitForJobs(jobReferences: JobReference*): Unit = { - val ids = jobReferences.map(_.getJobId).toBuffer - var allDone: Boolean = false - while (!allDone && ids.nonEmpty) { - val pollJobs = ids.map(bigquery.jobs().get(projectId, _).execute()) - pollJobs.foreach { j => - val error = j.getStatus.getErrorResult + def waitForJobs(jobs: QueryJob*): Unit = { + val numTotal = jobs.size + var pendingJobs = jobs.filter(_.jobReference.isDefined) + + while (pendingJobs.nonEmpty) { + val remainingJobs = pendingJobs.filter { j => + val jobId = j.jobReference.get.getJobId + val poll = bigquery.jobs().get(projectId, jobId).execute() + val error = poll.getStatus.getErrorResult if (error != null) { - throw new RuntimeException(s"BigQuery failed: $error") + throw new RuntimeException(s"Query job failed: id: $jobId, error: $error") + } + if (poll.getStatus.getState == "DONE") { + logJobStatistics(j.query, poll) + false + } else { + true } } - val done = pollJobs.count(_.getStatus.getState == "DONE") - logger.info(s"BigQuery jobs: $done out of ${pollJobs.size}") - allDone = done == pollJobs.size - if (allDone) { - pollJobs.foreach(logJobStatistics) - } else { + + pendingJobs = remainingJobs + val numDone = numTotal - pendingJobs.size + logger.info(s"Query: $numDone out of $numTotal completed") + if (pendingJobs.nonEmpty) { Thread.sleep(10000) } } @@ -270,34 +289,42 @@ class BigQueryClient private (private val projectId: String, auth: Option[Either new JobReference().setProjectId(projectId).setJobId(fullJobId) } - private def makeQuery(sqlQuery: String, destinationTable: TableReference): JobReference = { - val queryConfig: JobConfigurationQuery = new JobConfigurationQuery() - .setQuery(sqlQuery) - .setAllowLargeResults(true) - .setFlattenResults(false) - .setPriority(PRIORITY) - .setCreateDisposition("CREATE_IF_NEEDED") - .setWriteDisposition("WRITE_EMPTY") - .setDestinationTable(destinationTable) - - val jobConfig = new JobConfiguration().setQuery(queryConfig) - val jobReference = createJobReference(projectId, JOB_ID_PREFIX) - val job = new Job().setConfiguration(jobConfig).setJobReference(jobReference) - bigquery.jobs().insert(projectId, job).execute().getJobReference + private def makeBigQueryJob(sqlQuery: String, destinationTable: TableReference): QueryJob = new QueryJob { + override def waitForResult(): Unit = self.waitForJobs(this) + override lazy val jobReference: Option[JobReference] = { + logger.info(s"Executing query: $sqlQuery") + val queryConfig: JobConfigurationQuery = new JobConfigurationQuery() + .setQuery(sqlQuery) + .setAllowLargeResults(true) + .setFlattenResults(false) + .setPriority(PRIORITY) + .setCreateDisposition("CREATE_IF_NEEDED") + .setWriteDisposition("WRITE_EMPTY") + .setDestinationTable(destinationTable) + + val jobConfig = new JobConfiguration().setQuery(queryConfig) + val jobReference = createJobReference(projectId, JOB_ID_PREFIX) + val job = new Job().setConfiguration(jobConfig).setJobReference(jobReference) + Some(bigquery.jobs().insert(projectId, job).execute().getJobReference) + } + override val query: String = sqlQuery + override val table: TableReference = destinationTable } - private def logJobStatistics(job: Job): Unit = { + private def logJobStatistics(sqlQuery: String, job: Job): Unit = { val jobId = job.getJobReference.getJobId val stats = job.getStatistics + logger.info(s"Query completed: jobId: $jobId") + logger.info(s"Query: $sqlQuery") val elapsed = PERIOD_FORMATTER.print(new Period(stats.getEndTime - stats.getCreationTime)) val pending = PERIOD_FORMATTER.print(new Period(stats.getStartTime - stats.getCreationTime)) val execution = PERIOD_FORMATTER.print(new Period(stats.getEndTime - stats.getStartTime)) - logger.info(s"Job $jobId: elapsed: $elapsed, pending: $pending, execution: $execution") + logger.info(s"Elapsed: $elapsed, pending: $pending, execution: $execution") val bytes = FileUtils.byteCountToDisplaySize(stats.getQuery.getTotalBytesProcessed) val cacheHit = stats.getQuery.getCacheHit - logger.info(s"Job $jobId: total bytes processed: $bytes, cache hit: $cacheHit") + logger.info(s"Total bytes processed: $bytes, cache hit: $cacheHit") } private def getTable(table: TableReference): Table = { diff --git a/scio-core/src/main/scala/com/spotify/scio/ScioContext.scala b/scio-core/src/main/scala/com/spotify/scio/ScioContext.scala index abd68f2ea4..b8df4891cb 100644 --- a/scio-core/src/main/scala/com/spotify/scio/ScioContext.scala +++ b/scio-core/src/main/scala/com/spotify/scio/ScioContext.scala @@ -142,7 +142,7 @@ class ScioContext private[scio] (val options: DataflowPipelineOptions, private v private var _pipeline: Pipeline = null private var _isClosed: Boolean = false private val _promises: MBuffer[(Promise[Tap[_]], Tap[_])] = MBuffer.empty - private val _bigQueryJobs: MBuffer[JobReference] = MBuffer.empty + private val _queryJobs: MBuffer[QueryJob] = MBuffer.empty private val _accumulators: MSet[String] = MSet.empty /** Wrap a [[com.google.cloud.dataflow.sdk.values.PCollection PCollection]]. */ @@ -210,8 +210,8 @@ class ScioContext private[scio] (val options: DataflowPipelineOptions, private v /** Close the context. No operation can be performed once the context is closed. */ def close(): ScioResult = { - if (_bigQueryJobs.nonEmpty) { - bigQueryClient.waitForJobs(_bigQueryJobs: _*) + if (_queryJobs.nonEmpty) { + bigQueryClient.waitForJobs(_queryJobs: _*) } _isClosed = true @@ -334,9 +334,9 @@ class ScioContext private[scio] (val options: DataflowPipelineOptions, private v if (this.isTest) { this.getTestInput(BigQueryIO(sqlQuery)) } else { - val (tableRef, jobRef) = this.bigQueryClient.queryIntoTable(sqlQuery) - jobRef.foreach(j => _bigQueryJobs.append(j)) - wrap(this.applyInternal(GBigQueryIO.Read.from(tableRef).withoutValidation())).setName(sqlQuery) + val queryJob = this.bigQueryClient.queryIntoTable(sqlQuery) + _queryJobs.append(queryJob) + wrap(this.applyInternal(GBigQueryIO.Read.from(queryJob.table).withoutValidation())).setName(sqlQuery) } } diff --git a/scio-core/src/main/scala/com/spotify/scio/io/Taps.scala b/scio-core/src/main/scala/com/spotify/scio/io/Taps.scala index eae5c9893c..e8fdd85757 100644 --- a/scio-core/src/main/scala/com/spotify/scio/io/Taps.scala +++ b/scio-core/src/main/scala/com/spotify/scio/io/Taps.scala @@ -50,9 +50,9 @@ trait Taps { private def bigQueryTap(sqlQuery: String): BigQueryTap = { val bq = BigQueryClient.defaultInstance() - val (tableRef, jobRef) = bq.queryIntoTable(sqlQuery) - jobRef.foreach(j => bq.waitForJobs(j)) - BigQueryTap(tableRef) + val queryJob = bq.queryIntoTable(sqlQuery) + queryJob.waitForResult() + BigQueryTap(queryJob.table) } /**