Skip to content

Commit

Permalink
feat(ingest/backend): add and use 'x-total-record' header for /get-or…
Browse files Browse the repository at this point in the history
…iginal-metadata endpoint (#2857)
  • Loading branch information
corneliusroemer authored Sep 23, 2024
1 parent 5975698 commit 64ec017
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,13 @@ open class SubmissionController(
@ApiResponse(
responseCode = "200",
description = GET_ORIGINAL_METADATA_RESPONSE_DESCRIPTION,
headers = [
Header(
name = "x-total-records",
description = "The total number of records sent in responseBody",
schema = Schema(type = "integer"),
),
],
)
@ApiResponse(
responseCode = "423",
Expand All @@ -369,16 +376,29 @@ open class SubmissionController(
@HiddenParam authenticatedUser: AuthenticatedUser,
@RequestParam compression: CompressionFormat?,
): ResponseEntity<StreamingResponseBody> {
val stillProcessing = submitModel.checkIfStillProcessingSubmittedData()
if (stillProcessing) {
return ResponseEntity.status(HttpStatus.LOCKED).build()
}

val headers = HttpHeaders()
headers.contentType = MediaType.parseMediaType(MediaType.APPLICATION_NDJSON_VALUE)
if (compression != null) {
headers.add(HttpHeaders.CONTENT_ENCODING, compression.compressionName)
}

val stillProcessing = submitModel.checkIfStillProcessingSubmittedData()
if (stillProcessing) {
return ResponseEntity.status(HttpStatus.LOCKED).build()
}
val totalRecords = submissionDatabaseService.countOriginalMetadata(
authenticatedUser,
organism,
groupIdsFilter?.takeIf { it.isNotEmpty() },
statusesFilter?.takeIf { it.isNotEmpty() },
)
headers.add("x-total-records", totalRecords.toString())
// TODO(https://github.com/loculus-project/loculus/issues/2778)
// There's a possibility that the totalRecords change between the count and the actual query
// this is not too bad, if the client ends up with a few more records than expected
// We just need to make sure the etag used is from before the count
// Alternatively, we could read once to file while counting and then stream the file

val streamBody = streamTransactioned(compression) {
submissionDatabaseService.streamOriginalMetadata(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -950,13 +950,12 @@ open class SubmissionDatabaseService(
)
}

fun streamOriginalMetadata(
private fun originalMetadataFilter(
authenticatedUser: AuthenticatedUser,
organism: Organism,
groupIdsFilter: List<Int>?,
statusesFilter: List<Status>?,
fields: List<String>?,
): Sequence<AccessionVersionOriginalMetadata> {
): Op<Boolean> {
val organismCondition = SequenceEntriesView.organismIs(organism)
val groupCondition = getGroupCondition(groupIdsFilter, authenticatedUser)
val statusCondition = if (statusesFilter != null) {
Expand All @@ -966,6 +965,33 @@ open class SubmissionDatabaseService(
}
val conditions = organismCondition and groupCondition and statusCondition

return conditions
}

fun countOriginalMetadata(
authenticatedUser: AuthenticatedUser,
organism: Organism,
groupIdsFilter: List<Int>?,
statusesFilter: List<Status>?,
): Long = SequenceEntriesView
.selectAll()
.where(
originalMetadataFilter(
authenticatedUser,
organism,
groupIdsFilter,
statusesFilter,
),
)
.count()

fun streamOriginalMetadata(
authenticatedUser: AuthenticatedUser,
organism: Organism,
groupIdsFilter: List<Int>?,
statusesFilter: List<Status>?,
fields: List<String>?,
): Sequence<AccessionVersionOriginalMetadata> {
val originalMetadata = SequenceEntriesView.originalDataColumn
.extract<Map<String, String>>("metadata")
.alias("original_metadata")
Expand All @@ -976,7 +1002,14 @@ open class SubmissionDatabaseService(
SequenceEntriesView.accessionColumn,
SequenceEntriesView.versionColumn,
)
.where(conditions)
.where(
originalMetadataFilter(
authenticatedUser,
organism,
groupIdsFilter,
statusesFilter,
),
)
.fetchSize(streamBatchSize)
.asSequence()
.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ class GetOriginalMetadataEndpointTest(
@Test
fun `GIVEN no sequence entries in database THEN returns empty response`() {
val response = submissionControllerClient.getOriginalMetadata()

val responseBody = response.expectNdjsonAndGetContent<MetadataMap>()

response.andExpect(status().isOk)
.andExpect(header().string("x-total-records", `is`("0")))
assertThat(responseBody, `is`(emptyList()))
}

Expand All @@ -63,6 +65,9 @@ class GetOriginalMetadataEndpointTest(
val response = submissionControllerClient.getOriginalMetadata()

val responseBody = response.expectNdjsonAndGetContent<AccessionVersionOriginalMetadata>()

response.andExpect(status().isOk)
.andExpect(header().string("x-total-records", `is`(DefaultFiles.NUMBER_OF_SEQUENCES.toString())))
assertThat(responseBody.size, `is`(DefaultFiles.NUMBER_OF_SEQUENCES))
}

Expand Down Expand Up @@ -150,6 +155,8 @@ class GetOriginalMetadataEndpointTest(
groupIdsFilter = listOf(g0),
statusesFilter = listOf(Status.APPROVED_FOR_RELEASE),
)
response.andExpect(status().isOk)
.andExpect(header().string("x-total-records", `is`(expectedAccessionVersions.count().toString())))
val responseBody = response.expectNdjsonAndGetContent<AccessionVersionOriginalMetadata>()

assertThat(responseBody, hasSize(expected.size))
Expand Down
33 changes: 21 additions & 12 deletions ingest/scripts/call_loculus.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,20 +312,29 @@ def get_submitted(config: Config):
"statusesFilter": [],
}

logger.info("Getting previously submitted sequences")
while True:
logger.info("Getting previously submitted sequences")

response = make_request(HTTPMethod.GET, url, config, params=params)
response = make_request(HTTPMethod.GET, url, config, params=params)
expected_record_count = int(response.headers["x-total-records"])

entries: list[dict[str, Any]] = []
try:
entries = list(jsonlines.Reader(response.iter_lines()).iter())
except jsonlines.Error as err:
response_summary = response.text
max_error_length = 100
if len(response_summary) > max_error_length:
response_summary = response_summary[:50] + "\n[..]\n" + response_summary[-50:]
logger.error(f"Error decoding JSON from /get-original-metadata: {response_summary}")
raise ValueError from err
entries: list[dict[str, Any]] = []
try:
entries = list(jsonlines.Reader(response.iter_lines()).iter())
except jsonlines.Error as err:
response_summary = response.text
max_error_length = 100
if len(response_summary) > max_error_length:
response_summary = response_summary[:50] + "\n[..]\n" + response_summary[-50:]
logger.error(f"Error decoding JSON from /get-original-metadata: {response_summary}")
raise ValueError from err

if len(entries) == expected_record_count:
f"Got {len(entries)} records as expected"
break
logger.error(f"Got incomplete original metadata stream: expected {len(entries)}"
f"records but got {expected_record_count}. Retrying after 60 seconds.")
sleep(60)

# Initialize the dictionary to store results
submitted_dict: dict[str, dict[str, str | list]] = {}
Expand Down

0 comments on commit 64ec017

Please sign in to comment.