diff --git a/packages/core/src/command/command.ts b/packages/core/src/command/command.ts index 21ea01d402..c4cf7d708c 100644 --- a/packages/core/src/command/command.ts +++ b/packages/core/src/command/command.ts @@ -1,4 +1,4 @@ -import { Awaitable, coerce, Dict, isNullable, Logger, remove, Schema } from '@koishijs/utils' +import { Awaitable, coerce, Dict, isNullable, isGeneratorFunction, Logger, remove, Schema } from '@koishijs/utils' import { segment } from '@satorijs/core' import { Disposable } from 'cordis' import { Context } from '../context' @@ -26,8 +26,10 @@ export namespace Command { options?: Dict } + export type Plain = void | string | segment + export type Action - = (argv: Argv, ...args: A) => Awaitable + = (argv: Argv, ...args: A) => Awaitable | Generator | AsyncGenerator export type Usage = string | ((session: Session) => Awaitable) @@ -246,7 +248,24 @@ export class Command async () => { - return await action.call(this, argv, ...args) + if (isGeneratorFunction(action)) { + const result = action.call(this, argv, ...args) + let ids: string[] = [] + while (true) { + const effect = await result.next(ids) + ids = [] + // return + if (effect.done) { + return effect.value + } + // yield + if (!isNullable(effect.value)) { + ids = await argv.session.send(effect.value) + } + } + } else { + return action.call(this, argv, ...args) as Command.Plain + } }) queue.push(fallback) diff --git a/packages/core/tests/command.spec.ts b/packages/core/tests/command.spec.ts index 5f234afd77..5d1f4060d6 100644 --- a/packages/core/tests/command.spec.ts +++ b/packages/core/tests/command.spec.ts @@ -1,4 +1,5 @@ import { App, Command, Logger, Next } from 'koishi' +import { sleep } from '@koishijs/utils' import { inspect } from 'util' import { expect, use } from 'chai' import shape from 'chai-shape' @@ -175,12 +176,14 @@ describe('Command API', () => { }) }) - describe('Execute Commands', () => { + describe('Execute Commands', async () => { const app = new App() app.plugin(mock) const session = app.mock.session({}) + const client = app.mock.client('123') const warn = jest.spyOn(logger, 'warn') const next = jest.fn(Next.compose) + await app.start() let command: Command beforeEach(() => { @@ -290,6 +293,36 @@ describe('Command API', () => { expect(warn.mock.calls).to.have.length(0) expect(next.mock.calls).to.have.length(0) }) + + it('generator 1 (sync)', async () => { + command.action(function* () { + yield '1' + yield '2' + return '3' + }) + await client.shouldReply('test', ['1', '2', '3']) + }) + + it('generator 2 (sync without yield)', async () => { + command.action(function* () { return '1' }) + await client.shouldReply('test', '1') + }) + + it('generator 3 (sync without return)', async () => { + command.action(function* () { yield '1' }) + await client.shouldReply('test', '1') + }) + + it('generator 4 (async)', async () => { + command.action(async function* () { + yield '1' + yield '2' + await sleep(100) + yield '3' + return '4' + }) + await client.shouldReply('test', ['1', '2', '3', '4']) + }) }) describe('Bypass Middleware', async () => { diff --git a/packages/utils/src/misc.ts b/packages/utils/src/misc.ts index dd24738229..60f94b8aee 100644 --- a/packages/utils/src/misc.ts +++ b/packages/utils/src/misc.ts @@ -2,6 +2,10 @@ export function isInteger(source: any) { return typeof source === 'number' && Math.floor(source) === source } +export function isGeneratorFunction(fn: any): fn is (...args: any[]) => Generator | AsyncGenerator { + return ['GeneratorFunction', 'AsyncGeneratorFunction'].includes(fn.constructor.name) +} + export async function sleep(ms: number): Promise { return new Promise(resolve => setTimeout(resolve, ms)) }