diff --git a/apps/backend/apps/client/src/group/group.service.spec.ts b/apps/backend/apps/client/src/group/group.service.spec.ts index f174fae7ef..91ccc6a95d 100644 --- a/apps/backend/apps/client/src/group/group.service.spec.ts +++ b/apps/backend/apps/client/src/group/group.service.spec.ts @@ -1,7 +1,7 @@ import { CACHE_MANAGER } from '@nestjs/cache-manager' import { ConfigService } from '@nestjs/config' import { Test, type TestingModule } from '@nestjs/testing' -import { Prisma } from '@prisma/client' +import { Prisma, PrismaClient } from '@prisma/client' import type { Cache } from 'cache-manager' import { expect } from 'chai' import * as chai from 'chai' @@ -12,21 +12,29 @@ import { ConflictFoundException, EntityNotExistException } from '@libs/exception' -import { PrismaService } from '@libs/prisma' +import { PrismaService, type FlatTransactionClient } from '@libs/prisma' +import { transactionExtension } from '@libs/prisma' import { GroupService } from './group.service' import type { UserGroupData } from './interface/user-group-data.interface' chai.use(chaiExclude) - describe('GroupService', () => { let service: GroupService let cache: Cache - let prisma: PrismaService - beforeEach(async () => { + let tx: FlatTransactionClient + + const prisma = new PrismaClient().$extends(transactionExtension) + + beforeEach(async function () { + // TODO: CI 테스트에서 timeout이 걸리는 문제를 우회하기 위해서 timeout을 0으로 설정 (timeout disabled) + // local에서는 timeout을 disable 하지 않아도 테스트가 정상적으로 동작함 (default setting: 2000ms) + this.timeout(0) + //transaction client + tx = await prisma.$begin() const module: TestingModule = await Test.createTestingModule({ providers: [ GroupService, - PrismaService, + { provide: PrismaService, useValue: tx }, ConfigService, { provide: CACHE_MANAGER, @@ -39,7 +47,6 @@ describe('GroupService', () => { }).compile() service = module.get(GroupService) cache = module.get(CACHE_MANAGER) - prisma = module.get(PrismaService) }) it('should be defined', () => { @@ -166,9 +173,8 @@ describe('GroupService', () => { describe('joinGroupById', () => { let groupId: number const userId = 4 - beforeEach(async () => { - const group = await prisma.group.create({ + const group = await tx.group.create({ data: { groupName: 'test', description: 'test', @@ -182,26 +188,7 @@ describe('GroupService', () => { }) afterEach(async () => { - try { - await prisma.userGroup.delete({ - where: { - // eslint-disable-next-line @typescript-eslint/naming-convention - userId_groupId: { userId, groupId } - } - }) - } catch { - /* 삭제할 내용이 없는 경우 예외 무시 */ - } - - try { - await prisma.group.delete({ - where: { - id: groupId - } - }) - } catch { - /* 삭제할 내용 없을 경우 예외 무시 */ - } + await tx.$rollback() }) it('should return {isJoined: true} when group not set as requireApprovalBeforeJoin', async () => { @@ -225,7 +212,7 @@ describe('GroupService', () => { }) it('should return {isJoined: false} when group set as requireApprovalBeforeJoin', async () => { - await prisma.group.update({ + await tx.group.update({ where: { id: groupId }, @@ -250,7 +237,7 @@ describe('GroupService', () => { }) it('should throw ConflictFoundException when user is already group memeber', async () => { - await prisma.userGroup.create({ + await tx.userGroup.create({ data: { userId, groupId, @@ -270,7 +257,7 @@ describe('GroupService', () => { { userId, expiresAt: Date.now() + JOIN_GROUP_REQUEST_EXPIRE_TIME } ]) - await prisma.group.update({ + await tx.group.update({ where: { id: groupId }, @@ -291,9 +278,8 @@ describe('GroupService', () => { describe('leaveGroup', () => { const groupId = 3 const userId = 4 - beforeEach(async () => { - await prisma.userGroup.createMany({ + await tx.userGroup.createMany({ data: [ { userId, @@ -310,18 +296,7 @@ describe('GroupService', () => { }) afterEach(async () => { - try { - await prisma.userGroup.deleteMany({ - where: { - OR: [ - { AND: [{ userId }, { groupId }] }, - { AND: [{ userId: 5 }, { groupId }] } - ] - } - }) - } catch { - return - } + await tx.$rollback() }) it('should return deleted userGroup when valid userId and groupId passed', async () => { diff --git a/apps/backend/libs/prisma/src/index.ts b/apps/backend/libs/prisma/src/index.ts index 242be587a0..7f13878b82 100644 --- a/apps/backend/libs/prisma/src/index.ts +++ b/apps/backend/libs/prisma/src/index.ts @@ -1,2 +1,3 @@ export * from './prisma.module' export * from './prisma.service' +export * from './transaction.extension' diff --git a/apps/backend/libs/prisma/src/transaction.extension.ts b/apps/backend/libs/prisma/src/transaction.extension.ts new file mode 100644 index 0000000000..5dbd210f22 --- /dev/null +++ b/apps/backend/libs/prisma/src/transaction.extension.ts @@ -0,0 +1,71 @@ +import { Prisma } from '@prisma/client' +import { PrismaService } from './prisma.service' + +export type FlatTransactionClient = Prisma.TransactionClient & { + $commit: () => Promise + $rollback: () => Promise +} + +const ROLLBACK = { [Symbol.for('prisma.client.extension.rollback')]: true } + +export const transactionExtension = Prisma.defineExtension({ + client: { + async $begin() { + const prisma = Prisma.getExtensionContext(this) + let setTxClient: (txClient: Prisma.TransactionClient) => void + let commit: () => void + let rollback: () => void + + // a promise for getting the tx inner client + const txClient = new Promise((res) => { + setTxClient = res + }) + + // a promise for controlling the transaction + const txPromise = new Promise((_res, _rej) => { + commit = () => _res(undefined) + rollback = () => _rej(ROLLBACK) + }) + + // opening a transaction to control externally + if ( + '$transaction' in prisma && + typeof prisma.$transaction === 'function' + ) { + const tx = prisma + .$transaction((txClient) => { + setTxClient(txClient as unknown as Prisma.TransactionClient) + return txPromise + }) + .catch((e) => { + if (e === ROLLBACK) { + return + } + throw e + }) + + // return a proxy TransactionClient with `$commit` and `$rollback` methods + return new Proxy(await txClient, { + get(target, prop) { + if (prop === '$commit') { + return () => { + commit() + return tx + } + } + if (prop === '$rollback') { + return () => { + rollback() + return tx + } + } + return target[prop as keyof typeof target] + } + }) as FlatTransactionClient + } + + throw new Error('Transactions are not supported by this client') + }, + getPaginator: PrismaService.prototype.getPaginator + } +})