diff --git a/src/strucclustutils/createcomplexreport.h b/src/strucclustutils/createcomplexreport.h index faa96a29..fc45fa9d 100644 --- a/src/strucclustutils/createcomplexreport.h +++ b/src/strucclustutils/createcomplexreport.h @@ -8,20 +8,17 @@ const double MAX_ASSIGNED_CHAIN_RATIO = 1.0; const double TOO_SMALL_MEAN = 1.0; const double TOO_SMALL_CV = 0.1; const double FILTERED_OUT = 0.0; -const unsigned int UNCLUSTERED = 0; -const unsigned int CLUSTERED = 1; -const unsigned int MIN_PTS = 2; +const bool UNCLUSTERED = false; +const bool CLUSTERED = true; const float BIT_SCORE_MARGIN = 0.7; const float DEF_BIT_SCORE = -1.0; const int UNINITIALIZED = 0; -const float LEARNING_RATE = 0.1; -const float DEFAULT_EPS = 0.1; -const unsigned int MULTIPLE_CHAINED_COMPOLEX = 2; +const unsigned int MULTIPLE_CHAINED_COMPLEX = 2; typedef std::pair compNameChainName_t; typedef std::map chainKeyToComplexId_t; typedef std::map> complexIdToChainKeys_t; typedef std::vector cluster_t; -typedef std::map, double> distMap_t; +typedef std::map, float> distMap_t; typedef std::string resultToWrite_t; typedef std::string chainName_t; typedef std::pair resultToWriteWithKey_t; diff --git a/src/strucclustutils/scorecomplex.cpp b/src/strucclustutils/scorecomplex.cpp index 2a649141..6441793a 100644 --- a/src/strucclustutils/scorecomplex.cpp +++ b/src/strucclustutils/scorecomplex.cpp @@ -275,22 +275,6 @@ bool compareChainToChainAlnByDbComplexId(const ChainToChainAln &first, const Cha return false; } -bool compareChainToChainAlnByClusterLabel(const ChainToChainAln &first, const ChainToChainAln &second) { - if (first.label < second.label) - return true; - if (first.label > second.label) - return false; - if (first.qChain.chainKey < second.qChain.chainKey) - return true; - if (first.qChain.chainKey > second.qChain.chainKey) - return false; - if (first.dbChain.chainKey < second.dbChain.chainKey) - return true; - if (first.dbChain.chainKey > second.dbChain.chainKey) - return false; - return false; -} - bool compareAssignment(const Assignment &first, const Assignment &second) { if (first.qTmScore > second.qTmScore) return true; @@ -311,114 +295,75 @@ bool compareNeighborWithDist(const NeighborsWithDist &first, const NeighborsWith return false; } -class DBSCANCluster { +class NearestNeighborsCluster { public: - DBSCANCluster(SearchResult &searchResult, double minCov) : searchResult(searchResult) { - cLabel = 0; - minClusterSize = (unsigned int) ((double) searchResult.qChainKeys.size() * minCov); + NearestNeighborsCluster(SearchResult &searchResult, std::set &finalClusters, double minCov) : searchResult(searchResult), finalClusters(finalClusters) { + minClusterSize = std::max(MULTIPLE_CHAINED_COMPLEX, (unsigned int) ((double) searchResult.qChainKeys.size() * minCov)); idealClusterSize = std::min(searchResult.qChainKeys.size(), searchResult.dbChainKeys.size()); - finalClusters.clear(); prevMaxClusterSize = 0; - maxDist = 0; - eps = DEFAULT_EPS; - learningRate = LEARNING_RATE; } - unsigned int getAlnClusters() { - // rbh filter + bool getAlnClusters() { filterAlnsByRBH(); fillDistMap(); - // To skip DBSCAN clustering when alignments are few enough. if (searchResult.alnVec.size() <= idealClusterSize) return checkClusteringNecessity(); - return runDBSCAN(); + return getNearestNeighbors(); } private: SearchResult &searchResult; - float eps; - float maxDist; - float learningRate; - unsigned int cLabel; + std::set &finalClusters; unsigned int prevMaxClusterSize; - unsigned int maxClusterSize; unsigned int idealClusterSize; unsigned int minClusterSize; - std::vector neighbors; - std::vector neighborsOfCurrNeighbor; + cluster_t neighbors; std::vector neighborsWithDist; std::set qFoundChainKeys; std::set dbFoundChainKeys; distMap_t distMap; - std::vector currClusters; - std::set finalClusters; std::map qBestBitScore; std::map dbBestBitScore; - unsigned int runDBSCAN() { - initializeAlnLabels(); - if (eps >= maxDist) - return finishDBSCAN(); + bool getNearestNeighbors() { + finalClusters.clear(); + for (size_t i=0; i < searchResult.alnVec.size(); i++) { + neighbors.clear(); + neighborsWithDist.clear(); + qFoundChainKeys.clear(); + dbFoundChainKeys.clear(); - for (size_t centerAlnIdx=0; centerAlnIdx < searchResult.alnVec.size(); centerAlnIdx++) { - ChainToChainAln ¢erAln = searchResult.alnVec[centerAlnIdx]; - if (centerAln.label != 0) - continue; + for (size_t j = 0; j < searchResult.alnVec.size(); j++) { + neighborsWithDist.emplace_back(j, i == j ? 0.0 : distMap[{std::min(i, j), std::max(i, j)}]); + } - getNeighbors(centerAlnIdx, neighbors); - if (neighbors.size() < MIN_PTS) - continue; + SORT_SERIAL(neighborsWithDist.begin(), neighborsWithDist.end(), compareNeighborWithDist); + for (auto &neighborWithDist: neighborsWithDist) { + if (neighbors.size() >= idealClusterSize) + break; - centerAln.label = ++cLabel; - unsigned int neighborIdx = 0; - while (neighborIdx < neighbors.size()) { - unsigned int neighborAlnIdx = neighbors[neighborIdx++]; - if (centerAlnIdx == neighborAlnIdx) - continue; + if (!qFoundChainKeys.insert(searchResult.alnVec[neighborWithDist.neighbor].qChain.chainKey).second) + break; - ChainToChainAln &neighborAln = searchResult.alnVec[neighborAlnIdx]; - neighborAln.label = cLabel; - getNeighbors(neighborAlnIdx, neighborsOfCurrNeighbor); - if (neighborsOfCurrNeighbor.size() < MIN_PTS) - continue; + if (!dbFoundChainKeys.insert(searchResult.alnVec[neighborWithDist.neighbor].dbChain.chainKey).second) + break; - for (auto neighbor : neighborsOfCurrNeighbor) { - if (std::find(neighbors.begin(), neighbors.end(), neighbor) == neighbors.end()) - neighbors.emplace_back(neighbor); - } + neighbors.emplace_back(neighborWithDist.neighbor); } - if (neighbors.size() > idealClusterSize || checkChainRedundancy()) - getNearestNeighbors(centerAlnIdx); - - // too small cluster - if (neighbors.size() < maxClusterSize) + if (neighbors.size() < prevMaxClusterSize) continue; - // new Biggest cluster - if (neighbors.size() > maxClusterSize) { - maxClusterSize = neighbors.size(); - currClusters.clear(); + if (neighbors.size() > prevMaxClusterSize) { + prevMaxClusterSize = neighbors.size(); + finalClusters.clear(); } - currClusters.emplace_back(neighbors); + SORT_SERIAL(neighbors.begin(), neighbors.end()); + finalClusters.insert(neighbors); } - - if (!finalClusters.empty() && currClusters.empty()) - return finishDBSCAN(); - - if (maxClusterSize < prevMaxClusterSize) - return finishDBSCAN(); - - if (maxClusterSize > prevMaxClusterSize) { - finalClusters.clear(); - prevMaxClusterSize = maxClusterSize; - } - - finalClusters.insert(currClusters.begin(), currClusters.end()); - eps += learningRate; - return runDBSCAN(); + return finishClustering(); } void fillDistMap() { @@ -426,40 +371,15 @@ class DBSCANCluster { distMap.clear(); for (size_t i=0; i < searchResult.alnVec.size(); i++) { ChainToChainAln &prevAln = searchResult.alnVec[i]; + for (size_t j = i+1; j < searchResult.alnVec.size(); j++) { ChainToChainAln &currAln = searchResult.alnVec[j]; dist = prevAln.getDistance(currAln); - maxDist = std::max(maxDist, dist); distMap.insert({{i,j}, dist}); } } } - void getNeighbors(unsigned int centerIdx, std::vector &neighborVec) { - neighborVec.clear(); - neighborVec.emplace_back(centerIdx); - for (size_t neighborIdx = 0; neighborIdx < searchResult.alnVec.size(); neighborIdx++) { - - if (neighborIdx == centerIdx) - continue; - - if ((centerIdx < neighborIdx ? distMap[{centerIdx, neighborIdx}] : distMap[{neighborIdx, centerIdx}]) >= eps) - continue; - - neighborVec.emplace_back(neighborIdx); - } -// return; - } - - void initializeAlnLabels() { - for (auto &aln : searchResult.alnVec) { - aln.label = UNCLUSTERED; - } - cLabel = UNCLUSTERED; - maxClusterSize = 0; - currClusters.clear(); - } - bool checkChainRedundancy() { qFoundChainKeys.clear(); dbFoundChainKeys.clear(); @@ -474,35 +394,36 @@ class DBSCANCluster { return false; } - unsigned int checkClusteringNecessity() { + bool checkClusteringNecessity() { if (searchResult.alnVec.empty()) return UNCLUSTERED; + for (size_t alnIdx=0; alnIdx= std::max(qBestBitScore[qKey], dbBestBitScore[dbKey]) * BIT_SCORE_MARGIN) { - alnIdx ++; + alnIdx++; continue; } searchResult.alnVec.erase(searchResult.alnVec.begin() + alnIdx); } -// return; - } - - void getNearestNeighbors(unsigned int centerIdx) { - qFoundChainKeys.clear(); - dbFoundChainKeys.clear(); - neighborsWithDist.clear(); - - for (auto neighborIdx: neighbors) { - if (neighborIdx == centerIdx) { - neighborsWithDist.emplace_back(neighborIdx, 0.0); - continue; - } - neighborsWithDist.emplace_back(neighborIdx, neighborIdx < centerIdx ? distMap[{neighborIdx, centerIdx}] : distMap[{centerIdx, neighborIdx}]); - } - SORT_SERIAL(neighborsWithDist.begin(), neighborsWithDist.end(), compareNeighborWithDist); - neighbors.clear(); - for (auto neighborWithDist : neighborsWithDist) { - if (!qFoundChainKeys.insert(searchResult.alnVec[neighborWithDist.neighbor].qChain.chainKey).second) - break; - - if (!dbFoundChainKeys.insert(searchResult.alnVec[neighborWithDist.neighbor].dbChain.chainKey).second) - break; - - neighbors.emplace_back(neighborWithDist.neighbor); - } + qBestBitScore.clear(); + dbBestBitScore.clear(); // return; } }; @@ -579,15 +476,20 @@ class ComplexScorer { } void getSearchResults(unsigned int qComplexId, std::vector &qChainKeys, chainKeyToComplexId_t &dbChainKeyToComplexIdLookup, complexIdToChainKeys_t &dbComplexIdToChainKeysLookup, std::vector &searchResults) { + hasBacktrace = false; unsigned int qResLen = getQueryResidueLength(qChainKeys); - if (qResLen == 0) return; + if (qResLen == 0) + return; + paredSearchResult = SearchResult(qChainKeys, qResLen); // for each chain from the query Complex for (auto qChainKey: qChainKeys) { unsigned int qKey = alnDbr.getId(qChainKey); if (qKey == NOT_AVAILABLE_CHAIN_KEY) continue; char *data = alnDbr.getData(qKey, thread_idx); - if (*data == '\0') continue; + if (*data == '\0') + continue; + qAlnResult = Matcher::parseAlignmentRecord(data); size_t qDbId = qCaDbr->sequenceReader->getId(qChainKey); char *qCaData = qCaDbr->sequenceReader->getData(qDbId, thread_idx); @@ -596,6 +498,7 @@ class ComplexScorer { float *queryCaData = qCoords.read(qCaData, qAlnResult.qLen, qCaLength); qChain = Chain(qComplexId, qChainKey); tmAligner->initQuery(queryCaData, &queryCaData[qLen], &queryCaData[qLen * 2], NULL, qLen); + // for each alignment from the query chain while (*data != '\0') { char dbKeyBuffer[255 + 1]; @@ -603,10 +506,11 @@ class ComplexScorer { const auto dbChainKey = (unsigned int) strtoul(dbKeyBuffer, NULL, 10); const unsigned int dbComplexId = dbChainKeyToComplexIdLookup.at(dbChainKey); dbAlnResult = Matcher::parseAlignmentRecord(data); - if (dbAlnResult.backtrace.empty()) { - Debug(Debug::ERROR) << "Backtraces are required. Please run search with '-a' option.\n"; - EXIT(EXIT_FAILURE); - } + data = Util::skipLine(data); + if (dbAlnResult.backtrace.empty()) + continue; + + hasBacktrace = true; size_t tCaId = tCaDbr->sequenceReader->getId(dbChainKey); char *tCaData = tCaDbr->sequenceReader->getData(tCaId, thread_idx); size_t tCaLength = tCaDbr->sequenceReader->getEntryLen(tCaId); @@ -617,9 +521,13 @@ class ComplexScorer { currAln = ChainToChainAln(qChain, dbChain, queryCaData, targetCaData, dbAlnResult, tmResult); currAlns.emplace_back(currAln); currAln.free(); - data = Util::skipLine(data); } // while end } // for end + // When alignments have no backtrace + if (!hasBacktrace) { + Debug(Debug::ERROR) << "Backtraces are required. Please run search with '-a' option.\n"; + EXIT(EXIT_FAILURE); + } if (currAlns.empty()) return; @@ -629,6 +537,7 @@ class ComplexScorer { std::vector currDbChainKeys = dbComplexIdToChainKeysLookup.at(currDbComplexId); unsigned int currDbResLen = getDbResidueLength(currDbChainKeys); paredSearchResult.resetDbComplex(currDbChainKeys, currDbResLen); + for (auto &aln: currAlns) { if (aln.dbChain.complexId == currDbComplexId) { paredSearchResult.alnVec.emplace_back(aln); @@ -649,7 +558,7 @@ class ComplexScorer { currAlns.clear(); paredSearchResult.filterAlnVec(minAssignedChainsRatio); paredSearchResult.standardize(); - if (!paredSearchResult.alnVec.empty() && currDbChainKeys.size() >= MULTIPLE_CHAINED_COMPOLEX) + if (!paredSearchResult.alnVec.empty()) // && currDbChainKeys.size() >= NOT_SINGLE_CHAIN searchResults.emplace_back(paredSearchResult); paredSearchResult.alnVec.clear(); @@ -661,28 +570,27 @@ class ComplexScorer { maxResLen = std::max(searchResult.qChainKeys.size(), searchResult.dbChainKeys.size()) * maxChainLen; tmAligner = new TMaligner(maxResLen, false, true); } - unsigned int currLabel; - DBSCANCluster dbscanCluster = DBSCANCluster(searchResult, minAssignedChainsRatio); - currLabel = dbscanCluster.getAlnClusters(); - if (currLabel == UNCLUSTERED) return; + + finalClusters.clear(); + NearestNeighborsCluster nnCluster = NearestNeighborsCluster(searchResult, finalClusters, minAssignedChainsRatio); + if (!nnCluster.getAlnClusters()) { + finalClusters.clear(); + return; + } + assignment = Assignment(searchResult.qResidueLen, searchResult.dbResidueLen); - for (auto &currAln: searchResult.alnVec) { - if (currAln.label == UNCLUSTERED) - continue; + for (auto &cluster: finalClusters) { - if (currAln.label != currLabel) { - assignment.getTmScore(*tmAligner); - assignment.updateResultToWriteLines(); - assignments.emplace_back(assignment); - assignment.reset(); - currLabel = currAln.label; + for (auto alnIdx: cluster) { + assignment.appendChainToChainAln(searchResult.alnVec[alnIdx]); } - assignment.appendChainToChainAln(currAln); + + assignment.getTmScore(*tmAligner); + assignment.updateResultToWriteLines(); + assignments.emplace_back(assignment); + assignment.reset(); } - assignment.getTmScore(*tmAligner); - assignment.updateResultToWriteLines(); - assignments.emplace_back(assignment); - assignment.reset(); + finalClusters.clear(); } void free() { @@ -711,6 +619,8 @@ class ComplexScorer { std::vector currAlns; Assignment assignment; SearchResult paredSearchResult; + std::set finalClusters; + bool hasBacktrace; unsigned int getQueryResidueLength(std::vector &qChainKeys) { unsigned int qResidueLen = 0; @@ -833,8 +743,9 @@ int scorecomplex(int argc, const char **argv, const Command &command) { for (size_t qCompIdx = 0; qCompIdx < qComplexIndices.size(); qCompIdx++) { unsigned int qComplexId = qComplexIndices[qCompIdx]; std::vector &qChainKeys = qComplexIdToChainKeysMap.at(qComplexId); - if (qChainKeys.size() < MULTIPLE_CHAINED_COMPOLEX) + if (qChainKeys.size() < MULTIPLE_CHAINED_COMPLEX) continue; + complexScorer.getSearchResults(qComplexId, qChainKeys, dbChainKeyToComplexIdMap, dbComplexIdToChainKeysMap, searchResults); // for each db complex for (size_t dbId = 0; dbId < searchResults.size(); dbId++) {