Skip to content

Commit

Permalink
Improve study view treatment api performance
Browse files Browse the repository at this point in the history
  • Loading branch information
kalletlak committed Jul 12, 2023
1 parent 719e5d2 commit 48f4d69
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 141 deletions.
1 change: 1 addition & 0 deletions db-scripts/src/main/resources/migration.sql
Original file line number Diff line number Diff line change
Expand Up @@ -1026,4 +1026,5 @@ UPDATE `info` SET `DB_SCHEMA_VERSION`="2.13.0";
ALTER TABLE `clinical_event_data` MODIFY COLUMN `VALUE` varchar(3000) NOT NULL;
CREATE INDEX idx_clinical_event_key ON clinical_event_data (`KEY`);
CREATE INDEX idx_clinical_event_value ON clinical_event_data (`VALUE`);
CREATE INDEX idx_sample_stable_id ON sample (`STABLE_ID`);
UPDATE `info` SET `DB_SCHEMA_VERSION`="2.13.1";
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,18 @@ public interface TreatmentRepository {
public Map<String, List<Treatment>> getTreatmentsByPatientId(List<String> sampleIds, List<String> studyIds, ClinicalEventKeyCode key);

@Cacheable(cacheResolver = "generalRepositoryCacheResolver", condition = "@cacheEnabledConfig.getEnabled()")
public Map<String, List<ClinicalEventSample>> getSamplesByPatientId(List<String> sampleIds, List<String> studyIds);
public List<Treatment> getTreatments(List<String> sampleIds, List<String> studyIds, ClinicalEventKeyCode key);

@Cacheable(cacheResolver = "generalRepositoryCacheResolver", condition = "@cacheEnabledConfig.getEnabled()")
public Map<String, List<ClinicalEventSample>> getShallowSamplesByPatientId(List<String> sampleIds, List<String> studyIds);
public Map<String, List<ClinicalEventSample>> getSamplesByPatientId(List<String> sampleIds, List<String> studyIds);

@Cacheable(cacheResolver = "generalRepositoryCacheResolver", condition = "@cacheEnabledConfig.getEnabled()")
public Set<String> getAllUniqueTreatments(List<String> sampleIds, List<String> studyIds, ClinicalEventKeyCode key);
public Map<String, List<ClinicalEventSample>> getShallowSamplesByPatientId(List<String> sampleIds, List<String> studyIds);

@Cacheable(cacheResolver = "generalRepositoryCacheResolver", condition = "@cacheEnabledConfig.getEnabled()")
public Boolean hasTreatmentData(List<String> studies, String key);
public Boolean hasTreatmentData(List<String> studies, ClinicalEventKeyCode key);

@Cacheable(cacheResolver = "generalRepositoryCacheResolver", condition = "@cacheEnabledConfig.getEnabled()")
public Boolean hasSampleTimelineData(List<String> studies);

@Cacheable(cacheResolver = "generalRepositoryCacheResolver", condition = "@cacheEnabledConfig.getEnabled()")
public Boolean studyIdHasTreatments(String studyId, ClinicalEventKeyCode key);
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@ public interface TreatmentMapper {

List<ClinicalEventSample> getAllShallowSamples(List<String> sampleIds, List<String> studyIds);

Set<String> getAllUniqueTreatments(List<String> sampleIds, List<String> studyIds, String key);

Boolean hasTreatmentData(List<String> sampleIds, List<String> studyIds, String key);

Boolean hasSampleTimelineData(List<String> sampleIds, List<String> studyIds);

Boolean studyIdHasTreatments(String studyId, String key);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@ public class TreatmentMyBatisRepository implements TreatmentRepository {

@Override
public Map<String, List<Treatment>> getTreatmentsByPatientId(List<String> sampleIds, List<String> studyIds, ClinicalEventKeyCode key) {
return getTreatments(sampleIds, studyIds, key)
.stream()
.collect(groupingBy(Treatment::getPatientId));
}

@Override
public List<Treatment> getTreatments(List<String> sampleIds, List<String> studyIds, ClinicalEventKeyCode key) {
return treatmentMapper.getAllTreatments(sampleIds, studyIds, key.getKey())
.stream()
.flatMap(treatment -> splitIfDelimited(treatment, key))
.collect(groupingBy(Treatment::getPatientId));


.collect(Collectors.toList());
}

private Stream<Treatment> splitIfDelimited(Treatment unsplitTreatment, ClinicalEventKeyCode key) {
Expand Down Expand Up @@ -61,31 +66,12 @@ public Map<String, List<ClinicalEventSample>> getShallowSamplesByPatientId(List<
}

@Override
public Set<String> getAllUniqueTreatments(List<String> sampleIds, List<String> studyIds, ClinicalEventKeyCode key) {
return treatmentMapper.getAllUniqueTreatments(sampleIds, studyIds, key.getKey())
.stream()
.flatMap(treatment -> {
if (key.isDelimited()) {
return Arrays.stream(treatment.split(key.getDelimiter()));
} else {
return Stream.of(treatment);
}
})
.collect(Collectors.toSet());
}

@Override
public Boolean hasTreatmentData(List<String> studies, String key) {
return treatmentMapper.hasTreatmentData(null, studies, key);
public Boolean hasTreatmentData(List<String> studies, ClinicalEventKeyCode key) {
return treatmentMapper.hasTreatmentData(null, studies, key.getKey());
}

@Override
public Boolean hasSampleTimelineData(List<String> studies) {
return treatmentMapper.hasSampleTimelineData(null, studies);
}

@Override
public Boolean studyIdHasTreatments(String studyId, ClinicalEventKeyCode key) {
return treatmentMapper.studyIdHasTreatments(studyId, key.getKey());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
INNER JOIN sample ON patient.INTERNAL_ID = sample.PATIENT_ID
INNER JOIN cancer_study ON patient.CANCER_STUDY_ID = cancer_study.CANCER_STUDY_ID
<include refid="where"/>
AND clinical_event.EVENT_TYPE = 'TREATMENT'
AND clinical_event_data.KEY = #{key}
</select>

Expand Down Expand Up @@ -48,19 +49,6 @@
<include refid="where"/>
</select>

<select id="getAllUniqueTreatments" resultType="java.lang.String">
SELECT
DISTINCT clinical_event_data.VALUE
FROM
clinical_event
INNER JOIN clinical_event_data ON clinical_event.CLINICAL_EVENT_ID = clinical_event_data.CLINICAL_EVENT_ID
INNER JOIN patient ON clinical_event.PATIENT_ID = patient.INTERNAL_ID
INNER JOIN sample ON patient.INTERNAL_ID = sample.PATIENT_ID
INNER JOIN cancer_study ON patient.CANCER_STUDY_ID = cancer_study.CANCER_STUDY_ID
<include refid="where"/>
AND clinical_event_data.KEY = #{key}
</select>

<select id="hasTreatmentData" resultType="java.lang.Boolean">
SELECT EXISTS(SELECT
*
Expand All @@ -71,6 +59,7 @@
INNER JOIN sample ON patient.INTERNAL_ID = sample.PATIENT_ID
INNER JOIN cancer_study ON patient.CANCER_STUDY_ID = cancer_study.CANCER_STUDY_ID
<include refid="where"/>
AND clinical_event.EVENT_TYPE = 'TREATMENT'
AND clinical_event_data.KEY = #{key} LIMIT 1
)
</select>
Expand All @@ -88,20 +77,6 @@
AND clinical_event_data.KEY = 'SAMPLE_ID'
AND (clinical_event.EVENT_TYPE LIKE 'Sample Acquisition' OR clinical_event.EVENT_TYPE LIKE 'SPECIMEN') LIMIT 1)
</select>


<select id="studyIdHasTreatments" resultType="java.lang.Boolean">
SELECT EXISTS(SELECT
*
FROM
clinical_event
INNER JOIN clinical_event_data ON clinical_event.CLINICAL_EVENT_ID = clinical_event_data.CLINICAL_EVENT_ID
INNER JOIN patient ON clinical_event.PATIENT_ID = patient.INTERNAL_ID
INNER JOIN cancer_study ON patient.CANCER_STUDY_ID = cancer_study.CANCER_STUDY_ID
WHERE
clinical_event_data.KEY = #{key}
AND cancer_study.CANCER_STUDY_IDENTIFIER = #{studyId} LIMIT 1)
</select>

<sql id="where">
<where>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,22 +133,9 @@ public void getShallowSamplesByPatientId() {
Assert.assertEquals(actual, expected);
}

@Test
public void getAllUniqueTreatments() {
HashSet<String> expected = new HashSet<>(Collections.singletonList("Madeupanib"));

Set<String> actual = treatmentRepository.getAllUniqueTreatments(
Collections.singletonList("TCGA-A1-A0SD-01"),
Collections.singletonList("study_tcga_pub"),
ClinicalEventKeyCode.Agent
);

Assert.assertEquals(actual, expected);
}

@Test
public void hasTreatmentData() {
Boolean actual = treatmentRepository.hasTreatmentData(Collections.singletonList("study_tcga_pub"), ClinicalEventKeyCode.Agent.getKey());
Boolean actual = treatmentRepository.hasTreatmentData(Collections.singletonList("study_tcga_pub"), ClinicalEventKeyCode.Agent);

Assert.assertEquals(actual, true);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ private Pair<List<String>, List<String>> filterIds(List<String> sampleIds, List<
}
Set<String> studiesWithTreatments = studyIds.stream()
.distinct()
.filter(studyId -> treatmentRepository.studyIdHasTreatments(studyId, key))
.filter(studyId -> treatmentRepository.hasTreatmentData(Collections.singletonList(studyId), key))
.collect(Collectors.toSet());

ArrayList<String> filteredSampleIds = new ArrayList<>();
Expand Down Expand Up @@ -158,62 +158,37 @@ public List<PatientTreatmentRow> getAllPatientTreatmentRows(
sampleIds = filteredIds.getLeft();
studyIds = filteredIds.getRight();

Map<String, List<Treatment>> treatmentsByPatient =
treatmentRepository.getTreatmentsByPatientId(sampleIds, studyIds, key);
Map<String, List<ClinicalEventSample>> samplesByPatient = treatmentRepository
.getShallowSamplesByPatientId(sampleIds, studyIds)
.entrySet()
.stream()
.filter(e -> treatmentsByPatient.containsKey(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
Set<String> treatments = treatmentRepository.getAllUniqueTreatments(sampleIds, studyIds, key);

return treatments.stream()
.map(t -> createPatientTreatmentRowForTreatment(t, treatmentsByPatient, samplesByPatient))
.collect(Collectors.toList());
}

private PatientTreatmentRow createPatientTreatmentRowForTreatment(
String treatment,
Map<String, List<Treatment>> treatmentsByPatient,
Map<String, List<ClinicalEventSample>> samplesByPatient
) {
// find all the patients that have received this treatment
List<Map.Entry<String, List<Treatment>>> matchingPatients = matchingPatients(treatment, treatmentsByPatient);
.getShallowSamplesByPatientId(sampleIds, studyIds);

// from those patients, extract the unique samples
Set<ClinicalEventSample> samples = matchingPatients
Map<String, List<Treatment>> treatmentSet = treatmentRepository.getTreatments(sampleIds, studyIds, key)
.stream()
.map(Map.Entry::getKey)
.flatMap(patient -> samplesByPatient.getOrDefault(patient, new ArrayList<>()).stream())
.collect(Collectors.toSet());


return new PatientTreatmentRow(treatment, matchingPatients.size(), samples);
}
.collect(groupingBy(Treatment::getTreatment));

private List<Map.Entry<String, List<Treatment>>> matchingPatients(
String treatment,
Map<String, List<Treatment>> treatmentsByPatient
) {
return treatmentsByPatient.entrySet().stream()
.filter(p -> p.getValue().stream().anyMatch(t -> t.getTreatment().equals(treatment)))
return treatmentSet.entrySet()
.stream()
.map(entry -> {
String treatment = entry.getKey();
Set<String> patientIds = entry.getValue().stream().map(Treatment::getPatientId).collect(toSet());
Set<ClinicalEventSample> clinicalEventSamples = patientIds
.stream()
.flatMap(patientId -> samplesByPatient.getOrDefault(patientId, new ArrayList<>()).stream())
.collect(toSet());
return new PatientTreatmentRow(treatment, patientIds.size(), clinicalEventSamples);
})
.collect(toList());
}

@Override
public Boolean containsTreatmentData(List<String> studies, ClinicalEventKeyCode key) {
return treatmentRepository.hasTreatmentData(studies, key.getKey());
return treatmentRepository.hasTreatmentData(studies, key);
}

@Override
public Boolean containsSampleTreatmentData(List<String> studyIds, ClinicalEventKeyCode key) {
studyIds = studyIds.stream()
.filter(studyId -> treatmentRepository.studyIdHasTreatments(studyId, key))
.filter(studyId -> treatmentRepository.hasTreatmentData(Collections.singletonList(studyId), key))
.collect(Collectors.toList());
if(studyIds.size() > 0 && treatmentRepository.hasSampleTimelineData(studyIds)) {
return treatmentRepository.hasTreatmentData(studyIds, key.getKey());
}
return false;
return studyIds.size() > 0 && treatmentRepository.hasSampleTimelineData(studyIds);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,11 @@ public class TreatmentServiceImplTest {
private TreatmentRepository treatmentRepository;

@Test
public void getAllSampleTreatmentsSingleRow() {
mockTreatmentsByPatient(
makeTreatment("madeupanib", "P0", 0, 10)
);
public void getAllPatientTreatmentRows() {
mockSamplesByPatient(
makeSample("S0", "P0", 5)
);
mockAllTreatments("madeupanib");
mockTreatments(makeTreatment("madeupanib", "P0", 0, 10));

PatientTreatmentRow rowA = makePatientRow("madeupanib", 1, Collections.singletonList("S0"), Collections.singletonList("P0"));
List<PatientTreatmentRow> expected = Collections.singletonList(rowA);
Expand All @@ -43,15 +40,14 @@ public void getAllSampleTreatmentsSingleRow() {
}

@Test
public void getAllSampleTreatmentsOneSampleTwoTreatmentsOnePatient() {
mockTreatmentsByPatient(
public void getAllPatientTreatmentRowsOneSampleTwoTreatmentsOnePatient() {
mockTreatments(
makeTreatment("madeupanib", "P0", 0, 10),
makeTreatment("fakedrugazol", "P0", 0, 10)
);
mockSamplesByPatient(
makeSample("S0", "P0", 5)
);
mockAllTreatments("madeupanib", "fakedrugazol");


PatientTreatmentRow rowA = makePatientRow("fakedrugazol", 1, Collections.singletonList("S0"), Collections.singletonList("P0"));
Expand All @@ -63,15 +59,14 @@ public void getAllSampleTreatmentsOneSampleTwoTreatmentsOnePatient() {
}

@Test
public void getAllSampleTreatmentsTwoSamplesOnePatientOneTreatment() {
mockTreatmentsByPatient(
public void getAllPatientTreatmentRowsTwoSamplesOnePatientOneTreatment() {
mockTreatments(
makeTreatment("fakedrugazol", "P0", 0, 10)
);
mockSamplesByPatient(
makeSample("S0", "P0", 5),
makeSample("S1", "P0", 10)
);
mockAllTreatments("fakedrugazol");

// even though there are two samples, you expect a count of 1,
// because this is from the patient level, and both samples are for the same patient
Expand All @@ -83,16 +78,15 @@ public void getAllSampleTreatmentsTwoSamplesOnePatientOneTreatment() {
}

@Test
public void getAllSampleTreatmentsTwoSamplesTwoPatientsTwoTreatments() {
mockTreatmentsByPatient(
public void getAllPatientTreatmentRowsTwoSamplesTwoPatientsTwoTreatments() {
mockTreatments(
makeTreatment("fakedrugazol", "P0", 0, 10),
makeTreatment("fakedrugazol", "P1", 0, 10)
);
mockSamplesByPatient(
makeSample("S0", "P0", 5),
makeSample("S1", "P1", 10)
);
mockAllTreatments("fakedrugazol");

// now there are two patients, so you expect a count of two
PatientTreatmentRow rowA = makePatientRow("fakedrugazol", 2, Arrays.asList("S0", "S1"), Arrays.asList("P0", "P1"));
Expand All @@ -103,16 +97,15 @@ public void getAllSampleTreatmentsTwoSamplesTwoPatientsTwoTreatments() {
}

@Test
public void getAllSampleTreatmentsTwoSamplesTwoPatientsTwoDifferentTreatments() {
mockTreatmentsByPatient(
public void getAllPatientTreatmentRowsTwoSamplesTwoPatientsTwoDifferentTreatments() {
mockTreatments(
makeTreatment("fakedrugazol", "P0", 0, 10),
makeTreatment("madeupanib", "P1", 0, 10)
);
mockSamplesByPatient(
makeSample("S0", "P0", 5),
makeSample("S1", "P1", 10)
);
mockAllTreatments("fakedrugazol", "madeupanib");

PatientTreatmentRow rowA = makePatientRow("fakedrugazol", 1, Collections.singletonList("S0"), Collections.singletonList("P0"));
PatientTreatmentRow rowB = makePatientRow("madeupanib", 1, Collections.singletonList("S1"), Collections.singletonList("P1"));
Expand Down Expand Up @@ -252,6 +245,11 @@ private void mockTreatmentsByPatient(Treatment... treatments) {
.thenReturn(treatmentsByPatient);
}

private void mockTreatments(Treatment... treatments) {
Mockito.when(treatmentRepository.getTreatments(Mockito.any(), Mockito.any(), Mockito.any()))
.thenReturn(Arrays.stream(treatments).collect(Collectors.toList()));
}

private void mockSamplesByPatient(ClinicalEventSample... samples) {
Map<String, List<ClinicalEventSample>> samplesByPatient = Arrays.stream(samples)
.collect(Collectors.groupingBy(ClinicalEventSample::getPatientId));
Expand All @@ -261,12 +259,6 @@ private void mockSamplesByPatient(ClinicalEventSample... samples) {
.thenReturn(samplesByPatient);
}

private void mockAllTreatments(String... treatments) {
Set<String> allTreatments = new HashSet<>(Arrays.asList(treatments));
Mockito.when(treatmentRepository.getAllUniqueTreatments(Mockito.any(), Mockito.any(), Mockito.any()))
.thenReturn(allTreatments);
}

private Treatment makeTreatment(String treatment, String patientId, Integer start, Integer stop) {
Treatment t = new Treatment();
t.setTreatment(treatment);
Expand Down

0 comments on commit 48f4d69

Please sign in to comment.