diff --git a/integration-tests/src/test/java/com/transferwise/tasks/testapp/TaskProcessingIntTest.java b/integration-tests/src/test/java/com/transferwise/tasks/testapp/TaskProcessingIntTest.java index f8f4a0e7..bdb2bb3c 100644 --- a/integration-tests/src/test/java/com/transferwise/tasks/testapp/TaskProcessingIntTest.java +++ b/integration-tests/src/test/java/com/transferwise/tasks/testapp/TaskProcessingIntTest.java @@ -13,6 +13,7 @@ import com.transferwise.tasks.ITaskDataSerializer; import com.transferwise.tasks.ITasksService; import com.transferwise.tasks.ITasksService.AddTaskRequest; +import com.transferwise.tasks.TasksProperties; import com.transferwise.tasks.dao.ITaskDao; import com.transferwise.tasks.dao.ITaskDao.InsertTaskRequest; import com.transferwise.tasks.domain.IBaseTask; @@ -48,6 +49,8 @@ import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerRecord; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -70,6 +73,10 @@ public class TaskProcessingIntTest extends BaseIntTest { protected ITaskDataSerializer taskDataSerializer; @Autowired protected GlobalProcessingState globalProcessingState; + @Autowired + protected TasksProperties tasksProperties; + @Autowired + protected KafkaProducer kafkaProducer; private KafkaTasksExecutionTriggerer kafkaTasksExecutionTriggerer; @@ -152,8 +159,12 @@ void allUniqueTasksWillGetProcessed(int scenario) throws Exception { log.info("Tasks execution took {} ms", end - start); KafkaTasksExecutionTriggerer.ConsumerBucket consumerBucket = kafkaTasksExecutionTriggerer.getConsumerBucket("default"); - assertEquals(0, consumerBucket.getOffsetsCompletedCount()); - assertEquals(0, consumerBucket.getOffsetsCount()); + kafkaTasksExecutionTriggerer.stopTasksProcessing("default").get(); + kafkaTasksExecutionTriggerer.startTasksProcessing("default"); + var originalOffset = consumerBucket.getOffsetsCompletedCount(); + + assertEquals(originalOffset, consumerBucket.getOffsetsCompletedCount()); + assertEquals(originalOffset, consumerBucket.getOffsetsCount()); assertEquals(0, consumerBucket.getUnprocessedFetchedRecordsCount()); await().until(() -> consumerBucket.getOffsetsToBeCommitedCount() == 0); @@ -433,6 +444,48 @@ public void freeSpace(IBaseTask task) { })); } + @Test + void taskProcessingWillHandlePoisonPillAttack() { + // given: + int tasksToFire = 10; + AtomicInteger counter = new AtomicInteger(); + + testTaskHandlerAdapter.setProcessor((ISyncTaskProcessor) task -> { + counter.incrementAndGet(); + return new ProcessResult().setResultCode(ResultCode.DONE); + }); + + // when: + for (int i = 0; i < tasksToFire; i++) { + publishPosionPill(); + addTask(); + publishPosionPill(); + } + + // then: + await().until(() -> transactionsHelper.withTransaction().asNew().call(() -> { + try { + return testTasksService.getFinishedTasks("test", null).size() == tasksToFire + && counter.get() == tasksToFire; + } catch (Throwable t) { + log.error(t.getMessage(), t); + } + return false; + })); + } + + @SneakyThrows + private void publishPosionPill() { + final var topicName = "twTasks." + tasksProperties.getGroupId() + ".executeTask.default"; + kafkaProducer.send(new ProducerRecord<>(topicName, "poison-pill")).get(); + } + + private void addTask() { + transactionsHelper.withTransaction().asNew().call(() -> + tasksService.addTask(new ITasksService.AddTaskRequest().setType("test").setData(taskDataSerializer.serialize("foo"))) + ); + } + private int counterSum(String name) { return meterRegistry.find(name) .counters() diff --git a/integration-tests/src/test/java/com/transferwise/tasks/testapp/config/TestConfiguration.java b/integration-tests/src/test/java/com/transferwise/tasks/testapp/config/TestConfiguration.java index 2b8c4136..f8ac4a6e 100644 --- a/integration-tests/src/test/java/com/transferwise/tasks/testapp/config/TestConfiguration.java +++ b/integration-tests/src/test/java/com/transferwise/tasks/testapp/config/TestConfiguration.java @@ -1,6 +1,7 @@ package com.transferwise.tasks.testapp.config; import com.transferwise.common.context.TwContextClockHolder; +import com.transferwise.tasks.TasksProperties; import com.transferwise.tasks.buckets.BucketProperties; import com.transferwise.tasks.buckets.IBucketsManager; import com.transferwise.tasks.domain.ITask; @@ -17,7 +18,9 @@ import java.time.ZonedDateTime; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import lombok.extern.slf4j.Slf4j; import org.apache.kafka.clients.CommonClientConfigs; import org.apache.kafka.clients.admin.AdminClient; @@ -25,6 +28,9 @@ import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.CooperativeStickyAssignor; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.serialization.StringSerializer; import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.autoconfigure.kafka.KafkaProperties; @@ -141,6 +147,31 @@ public IKafkaListenerConsumerPropertiesProvider twTasksKafkaListenerSpringKafkaC }; } + @Bean + public KafkaProducer kafkaTaskTriggererProducer(TasksProperties tasksProperties) { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, tasksProperties.getTriggering().getKafka().getBootstrapServers()); + + configs.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + configs.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + + configs.put(ProducerConfig.ACKS_CONFIG, "all"); + configs.put(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, 5); + configs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, "5000"); + configs.put(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, "true"); + configs.put(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG, "5000"); + configs.put(ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG, "10000"); + configs.put(ProducerConfig.LINGER_MS_CONFIG, "5"); + configs.put(ProducerConfig.CLIENT_ID_CONFIG, tasksProperties.getGroupId() + ".tw-tasks-triggerer"); + configs.put(ProducerConfig.RECONNECT_BACKOFF_MAX_MS_CONFIG, "5000"); + configs.put(ProducerConfig.RECONNECT_BACKOFF_MS_CONFIG, "100"); + configs.put(ProducerConfig.METADATA_MAX_AGE_CONFIG, "120000"); + + configs.putAll(tasksProperties.getTriggering().getKafka().getProperties()); + + return new KafkaProducer<>(configs); + } + @Bean ITaskRegistrationDecorator jambiRegistrationInterceptor() { return new JambiTaskRegistrationDecorator(); diff --git a/integration-tests/src/test/java/com/transferwise/tasks/triggering/KafkaTasksExecutionTriggererIntTest.java b/integration-tests/src/test/java/com/transferwise/tasks/triggering/KafkaTasksExecutionTriggererIntTest.java index 35c7685d..f70bad26 100644 --- a/integration-tests/src/test/java/com/transferwise/tasks/triggering/KafkaTasksExecutionTriggererIntTest.java +++ b/integration-tests/src/test/java/com/transferwise/tasks/triggering/KafkaTasksExecutionTriggererIntTest.java @@ -23,6 +23,8 @@ import lombok.SneakyThrows; import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerRecord; import org.apache.kafka.common.serialization.StringDeserializer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; @@ -34,12 +36,15 @@ class KafkaTasksExecutionTriggererIntTest extends BaseIntTest { public static final String PARTITION_KEY = "7a1a43c9-35af-4bea-9349-a1f344c8185c"; private static final String BUCKET_ID = "manualStart"; + private static final String TASK_TYPE = "test"; private KafkaConsumer kafkaConsumer; @Autowired protected ITaskDataSerializer taskDataSerializer; @Autowired private TasksProperties tasksProperties; + @Autowired + protected KafkaProducer kafkaProducer; @BeforeEach @SneakyThrows @@ -68,7 +73,6 @@ void cleanup() { @Test void shouldUsePartitionKeyStrategyWhenCustomStrategyDefinedInProcessor() { final var data = "Hello World!"; - final var taskType = "test"; final var taskId = UuidUtils.generatePrefixCombUuid(); testTaskHandlerAdapter.setProcessor(resultRegisteringSyncTaskProcessor) @@ -82,12 +86,12 @@ void shouldUsePartitionKeyStrategyWhenCustomStrategyDefinedInProcessor() { final var taskRequest = new AddTaskRequest() .setTaskId(taskId) .setData(taskDataSerializer.serialize(data)) - .setType(taskType) + .setType(TASK_TYPE) .setUniqueKey(uniqueKey.toString()) .setRunAfterTime(ZonedDateTime.now().plusHours(1)); transactionsHelper.withTransaction().asNew().call(() -> testTasksService.addTask(taskRequest)); - await().until(() -> testTasksService.getWaitingTasks(taskType, null).size() > 0); + await().until(() -> testTasksService.getWaitingTasks(TASK_TYPE, null).size() > 0); assertTrue(transactionsHelper.withTransaction().asNew().call(() -> testTasksService.resumeTask(new ITasksService.ResumeTaskRequest().setTaskId(taskId).setVersion(0)) @@ -106,6 +110,48 @@ void shouldUsePartitionKeyStrategyWhenCustomStrategyDefinedInProcessor() { Assertions.assertTrue(keys.contains(PARTITION_KEY)); } + @Test + void handlesPoisonPills() { + // setup: + testTaskHandlerAdapter.setProcessor(resultRegisteringSyncTaskProcessor) + .setProcessingPolicy(new SimpleTaskProcessingPolicy() + .setProcessingBucket(BUCKET_ID) + .setMaxProcessingDuration(Duration.of(1, ChronoUnit.HOURS)) + .setPartitionKeyStrategy(new TestPartitionKeyStrategy())); + + + // when + int tasksToFire = 10; + for (int i = 0; i < tasksToFire; i++) { + publishPosionPill(); + addTask(); + publishPosionPill(); + } + testTasksService.startTasksProcessing(BUCKET_ID); + + await().until( + () -> resultRegisteringSyncTaskProcessor.getTaskResults().size() == tasksToFire + ); + + } + + @SneakyThrows + private void addTask() { + UUID taskId = UuidUtils.generatePrefixCombUuid(); + final var taskRequest = new AddTaskRequest() + .setTaskId(taskId) + .setData(taskDataSerializer.serialize("Hello World!")) + .setType(TASK_TYPE); + + transactionsHelper.withTransaction().asNew().call(() -> testTasksService.addTask(taskRequest)); + } + + @SneakyThrows + private void publishPosionPill() { + final var topicName = "twTasks." + tasksProperties.getGroupId() + ".executeTask." + BUCKET_ID; + kafkaProducer.send(new ProducerRecord<>(topicName, "poison-pill")).get(); + } + static class TestPartitionKeyStrategy implements IPartitionKeyStrategy { private static final UUID KEY = UUID.fromString(PARTITION_KEY); diff --git a/tw-tasks-core/src/main/java/com/transferwise/tasks/triggering/KafkaTasksExecutionTriggerer.java b/tw-tasks-core/src/main/java/com/transferwise/tasks/triggering/KafkaTasksExecutionTriggerer.java index 69760124..ddfaca11 100644 --- a/tw-tasks-core/src/main/java/com/transferwise/tasks/triggering/KafkaTasksExecutionTriggerer.java +++ b/tw-tasks-core/src/main/java/com/transferwise/tasks/triggering/KafkaTasksExecutionTriggerer.java @@ -280,7 +280,17 @@ public void poll(String bucketId) { log.debug("Received Kafka message from topic '{}' partition {} offset {}.", consumerRecord.topic(), consumerRecord.partition(), offset); - BaseTask task = JsonUtils.fromJson(objectMapper, consumerRecord.value(), BaseTask.class); + BaseTask task; + try { + task = JsonUtils.fromJson(objectMapper, consumerRecord.value(), BaseTask.class); + } catch (Exception e) { + log.error("Received malformed task trigger in bucket {} [from topic '{}' partition {} offset {}].", + bucketId, consumerRecord.topic(), consumerRecord.partition(), offset, e); + consumerBucket.decrementUnprocessedFetchedRecordsCount(); + releaseCompletedOffset(consumerBucket, topicPartition, offset); + continue; + } + mdcService.with(() -> { mdcService.put(task);