Skip to content

Commit

Permalink
use a CQL query to narrow down the EC2 instances returned by listCont…
Browse files Browse the repository at this point in the history
…ainerInstances when searching by instance ID

this should dramatically reduce the number of ListContainerInstances,
DescribeContainerInstances, ListTasks, and DescribeTasks calls made by
the Draining lambda.
  • Loading branch information
bpholt committed Dec 6, 2023
1 parent e4abe83 commit 2523c49
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import com.dwolla.aws.ecs.*

abstract class TestEcsAlg extends EcsAlg[IO, List] {
override def listClusterArns: List[ClusterArn] = ???
override def listContainerInstances(cluster: ClusterArn): List[ContainerInstance] = ???
override def listContainerInstances(cluster: ClusterArn, filter: Option[CQLQuery]): List[ContainerInstance] = ???
override def findEc2Instance(ec2InstanceId: InstanceId): IO[Option[(ClusterArn, ContainerInstance)]] = IO.raiseError(new NotImplementedError)
override def isTaskDefinitionRunningOnInstance(cluster: ClusterArn, ci: ContainerInstance, taskDefinition: TaskDefinitionArn): IO[Boolean] = IO.raiseError(new NotImplementedError)
override def drainInstanceImpl(cluster: ClusterArn, ci: ContainerInstance): IO[Unit] = IO.raiseError(new NotImplementedError)
Expand Down
29 changes: 23 additions & 6 deletions core-tests/src/test/scala/com/dwolla/aws/ecs/EcsAlgSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import cats.*
import cats.data.*
import cats.effect.*
import cats.syntax.all.*
import com.amazonaws.ec2.InstanceId
import com.amazonaws.ecs.ContainerInstanceStatus.DRAINING
import com.amazonaws.ecs.{BoxedInteger, ContainerInstanceField, DescribeContainerInstancesResponse, DescribeTasksResponse, DesiredStatus, ECS, Failure, LaunchType, ListClustersResponse, ListContainerInstancesResponse, ListTasksResponse, Task, TaskField, UpdateContainerInstancesStateResponse, ContainerInstance as AwsContainerInstance}
import com.dwolla.*
Expand All @@ -26,6 +27,7 @@ class EcsAlgSpec
given [F[_] : Applicative]: LoggerFactory[F] = NoOpFactory[F]

def fakeECS(arbCluster: ArbitraryCluster): ECS[IO] = new ECS.Default[IO](new NotImplementedError().raiseError) {
private val ec2InstanceIdCQL = """ec2InstanceId == (i-.+)""".r
private lazy val listClustersResponses: Map[NextPageToken, ListClustersResponse] =
ArbitraryPagination.paginateWith[Chunk, ArbitraryCluster, ClusterWithInstances, ClusterArn](arbCluster) {
case ClusterWithInstances((c, _)) => c.clusterArn
Expand All @@ -37,17 +39,32 @@ class EcsAlgSpec
}
.toMap

private lazy val listContainerInstancesResponses: Map[Option[ClusterArn], Map[NextPageToken, ListContainerInstancesResponse]] =
private lazy val listContainerInstancesResponses: Map[Option[ClusterArn], Option[CQLQuery] => Map[NextPageToken, ListContainerInstancesResponse]] =
arbCluster
.value
.flatMap(_.toList)
.map { cwi =>
val clusterArn: Option[ClusterArn] = cwi.value._1.clusterArn.some
val pages = ArbitraryPagination.paginate(cwi.value._2).view.mapValues {
case (c, n) =>
ListContainerInstancesResponse(c.map(_.containerInstanceId.value).toList.some, n.value)
val pages: Option[CQLQuery] => Map[NextPageToken, ListContainerInstancesResponse] = maybeQuery => {
val maybeInstanceId = maybeQuery.map(_.value).collect {
case ec2InstanceIdCQL(instanceId) => InstanceId(instanceId)
}
.toMap
ArbitraryPagination.paginate(cwi.value._2).view.mapValues {
case (c, n) =>
val containerInstances =
c
.filter { ciwtp =>
// if incoming query is empty, return true
// if incoming query matches ec2InstanceIdCQL regex and ciwtp contains a matching InstanceID, return true
maybeQuery.isEmpty || maybeInstanceId.contains(ciwtp.ec2InstanceId.value)
}
.map(_.containerInstanceId.value)
.toList
.some
ListContainerInstancesResponse(containerInstances, n.value)
}
.toMap
}

clusterArn -> pages
}
Expand Down Expand Up @@ -122,12 +139,12 @@ class EcsAlgSpec
maxResults: Option[BoxedInteger],
status: Option[com.amazonaws.ecs.ContainerInstanceStatus]): IO[ListContainerInstancesResponse] =
rejectParameters("listContainerInstances")(
filter.as("filter"),
maxResults.as("maxResults"),
status.as("status"),
).as {
listContainerInstancesResponses
.get(cluster.map(ClusterArn(_)))
.map(_(filter.map(CQLQuery(_))))
.flatMap(_.get(NextPageToken(nextToken)))
.getOrElse(ListContainerInstancesResponse())
}
Expand Down
12 changes: 7 additions & 5 deletions core/src/main/scala/com/dwolla/aws/ecs/EcsAlg.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ import com.dwolla.aws.TraceableValueInstances.given

abstract class EcsAlg[F[_] : Applicative, G[_]] {
def listClusterArns: G[ClusterArn]
def listContainerInstances(cluster: ClusterArn): G[ContainerInstance]
def listContainerInstances(cluster: ClusterArn,
filter: Option[CQLQuery] = None,
): G[ContainerInstance]
def findEc2Instance(ec2InstanceId: InstanceId): F[Option[(ClusterArn, ContainerInstance)]]
def drainInstance(cluster: ClusterArn, ci: ContainerInstance): F[Unit] =
drainInstanceImpl(cluster, ci).unlessA(ci.status == ContainerInstanceStatus.Draining)
Expand All @@ -44,12 +46,13 @@ object EcsAlg {
.map(ClusterArn(_))
}

override def listContainerInstances(cluster: ClusterArn): Stream[F, ContainerInstance] =
override def listContainerInstances(cluster: ClusterArn,
filter: Option[CQLQuery]): Stream[F, ContainerInstance] =
Trace[Stream[F, *]].span("EcsAlg.listContainerInstances") {
Trace[Stream[F, *]].put("cluster" -> cluster) >>
Pagination.offsetUnfoldChunkEval { (nextToken: Option[String]) =>
ecs
.listContainerInstances(cluster.value.some, nextToken = nextToken)
.listContainerInstances(cluster.value.some, nextToken = nextToken, filter = filter.map(_.value))
.map { resp =>
resp.containerInstanceArns.toChunk.map(ContainerInstanceId(_)) -> resp.nextToken
}
Expand Down Expand Up @@ -92,8 +95,7 @@ object EcsAlg {
Trace[F].put("ec2InstanceId" -> ec2InstanceId) >>
LoggerFactory[F].create.flatMap { case given Logger[F] =>
listClusterArns
// TODO listContainerInstances could use a CQL expression to narrow the search
.mproduct(listContainerInstances(_).filter(_.ec2InstanceId == ec2InstanceId))
.mproduct(listContainerInstances(_, filter = CQLQuery(s"ec2InstanceId == $ec2InstanceId").some))
.compile
.last
.flatTap { ec2Instance =>
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/com/dwolla/aws/ecs/model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ object TaskCount extends NewtypeWrapped[Long] {
given Order[TaskCount] = Order[Long].contramap(_.value)
}

type CQLQuery = CQLQuery.Type
object CQLQuery extends NewtypeWrapped[String]

case class Cluster(region: AwsRegion, accountId: AccountId, name: ClusterName) {
val clusterArn: ClusterArn = ClusterArn(s"arn:aws:ecs:$region:$accountId:cluster/$name")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ object TestApp extends ResourceApp.Simple {
.flatMap { ecs =>
ecs.listClusterArns
.filter(_.value.contains("Production"))
.flatMap(ecs.listContainerInstances)
.flatMap(ecs.listContainerInstances(_))
}
.evalMap(c => IO.println(c))
.compile
Expand Down

0 comments on commit 2523c49

Please sign in to comment.