diff --git a/src/main/scala/AssignTraceIds.scala b/src/main/scala/AssignTraceIds.scala index 3e560953..5192939d 100644 --- a/src/main/scala/AssignTraceIds.scala +++ b/src/main/scala/AssignTraceIds.scala @@ -20,9 +20,15 @@ import at.ac.oeaw.imba.gerlich.gerlib.graph.{ } import at.ac.oeaw.imba.gerlich.gerlib.imaging.* import at.ac.oeaw.imba.gerlich.gerlib.imaging.instances.all.given -import at.ac.oeaw.imba.gerlich.gerlib.io.csv.ColumnNames.SpotChannelColumnName +import at.ac.oeaw.imba.gerlich.gerlib.io.csv.ColumnNames.{ + NucleusDesignationColumnName, + SpotChannelColumnName, +} import at.ac.oeaw.imba.gerlich.gerlib.io.csv.instances.all.given -import at.ac.oeaw.imba.gerlich.gerlib.io.csv.readCsvToCaseClasses +import at.ac.oeaw.imba.gerlich.gerlib.io.csv.{ + readCsvToCaseClasses, + writeCaseClassesToCsv, +} import at.ac.oeaw.imba.gerlich.gerlib.numeric.* import at.ac.oeaw.imba.gerlich.gerlib.numeric.instances.all.given import at.ac.oeaw.imba.gerlich.gerlib.syntax.all.* @@ -34,13 +40,17 @@ import at.ac.oeaw.imba.gerlich.looptrace.cli.ScoptCliReaders import at.ac.oeaw.imba.gerlich.looptrace.csv.ColumnNames.{ MergeContributorsColumnNameForAssessedRecord, RoiIndexColumnName, + TraceIdColumnName, + TracePartnersColumName, } import at.ac.oeaw.imba.gerlich.looptrace.csv.getCsvRowDecoderForImagingChannel import at.ac.oeaw.imba.gerlich.looptrace.csv.instances.all.given +import at.ac.oeaw.imba.gerlich.looptrace.instances.all.given import at.ac.oeaw.imba.gerlich.looptrace.internal.BuildInfo import at.ac.oeaw.imba.gerlich.looptrace.roi.MergeAndSplitRoiTools.IndexedDetectedSpot import at.ac.oeaw.imba.gerlich.looptrace.space.BoundingBox import at.ac.oeaw.imba.gerlich.looptrace.ImagingRoundsConfiguration.ProximityGroup +import at.ac.oeaw.imba.gerlich.gerlib.cell.NuclearDesignation /** Assign trace IDs to regional spots, considering the potential to group some together for downstream analytical purposes. */ object AssignTraceIds extends ScoptCliReaders, StrictLogging: @@ -90,6 +100,11 @@ object AssignTraceIds extends ScoptCliReaders, StrictLogging: } } + private def checkTraceId(offLimits: NonEmptySet[TraceId])(tid: TraceId): Unit = + if (offLimits contains tid) { + throw new Exception(s"Trace ID is already a ROI index and can't be used: ${tid.show_}") + } + private def definePairwiseDistanceThresholds( rules: NonEmptyList[TraceIdDefinitionAndFiltrationRule], ): Map[(ImagingTimepoint, ImagingTimepoint), DistanceThreshold] = @@ -148,114 +163,171 @@ object AssignTraceIds extends ScoptCliReaders, StrictLogging: // Ensure each record gets a node, and add the discovered edges. buildSimpleGraph(records.map(_.index).toList.toSet, edgeEndpoints) + // Start trace IDs with 1 more than max ROI ID/index. + private def getInitialTraceId(roiIds: NonEmptyList[RoiIndex]): TraceId = + val maxRoiId = roiIds.toList.max(using summon[Order[RoiIndex]].toOrdering) + TraceId.unsafe(NonnegativeInt(1) + maxRoiId.get) + private def labelRecords( rules: NonEmptyList[TraceIdDefinitionAndFiltrationRule], discardIfNotInGroupOfInterest: Boolean, - )(maybeRecords: List[InputRecord]): Option[NonEmptyList[(InputRecord, TraceId, Option[NonEmptySet[RoiIndex]])]] = + )(records: NonEmptyList[InputRecord]): NonEmptyList[OutputRecord] = /* Necessary imports and type aliases */ import AtLeast2.syntax.{ remove, toNes, toSet } - import at.ac.oeaw.imba.gerlich.looptrace.instances.all.given // SimpleShow instances for domain types type TimepointExpectationLookup = NonEmptyMap[ImagingTimepoint, TraceIdDefinitionAndFiltrationRule] - maybeRecords.toNel.map{ records => - val lookupRecord: NonEmptyMap[RoiIndex, InputRecord] = records.map(r => r.index -> r).toNem - val lookupRule: TimepointExpectationLookup = - // Provide a way to get the expected group members and requirement stringency for a given timepoint. - given orderForKeyValuePairs[V]: Order[(ImagingTimepoint, V)] = Order.by(_._1) - given semigroup: Semigroup[TimepointExpectationLookup] = - Semigroup.instance{ (x, y) => - val collisions = x.keys & y.keys - if collisions.isEmpty then x ++ y - else throw new Exception(s"${collisions.size} key collision(s) between lookups to combine: $collisions") - } - rules.reduceMap{ r => r.mergeGroup.members.toNes.map(_ -> r).toNonEmptyList.toNem } - val initTraceId = - // Start trace IDs with 1 more than max ROI ID/index. - TraceId.unsafe(NonnegativeInt(1) + records.foldLeft(records.head.index){ (i, r) => i max r.index }.get) - val traceIdsOffLimits = - // Don't use any ROI index/ID as a trace ID. - records.map(_.index.get).map(TraceId.unsafe).toNes - computeNeighborsGraph(rules)(records) - .strongComponentTraverser() - .map(_.nodes.map(_.outer) // Get ROI IDs. - .toList.toNel // each component as a nonempty list - .getOrElse{ throw new Exception("Empty component!") }) // protected against by definition of component - .toList.toNel // We want a nonempty list of components to accumulate errors - .getOrElse{ throw new Exception("No components!") } // protected against by initial .toNel call on input ROIs - .traverse(_.traverse{ i => lookupRecord.apply(i).toValidNel(i) }) - .fold( - badIds => - // guarded against by construction of the lookup from records input - throw new Exception(s"${badIds.length} ROI IDs couldn't be looked up! Here's one: ${badIds.head.show_}"), - _.toList.foldRight(initTraceId -> List.empty[(InputRecord, TraceId, Option[NonEmptySet[RoiIndex]])]){ - case (recGroup, (currId, acc)) => - if (traceIdsOffLimits contains currId) { - // guarded against by starting with max ROI index and always incrementing the currId - throw new Exception(s"Trace ID is already a ROI index and can't be used: ${currId.show_}") - } - val newRecs: List[(InputRecord, TraceId, Option[NonEmptySet[RoiIndex]])] = - AtLeast2.either(recGroup.map(_.index).toList.toSet).fold( - Function.const{ // singleton - given Eq[RoiPartnersRequirementType] = Eq.fromUniversalEquals - val useRecord = lookupRule - .apply(recGroup.head.timepoint) - .fold(!discardIfNotInGroupOfInterest)(_.requirement === RoiPartnersRequirementType.Lackadaisical) - if useRecord then List((recGroup.head, currId, None)) - else List() - }, - multiIds => - val useGroup: Boolean = - recGroup // at least two ROIs in group/component - .toList - .flatMap{ r => lookupRule.apply(r.timepoint) } - .toNel - .fold(!discardIfNotInGroupOfInterest){ rules => - given Eq[TraceIdDefinitionAndFiltrationRule] = Eq.fromUniversalEquals - val nUniqueRules = rules.toList.toSet.size - if nUniqueRules =!= 1 - then throw new Exception( - s"$nUniqueRules unique merge rules (not 1) for single group!" - ) - else - val rule = rules.head - val expectedTimes = rule.mergeGroup.members.toSet - val observedTimes = recGroup.map(_.timepoint).toList.toSet - rule.requirement match { - case RoiPartnersRequirementType.Conjunctive => - observedTimes === expectedTimes - case RoiPartnersRequirementType.Disjunctive => - observedTimes subsetOf expectedTimes - case RoiPartnersRequirementType.Lackadaisical => - true - } - } - if useGroup - then recGroup.toList.map(rec => (rec, currId, multiIds.remove(rec.index).some)) - else List() - ) - val newTid = TraceId.unsafe(NonnegativeInt(1) + currId.get) - (newTid, newRecs ::: acc) - } - ) - ._2 - .toNel - .getOrElse{ - // guarded against by checking for empty input up front - throw new Exception("Wound up with empty results despite nonempty input!") + val lookupRecord: NonEmptyMap[RoiIndex, InputRecord] = records.map(r => r.index -> r).toNem + val lookupRule: TimepointExpectationLookup = + // Provide a way to get the expected group members and requirement stringency for a given timepoint. + given orderForKeyValuePairs[V]: Order[(ImagingTimepoint, V)] = Order.by(_._1) + given semigroup: Semigroup[TimepointExpectationLookup] = + Semigroup.instance{ (x, y) => + val collisions = x.keys & y.keys + if collisions.isEmpty then x ++ y + else throw new Exception(s"${collisions.size} key collision(s) between lookups to combine: $collisions") } - } + rules.reduceMap{ r => r.mergeGroup.members.toNes.map(_ -> r).toNonEmptyList.toNem } + val initTraceId = getInitialTraceId(records.map(_.index)) + val traceIdsOffLimits = + // Don't use any ROI index/ID as a trace ID. + records.map(_.index.get).map(TraceId.unsafe).toNes + computeNeighborsGraph(rules)(records) + .strongComponentTraverser() + .map(_.nodes.map(_.outer) // Get ROI IDs. + .toList.toNel // each component as a nonempty list + .getOrElse{ throw new Exception("Empty component!") }) // protected against by definition of component + .toList.toNel // We want a nonempty list of components to accumulate errors + .getOrElse{ throw new Exception("No components!") } // protected against by initial .toNel call on input ROIs + .traverse(_.traverse{ i => lookupRecord.apply(i).toValidNel(i) }) + .fold( + badIds => + // guarded against by construction of the lookup from records input + throw new Exception(s"${badIds.length} ROI IDs couldn't be looked up! Here's one: ${badIds.head.show_}"), + _.toList.foldRight(initTraceId -> List.empty[OutputRecord]){ + case (recGroup, (currId, acc)) => + checkTraceId(traceIdsOffLimits)(currId) + val newRecs: List[OutputRecord] = + AtLeast2.either(recGroup.map(_.index).toList.toSet).fold( + Function.const{ // singleton + given Eq[RoiPartnersRequirementType] = Eq.fromUniversalEquals + val useRecord = lookupRule + .apply(recGroup.head.timepoint) + .fold(!discardIfNotInGroupOfInterest)(_.requirement === RoiPartnersRequirementType.Lackadaisical) + if useRecord then List((recGroup.head, currId, None)) + else List() + }, + multiIds => + val useGroup: Boolean = + recGroup // at least two ROIs in group/component + .toList + .flatMap{ r => lookupRule.apply(r.timepoint) } + .toNel + .fold(!discardIfNotInGroupOfInterest){ rules => + given Eq[TraceIdDefinitionAndFiltrationRule] = Eq.fromUniversalEquals + val nUniqueRules = rules.toList.toSet.size + if nUniqueRules =!= 1 + then throw new Exception( + s"$nUniqueRules unique merge rules (not 1) for single group!" + ) + else + val rule = rules.head + val expectedTimes = rule.mergeGroup.members.toSet + val observedTimes = recGroup.map(_.timepoint).toList.toSet + rule.requirement match { + case RoiPartnersRequirementType.Conjunctive => + observedTimes === expectedTimes + case RoiPartnersRequirementType.Disjunctive => + observedTimes subsetOf expectedTimes + case RoiPartnersRequirementType.Lackadaisical => + true + } + } + if useGroup + then recGroup.toList.map(rec => (rec, currId, multiIds.remove(rec.index).some)) + else List() + ) + val newTid = TraceId.unsafe(NonnegativeInt(1) + currId.get) + (newTid, newRecs ::: acc) + } + ) + ._2 + .toNel + .getOrElse{ + // guarded against by checking for empty input up front + throw new Exception("Wound up with empty results despite nonempty input!") + } def workflow(roundsConfig: ImagingRoundsConfiguration, roisFile: os.Path, outputFile: os.Path): Unit = { - val readRois: IO[List[InputRecord]] = - import InputRecord.given - import fs2.data.text.utf8.* - given CsvRowDecoder[ImagingChannel, String] = - getCsvRowDecoderForImagingChannel(SpotChannelColumnName) - readCsvToCaseClasses(roisFile) - logger.info(s"Reading ROIs file: $roisFile") + import InputRecord.given + import fs2.data.text.utf8.* + given CsvRowDecoder[ImagingChannel, String] = + getCsvRowDecoderForImagingChannel(SpotChannelColumnName) + + IO{ logger.info(s"Reading ROIs file: $roisFile") } + .flatMap{ Function.const(readCsvToCaseClasses[InputRecord](roisFile)) } + .map(_.toNel match { + case None => + logger.error(s"No input record parsed from ROIs file ($roisFile)!") + Option.empty[NonEmptyList[OutputRecord]] + case Some(records) => + roundsConfig.mergeRules match { + case None => + val initTraceId = getInitialTraceId(records.map(_.index)) + val traceIdsOffLimits = records.map(r => TraceId(r.index.get)).toNes + records.zipWithIndex.map{ (r, i) => + val newTid = TraceId.unsafe(NonnegativeInt.unsafe(i) + initTraceId.get) + checkTraceId(traceIdsOffLimits)(newTid) + (r, newTid, None) + }.some + case Some(rules) => + labelRecords(rules, roundsConfig.discardRoisNotInGroupsOfInterest)(records).some + } + }) + .flatMap(_ match { + case None => IO{ logger.error("No output to write!") } + case Some(records) => + import OutputRecord.given + given CsvRowEncoder[ImagingChannel, String] = + // for derivation of CsvRowEncoder[ImagingContext, String] + SpotChannelColumnName.toNamedEncoder + logger.info(s"Writing output file: $outputFile") + fs2.Stream + .emits(records.toList) + .through(writeCaseClassesToCsv[OutputRecord](outputFile)) + .compile + .drain + }) + .unsafeRunSync() + logger.info("Done!") } + private type OutputRecord = (InputRecord, TraceId, Option[NonEmptySet[RoiIndex]]) + + object OutputRecord: + given encOutRec(using + encRoiId: CellEncoder[RoiIndex], + encContext: CsvRowEncoder[ImagingContext, String], + encCentroid: CsvRowEncoder[Centroid[Double], String], + encBox: CsvRowEncoder[BoundingBox, String], + encTid: CellEncoder[TraceId], + ): CsvRowEncoder[OutputRecord, String] with + override def apply(elem: OutputRecord): RowF[Some, String] = + val (inrec, tid, maybePartners) = elem + val idRow = RoiIndexColumnName.write(inrec.index) + val ctxRow = encContext(inrec.context) + val centerRow = encCentroid(inrec.centroid) + val boxRow = encBox(inrec.box) + val tidRow = TraceIdColumnName.write(tid) + val partnersRow = + TracePartnersColumName.write(maybePartners.fold(Set())(_.toSortedSet.toSet)) + val nucRow = RowF( + values = NonEmptyList.one(inrec.maybeNucleusNumber.fold("")(_.get.show_)), + headers = Some(NonEmptyList.one(NucleusDesignationColumnName.value)), + ) + idRow |+| ctxRow |+| centerRow |+| boxRow |+| nucRow |+| tidRow |+| partnersRow + end OutputRecord + final case class InputRecord( index: RoiIndex, context: ImagingContext, @@ -273,7 +345,6 @@ object AssignTraceIds extends ScoptCliReaders, StrictLogging: decContext: CsvRowDecoder[ImagingContext, String], decCentroid: CsvRowDecoder[Centroid[Double], String], decBox: CsvRowDecoder[BoundingBox, String], - decNuc: CellDecoder[NucleusNumber] ): CsvRowDecoder[InputRecord, String] = new: override def apply(row: RowF[Some, String]): DecoderResult[InputRecord] = val spotNel = summon[CsvRowDecoder[IndexedDetectedSpot, String]](row) @@ -281,7 +352,15 @@ object AssignTraceIds extends ScoptCliReaders, StrictLogging: .toValidatedNel val mergeInputsNel: ValidatedNel[String, Set[RoiIndex]] = MergeContributorsColumnNameForAssessedRecord.from(row) - val nucNel: ValidatedNel[String, Option[NucleusNumber]] = ??? + val nucNel: ValidatedNel[String, Option[NucleusNumber]] = + val key = NucleusDesignationColumnName.value + row.apply(key) + .toRight(s"Missing header/field '$key'") + .flatMap{ + case "" => None.asRight + case s => NucleusNumber.parse(s).map(_.some) + } + .toValidatedNel (spotNel, mergeInputsNel, nucNel) .mapN{ (spot, maybeMergeIndices, maybeNucNum) => InputRecord(spot.index, spot.context, spot.centroid, spot.box, maybeMergeIndices, maybeNucNum) diff --git a/src/main/scala/csv/ColumnNames.scala b/src/main/scala/csv/ColumnNames.scala index c53eaf9d..1facb926 100644 --- a/src/main/scala/csv/ColumnNames.scala +++ b/src/main/scala/csv/ColumnNames.scala @@ -48,6 +48,8 @@ object ColumnNames: val TraceIdColumnName: ColumnName[TraceId] = ColumnName("traceId") + val TracePartnersColumName: ColumnName[Set[RoiIndex]] = ColumnName("tracePartners") + private def coarseDriftColumnName[A <: EuclideanAxis](a: A): ColumnName[CoarseDriftComponent[A]] = val coarseDriftColumnSuffix: String = "DriftCoarsePixels" ColumnName(a match {