Skip to content

Commit

Permalink
Merge pull request cBioPortal#10261 from kalletlak/timeline-performance
Browse files Browse the repository at this point in the history
Improve study-view treatment api's performance
  • Loading branch information
kalletlak authored Jul 13, 2023
2 parents 2712b20 + dc487c2 commit b3400e6
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 171 deletions.
2 changes: 1 addition & 1 deletion db-scripts/src/main/resources/cgds.sql
Original file line number Diff line number Diff line change
Expand Up @@ -755,5 +755,5 @@ CREATE TABLE `resource_study` (
);

-- THIS MUST BE KEPT IN SYNC WITH db.version PROPERTY IN pom.xml
INSERT INTO info VALUES ('2.13.0', NULL);
INSERT INTO info VALUES ('2.13.1', NULL);

7 changes: 7 additions & 0 deletions db-scripts/src/main/resources/migration.sql
Original file line number Diff line number Diff line change
Expand Up @@ -1021,3 +1021,10 @@ ALTER TABLE `mutation_event` CHANGE COLUMN `ONCOTATOR_PROTEIN_POS_START` `PROTEI
ALTER TABLE `mutation_event` CHANGE COLUMN `ONCOTATOR_PROTEIN_POS_END` `PROTEIN_POS_END` int(11);
UPDATE `info` SET `DB_SCHEMA_VERSION`="2.13.0";


##version: 2.13.1
ALTER TABLE `clinical_event_data` MODIFY COLUMN `VALUE` varchar(3000) NOT NULL;
CREATE INDEX idx_clinical_event_key ON clinical_event_data (`KEY`);
CREATE INDEX idx_clinical_event_value ON clinical_event_data (`VALUE`);
CREATE INDEX idx_sample_stable_id ON sample (`STABLE_ID`);
UPDATE `info` SET `DB_SCHEMA_VERSION`="2.13.1";
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,18 @@ public interface TreatmentRepository {
public Map<String, List<Treatment>> getTreatmentsByPatientId(List<String> sampleIds, List<String> studyIds, ClinicalEventKeyCode key);

@Cacheable(cacheResolver = "generalRepositoryCacheResolver", condition = "@cacheEnabledConfig.getEnabled()")
public Map<String, List<ClinicalEventSample>> getSamplesByPatientId(List<String> sampleIds, List<String> studyIds);
public List<Treatment> getTreatments(List<String> sampleIds, List<String> studyIds, ClinicalEventKeyCode key);

@Cacheable(cacheResolver = "generalRepositoryCacheResolver", condition = "@cacheEnabledConfig.getEnabled()")
public Map<String, List<ClinicalEventSample>> getShallowSamplesByPatientId(List<String> sampleIds, List<String> studyIds);
public Map<String, List<ClinicalEventSample>> getSamplesByPatientId(List<String> sampleIds, List<String> studyIds);

@Cacheable(cacheResolver = "generalRepositoryCacheResolver", condition = "@cacheEnabledConfig.getEnabled()")
public Set<String> getAllUniqueTreatments(List<String> sampleIds, List<String> studyIds, ClinicalEventKeyCode key);
public Map<String, List<ClinicalEventSample>> getShallowSamplesByPatientId(List<String> sampleIds, List<String> studyIds);

@Cacheable(cacheResolver = "generalRepositoryCacheResolver", condition = "@cacheEnabledConfig.getEnabled()")
public Integer getTreatmentCount(List<String> studies, String key);
public Boolean hasTreatmentData(List<String> studies, ClinicalEventKeyCode key);

@Cacheable(cacheResolver = "generalRepositoryCacheResolver", condition = "@cacheEnabledConfig.getEnabled()")
public Integer getSampleCount(List<String> studies);
public Boolean hasSampleTimelineData(List<String> studies);

@Cacheable(cacheResolver = "generalRepositoryCacheResolver", condition = "@cacheEnabledConfig.getEnabled()")
public Boolean studyIdHasTreatments(String studyId, ClinicalEventKeyCode key);
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@ public interface TreatmentMapper {

List<ClinicalEventSample> getAllShallowSamples(List<String> sampleIds, List<String> studyIds);

Set<String> getAllUniqueTreatments(List<String> sampleIds, List<String> studyIds, String key);
Boolean hasTreatmentData(List<String> sampleIds, List<String> studyIds, String key);

Integer getTreatmentCount(List<String> sampleIds, List<String> studyIds, String key);

Integer getSampleCount(List<String> sampleIds, List<String> studyIds);

Boolean studyIdHasTreatments(String studyId, String key);
Boolean hasSampleTimelineData(List<String> sampleIds, List<String> studyIds);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@ public class TreatmentMyBatisRepository implements TreatmentRepository {

@Override
public Map<String, List<Treatment>> getTreatmentsByPatientId(List<String> sampleIds, List<String> studyIds, ClinicalEventKeyCode key) {
return getTreatments(sampleIds, studyIds, key)
.stream()
.collect(groupingBy(Treatment::getPatientId));
}

@Override
public List<Treatment> getTreatments(List<String> sampleIds, List<String> studyIds, ClinicalEventKeyCode key) {
return treatmentMapper.getAllTreatments(sampleIds, studyIds, key.getKey())
.stream()
.flatMap(treatment -> splitIfDelimited(treatment, key))
.collect(groupingBy(Treatment::getPatientId));


.collect(Collectors.toList());
}

private Stream<Treatment> splitIfDelimited(Treatment unsplitTreatment, ClinicalEventKeyCode key) {
Expand Down Expand Up @@ -61,31 +66,12 @@ public Map<String, List<ClinicalEventSample>> getShallowSamplesByPatientId(List<
}

@Override
public Set<String> getAllUniqueTreatments(List<String> sampleIds, List<String> studyIds, ClinicalEventKeyCode key) {
return treatmentMapper.getAllUniqueTreatments(sampleIds, studyIds, key.getKey())
.stream()
.flatMap(treatment -> {
if (key.isDelimited()) {
return Arrays.stream(treatment.split(key.getDelimiter()));
} else {
return Stream.of(treatment);
}
})
.collect(Collectors.toSet());
}

@Override
public Integer getTreatmentCount(List<String> studies, String key) {
return treatmentMapper.getTreatmentCount(null, studies, key);
}

@Override
public Integer getSampleCount(List<String> studies) {
return treatmentMapper.getSampleCount(null, studies);
public Boolean hasTreatmentData(List<String> studies, ClinicalEventKeyCode key) {
return treatmentMapper.hasTreatmentData(null, studies, key.getKey());
}

@Override
public Boolean studyIdHasTreatments(String studyId, ClinicalEventKeyCode key) {
return treatmentMapper.studyIdHasTreatments(studyId, key.getKey());
public Boolean hasSampleTimelineData(List<String> studies) {
return treatmentMapper.hasSampleTimelineData(null, studies);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
INNER JOIN sample ON patient.INTERNAL_ID = sample.PATIENT_ID
INNER JOIN cancer_study ON patient.CANCER_STUDY_ID = cancer_study.CANCER_STUDY_ID
<include refid="where"/>
AND clinical_event.EVENT_TYPE = 'TREATMENT'
AND clinical_event_data.KEY = #{key}
</select>

Expand Down Expand Up @@ -48,35 +49,24 @@
<include refid="where"/>
</select>

<select id="getAllUniqueTreatments" resultType="java.lang.String">
SELECT
DISTINCT clinical_event_data.VALUE
<select id="hasTreatmentData" resultType="java.lang.Boolean">
SELECT EXISTS(SELECT
*
FROM
clinical_event
INNER JOIN clinical_event_data ON clinical_event.CLINICAL_EVENT_ID = clinical_event_data.CLINICAL_EVENT_ID
INNER JOIN patient ON clinical_event.PATIENT_ID = patient.INTERNAL_ID
INNER JOIN sample ON patient.INTERNAL_ID = sample.PATIENT_ID
INNER JOIN cancer_study ON patient.CANCER_STUDY_ID = cancer_study.CANCER_STUDY_ID
<include refid="where"/>
AND clinical_event_data.KEY = #{key}
AND clinical_event.EVENT_TYPE = 'TREATMENT'
AND clinical_event_data.KEY = #{key} LIMIT 1
)
</select>

<select id="getTreatmentCount" resultType="java.lang.Integer">
SELECT
COUNT(clinical_event.PATIENT_ID)
FROM
clinical_event
INNER JOIN clinical_event_data ON clinical_event.CLINICAL_EVENT_ID = clinical_event_data.CLINICAL_EVENT_ID
INNER JOIN patient ON clinical_event.PATIENT_ID = patient.INTERNAL_ID
INNER JOIN sample ON patient.INTERNAL_ID = sample.PATIENT_ID
INNER JOIN cancer_study ON patient.CANCER_STUDY_ID = cancer_study.CANCER_STUDY_ID
<include refid="where"/>
AND clinical_event_data.KEY = #{key}
</select>

<select id="getSampleCount" resultType="java.lang.Integer">
SELECT
COUNT(sample.STABLE_ID)
<select id="hasSampleTimelineData" resultType="java.lang.Boolean">
SELECT EXISTS(SELECT
*
FROM
clinical_event
INNER JOIN clinical_event_data ON clinical_event.CLINICAL_EVENT_ID = clinical_event_data.CLINICAL_EVENT_ID
Expand All @@ -85,20 +75,7 @@
INNER JOIN cancer_study ON patient.CANCER_STUDY_ID = cancer_study.CANCER_STUDY_ID
<include refid="where"/>
AND clinical_event_data.KEY = 'SAMPLE_ID'
AND (clinical_event.EVENT_TYPE LIKE 'Sample Acquisition' OR clinical_event.EVENT_TYPE LIKE 'SPECIMEN')
</select>

<select id="studyIdHasTreatments" resultType="java.lang.Boolean">
SELECT
count(cancer_study.CANCER_STUDY_IDENTIFIER) > 0
FROM
clinical_event
INNER JOIN clinical_event_data ON clinical_event.CLINICAL_EVENT_ID = clinical_event_data.CLINICAL_EVENT_ID
INNER JOIN patient ON clinical_event.PATIENT_ID = patient.INTERNAL_ID
INNER JOIN cancer_study ON patient.CANCER_STUDY_ID = cancer_study.CANCER_STUDY_ID
WHERE
clinical_event_data.KEY = #{key}
AND cancer_study.CANCER_STUDY_IDENTIFIER = #{studyId}
AND (clinical_event.EVENT_TYPE LIKE 'Sample Acquisition' OR clinical_event.EVENT_TYPE LIKE 'SPECIMEN') LIMIT 1)
</select>

<sql id="where">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,32 +134,20 @@ public void getShallowSamplesByPatientId() {
}

@Test
public void getAllUniqueTreatments() {
HashSet<String> expected = new HashSet<>(Collections.singletonList("Madeupanib"));

Set<String> actual = treatmentRepository.getAllUniqueTreatments(
Collections.singletonList("TCGA-A1-A0SD-01"),
Collections.singletonList("study_tcga_pub"),
ClinicalEventKeyCode.Agent
);

Assert.assertEquals(actual, expected);
}
public void hasTreatmentData() {

Assert.assertEquals(true, treatmentRepository.hasTreatmentData(Collections.singletonList("study_tcga_pub"), ClinicalEventKeyCode.Agent));

Assert.assertEquals(false, treatmentRepository.hasTreatmentData(Collections.singletonList("acc_tcga"), ClinicalEventKeyCode.Agent));

@Test
public void getTreatmentCount() {
Integer expected = 1;
Integer actual = treatmentRepository.getTreatmentCount(Collections.singletonList("study_tcga_pub"), ClinicalEventKeyCode.Agent.getKey());

Assert.assertEquals(actual, expected);
}

@Test
public void getSampleCount() {
Integer expected = 2;
Integer actual = treatmentRepository.getSampleCount(Collections.singletonList("study_tcga_pub"));
Assert.assertEquals(actual, expected);
public void hasSampleTimelineData() {

Assert.assertEquals(true, treatmentRepository.hasSampleTimelineData(Collections.singletonList("study_tcga_pub")));

Assert.assertEquals(false, treatmentRepository.hasSampleTimelineData(Collections.singletonList("acc_tcga")));
}

}
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@
<tomcat.session.timeout>720</tomcat.session.timeout>

<!-- THIS SHOULD BE KEPT IN SYNC TO VERSION IN CGDS.SQL -->
<db.version>2.13.0</db.version>
<db.version>2.13.1</db.version>
</properties>

<modules>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static java.util.stream.Collectors.toList;

@Service
public class TreatmentServiceImpl implements TreatmentService {
@Autowired
Expand All @@ -25,7 +23,7 @@ private Pair<List<String>, List<String>> filterIds(List<String> sampleIds, List<
}
Set<String> studiesWithTreatments = studyIds.stream()
.distinct()
.filter(studyId -> treatmentRepository.studyIdHasTreatments(studyId, key))
.filter(studyId -> treatmentRepository.hasTreatmentData(Collections.singletonList(studyId), key))
.collect(Collectors.toSet());

ArrayList<String> filteredSampleIds = new ArrayList<>();
Expand All @@ -40,16 +38,16 @@ private Pair<List<String>, List<String>> filterIds(List<String> sampleIds, List<
}
return new ImmutablePair<>(filteredSampleIds, filteredStudyIds);
}

@Override
public List<SampleTreatmentRow> getAllSampleTreatmentRows(List<String> sampleIds, List<String> studyIds, ClinicalEventKeyCode key) {
Pair<List<String>, List<String>> filteredIds = filterIds(sampleIds, studyIds, key);
sampleIds = filteredIds.getLeft();
studyIds = filteredIds.getRight();

Map<String, List<ClinicalEventSample>> samplesByPatient =
Map<String, List<ClinicalEventSample>> samplesByPatient =
treatmentRepository.getSamplesByPatientId(sampleIds, studyIds);
Map<String, List<Treatment>> treatmentsByPatient =
Map<String, List<Treatment>> treatmentsByPatient =
treatmentRepository.getTreatmentsByPatientId(sampleIds, studyIds, key);

Stream<SampleTreatmentRow> rows = samplesByPatient.keySet().stream()
Expand Down Expand Up @@ -150,71 +148,51 @@ Stream<SampleTreatmentRow> toRows() {
);
}
}

@Override
public List<PatientTreatmentRow> getAllPatientTreatmentRows(
List<String> sampleIds, List<String> studyIds, ClinicalEventKeyCode key
) {
Pair<List<String>, List<String>> filteredIds = filterIds(sampleIds, studyIds, key);
sampleIds = filteredIds.getLeft();
studyIds = filteredIds.getRight();

Map<String, List<Treatment>> treatmentsByPatient =
treatmentRepository.getTreatmentsByPatientId(sampleIds, studyIds, key);
Map<String, List<ClinicalEventSample>> samplesByPatient = treatmentRepository
.getShallowSamplesByPatientId(sampleIds, studyIds)
.entrySet()
.stream()
.filter(e -> treatmentsByPatient.containsKey(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
Set<String> treatments = treatmentRepository.getAllUniqueTreatments(sampleIds, studyIds, key);

return treatments.stream()
.map(t -> createPatientTreatmentRowForTreatment(t, treatmentsByPatient, samplesByPatient))
.collect(Collectors.toList());
}

private PatientTreatmentRow createPatientTreatmentRowForTreatment(
String treatment,
Map<String, List<Treatment>> treatmentsByPatient,
Map<String, List<ClinicalEventSample>> samplesByPatient
) {
// find all the patients that have received this treatment
List<Map.Entry<String, List<Treatment>>> matchingPatients = matchingPatients(treatment, treatmentsByPatient);
Map<String, List<ClinicalEventSample>> samplesByPatient = treatmentRepository
.getShallowSamplesByPatientId(sampleIds, studyIds);

// from those patients, extract the unique samples
Set<ClinicalEventSample> samples = matchingPatients
Map<String, List<Treatment>> treatmentSet = treatmentRepository.getTreatments(sampleIds, studyIds, key)
.stream()
.map(Map.Entry::getKey)
.flatMap(patient -> samplesByPatient.getOrDefault(patient, new ArrayList<>()).stream())
.collect(Collectors.toSet());


return new PatientTreatmentRow(treatment, matchingPatients.size(), samples);
}
.collect(Collectors.groupingBy(Treatment::getTreatment));

private List<Map.Entry<String, List<Treatment>>> matchingPatients(
String treatment,
Map<String, List<Treatment>> treatmentsByPatient
) {
return treatmentsByPatient.entrySet().stream()
.filter(p -> p.getValue().stream().anyMatch(t -> t.getTreatment().equals(treatment)))
.collect(toList());
/*
This logic transforms treatmentSet to list of PatientTreatmentRow. transformation steps:
- key in treatmentSet is going to be treatment
- get all unique patient ids -> this is going to give count
- get all clinicalEventSamples using above unique patient ids
*/
return treatmentSet.entrySet()
.stream()
.map(entry -> {
String treatment = entry.getKey();
Set<String> patientIds = entry.getValue().stream().map(Treatment::getPatientId).collect(Collectors.toSet());
Set<ClinicalEventSample> clinicalEventSamples = patientIds
.stream()
.flatMap(patientId -> samplesByPatient.getOrDefault(patientId, new ArrayList<>()).stream())
.collect(Collectors.toSet());
return new PatientTreatmentRow(treatment, patientIds.size(), clinicalEventSamples);
})
.collect(Collectors.toList());
}

@Override
public Boolean containsTreatmentData(List<String> studies, ClinicalEventKeyCode key) {
return treatmentRepository.getTreatmentCount(studies, key.getKey()) > 0;
return treatmentRepository.hasTreatmentData(studies, key);
}

@Override
public Boolean containsSampleTreatmentData(List<String> studyIds, ClinicalEventKeyCode key) {
studyIds = studyIds.stream()
.filter(studyId -> treatmentRepository.studyIdHasTreatments(studyId, key))
.filter(studyId -> treatmentRepository.hasTreatmentData(Collections.singletonList(studyId), key))
.collect(Collectors.toList());
Integer sampleCount = treatmentRepository.getSampleCount(studyIds);
Integer treatmentCount = treatmentRepository.getTreatmentCount(studyIds, key.getKey());

return sampleCount * treatmentCount > 0;
return studyIds.size() > 0 && treatmentRepository.hasSampleTimelineData(studyIds);
}
}
Loading

0 comments on commit b3400e6

Please sign in to comment.