Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track credit consumption in AI bot #1809

Merged
merged 11 commits into from
Nov 22, 2024
75 changes: 75 additions & 0 deletions packages/ai-bot/lib/ai-cost.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import {
getCurrentActiveSubscription,
getUserByMatrixUserId,
spendCredits,
} from '@cardstack/billing/billing-queries';
import { PgAdapter, TransactionManager } from '@cardstack/postgres';
import { logger, retry } from '@cardstack/runtime-common';
import * as Sentry from '@sentry/node';

let log = logger('ai-bot');

export async function saveUsageCost(
pgAdapter: PgAdapter,
matrixUserId: string,
generationId: string,
) {
try {
// Generation data is sometimes not immediately available, so we retry a couple of times until we are able to get the cost
let costInUsd = await retry(() => fetchGenerationCost(generationId), {
retries: 10,
delayMs: 500,
});

let creditsConsumed = Math.round(costInUsd / 0.001);
jurgenwerk marked this conversation as resolved.
Show resolved Hide resolved

let user = await getUserByMatrixUserId(pgAdapter, matrixUserId);

// This check is for the transition period where we don't have subscriptions fully rolled out yet.
// When we have assurance that all users who use the bot have subscriptions, we can remove this subscription check.
let subscription = await getCurrentActiveSubscription(pgAdapter, user!.id);
if (!subscription) {
log.info(
`user ${matrixUserId} has no subscription, skipping credit usage tracking`,
);
return Promise.resolve();
}

if (!user) {
throw new Error(
`should not happen: user with matrix id ${matrixUserId} not found in the users table`,
);
}

let txManager = new TransactionManager(pgAdapter);

await txManager.withTransaction(async () => {
await spendCredits(pgAdapter, user!.id, creditsConsumed);

// TODO: send a signal to the host app to update credits balance displayed in the UI
});
} catch (err) {
log.error(
`Failed to track AI usage (matrixUserId: ${matrixUserId}, generationId: ${generationId}):`,
err,
);
Sentry.captureException(err);
// Don't throw, because we don't want to crash the bot over this
}
}

async function fetchGenerationCost(generationId: string) {
let response = await (
await fetch(`https://openrouter.ai/api/v1/generation?id=${generationId}`, {
headers: {
Authorization: `Bearer ${process.env.OPENROUTER_API_KEY}`,
},
})
).json();

if (response.error && response.error.includes('not found')) {
return null;
}

return response.data.total_cost;
}
2 changes: 1 addition & 1 deletion packages/ai-bot/lib/send-response.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ export class Responder {
}
}

async onError(error: OpenAIError) {
async onError(error: OpenAIError | string) {
Sentry.captureException(error);
return await sendError(
this.client,
Expand Down
57 changes: 52 additions & 5 deletions packages/ai-bot/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,39 @@ import { MatrixClient } from './lib/matrix';
import type { MatrixEvent as DiscreteMatrixEvent } from 'https://cardstack.com/base/matrix-event';
import * as Sentry from '@sentry/node';

import { saveUsageCost } from './lib/ai-cost';
import { PgAdapter } from '@cardstack/postgres';

let log = logger('ai-bot');

let trackAiUsageCostPromises = new Map<string, Promise<void>>();

class Assistant {
private openai: OpenAI;
private client: MatrixClient;
private pgAdapter: PgAdapter;
id: string;

constructor(client: MatrixClient, id: string) {
this.openai = new OpenAI({
baseURL: 'https://openrouter.ai/api/v1', // We use openrouter so that we can track usage cost in $
baseURL: 'https://openrouter.ai/api/v1',
apiKey: process.env.OPENROUTER_API_KEY,
});
this.id = id;
this.client = client;
this.pgAdapter = new PgAdapter();
}

async trackAiUsageCost(matrixUserId: string, generationId: string) {
if (trackAiUsageCostPromises.has(matrixUserId)) {
return;
}
trackAiUsageCostPromises.set(
matrixUserId,
saveUsageCost(this.pgAdapter, matrixUserId, generationId).finally(() => {
trackAiUsageCostPromises.delete(matrixUserId);
}),
);
}

getResponse(history: DiscreteMatrixEvent[]) {
Expand Down Expand Up @@ -133,6 +152,7 @@ Common issues are:
async function (event, room, toStartOfTimeline) {
try {
let eventBody = event.getContent().body;
let senderMatrixUserId = event.getSender()!;
if (!room) {
return;
}
Expand All @@ -150,15 +170,15 @@ Common issues are:
return; // don't respond to card fragments, we just gather these in our history
}

if (event.getSender() === aiBotUserId) {
if (senderMatrixUserId === aiBotUserId) {
return;
}
log.info(
'(%s) (Room: "%s" %s) (Message: %s %s)',
event.getType(),
room?.name,
room?.roomId,
event.getSender(),
senderMatrixUserId,
eventBody,
);

Expand All @@ -179,16 +199,33 @@ Common issues are:
const responder = new Responder(client, room.roomId);
await responder.initialize();

// Do not generate new responses if previous ones' cost is still being reported
let pendingCreditsConsumptionPromise = trackAiUsageCostPromises.get(
senderMatrixUserId!,
);
if (pendingCreditsConsumptionPromise) {
try {
await pendingCreditsConsumptionPromise;
} catch (e) {
log.error(e);
return responder.onError(
'There was an error saving your Boxel credits usage. Try again or contact support if the problem persists.',
);
}
}

if (historyError) {
responder.finalize(
'There was an error processing chat history. Please open another session.',
);
return;
}

let generationId: string | undefined;
const runner = assistant
.getResponse(history)
.on('chunk', async (chunk, _snapshot) => {
generationId = chunk.id;
await responder.onChunk(chunk);
})
.on('content', async (_delta, snapshot) => {
Expand All @@ -200,9 +237,19 @@ Common issues are:
.on('error', async (error) => {
await responder.onError(error);
});

// We also need to catch the error when getting the final content
let finalContent = await runner.finalContent().catch(responder.onError);
await responder.finalize(finalContent);
let finalContent;
try {
finalContent = await runner.finalContent();
await responder.finalize(finalContent);
} catch (error) {
await responder.onError(error);
} finally {
if (generationId) {
assistant.trackAiUsageCost(senderMatrixUserId, generationId);
}
}

if (shouldSetRoomTitle(eventList, aiBotUserId, event)) {
return await assistant.setTitle(room.roomId, history, event);
Expand Down
6 changes: 4 additions & 2 deletions packages/ai-bot/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
{
"name": "@cardstack/ai-bot",
"dependencies": {
"@cardstack/runtime-common": "workspace:^",
"@cardstack/runtime-common": "workspace:*",
"@cardstack/postgres": "workspace:*",
"@cardstack/billing": "workspace:*",
"@sentry/node": "^8.31.0",
"@types/node": "^18.18.5",
"@types/stream-chain": "^2.0.1",
Expand All @@ -21,7 +23,7 @@
},
"scripts": {
"lint": "eslint . --cache --ext ts",
"start": "NODE_NO_WARNINGS=1 ts-node --transpileOnly main",
"start": "NODE_NO_WARNINGS=1 PGDATABASE=boxel PGPORT=5435 ts-node --transpileOnly main",
"test": "NODE_NO_WARNINGS=1 qunit --require ts-node/register/transpile-only tests/index.ts",
"get-chat": "NODE_NO_WARNINGS=1 ts-node --transpileOnly scripts/get_chat.ts"
},
Expand Down
64 changes: 64 additions & 0 deletions packages/billing/billing-queries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -472,3 +472,67 @@ export async function expireRemainingPlanAllowanceInSubscriptionCycle(
subscriptionCycleId,
});
}

export async function spendCredits(
dbAdapter: DBAdapter,
userId: string,
creditsToSpend: number,
) {
let subscription = await getCurrentActiveSubscription(dbAdapter, userId);
if (!subscription) {
throw new Error('active subscription not found');
}
let subscriptionCycle = await getMostRecentSubscriptionCycle(
dbAdapter,
subscription.id,
);
if (!subscriptionCycle) {
throw new Error('subscription cycle not found');
}
let availablePlanAllowanceCredits = await sumUpCreditsLedger(dbAdapter, {
creditType: [
'plan_allowance',
'plan_allowance_used',
'plan_allowance_expired',
],
userId,
});

if (availablePlanAllowanceCredits >= creditsToSpend) {
await addToCreditsLedger(dbAdapter, {
userId,
creditAmount: -creditsToSpend,
creditType: 'plan_allowance_used',
subscriptionCycleId: subscriptionCycle.id,
});
} else {
// If user does not have enough plan allowance credits to cover the spend, try to also use extra credits
let availableExtraCredits = await sumUpCreditsLedger(dbAdapter, {
creditType: ['extra_credit', 'extra_credit_used'],
userId,
});
let planAllowanceToSpend = availablePlanAllowanceCredits; // Spend all plan allowance credits first
let extraCreditsToSpend = creditsToSpend - planAllowanceToSpend;
if (extraCreditsToSpend > availableExtraCredits) {
extraCreditsToSpend = availableExtraCredits;
}

if (planAllowanceToSpend > 0) {
await addToCreditsLedger(dbAdapter, {
userId,
creditAmount: -planAllowanceToSpend,
creditType: 'plan_allowance_used',
subscriptionCycleId: subscriptionCycle.id,
});
}

if (extraCreditsToSpend > 0) {
await addToCreditsLedger(dbAdapter, {
userId,
creditAmount: -extraCreditsToSpend,
creditType: 'extra_credit_used',
subscriptionCycleId: subscriptionCycle.id,
});
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would be worth tracking the generation ID.

Copy link
Contributor Author

@jurgenwerk jurgenwerk Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think now with this simple variant this may not be necessary but later on when we introduce jobs and queues, we could save the generation id there (in the serialized job) and have it available for inspection if needed.

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { type DBAdapter } from '@cardstack/runtime-common';
import {
addToCreditsLedger,
getCurrentActiveSubscription,
getMostRecentSubscriptionCycle,
getUserByStripeId,
insertStripeEvent,
markStripeEventAsProcessed,
Expand Down Expand Up @@ -50,11 +52,27 @@ export async function handleCheckoutSessionCompleted(
);
}

let subscription = await getCurrentActiveSubscription(dbAdapter, user.id);
if (!subscription) {
throw new Error(
`User ${user.id} has no subscription, cannot add extra credits`,
);
}
let subscriptionCycle = await getMostRecentSubscriptionCycle(
dbAdapter,
subscription!.id,
);
if (!subscriptionCycle) {
throw new Error(
`User ${user.id} has no subscription cycle, cannot add extra credits`,
);
}

await addToCreditsLedger(dbAdapter, {
userId: user.id,
creditAmount: creditReloadAmount,
creditType: 'extra_credit',
subscriptionCycleId: null,
subscriptionCycleId: subscriptionCycle.id,
});
}
});
Expand Down
3 changes: 0 additions & 3 deletions packages/billing/stripe-webhook-handlers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,6 @@ export default async function stripeWebhookHandler(

let type = event.type;

// For adding extra credits, we should listen for charge.succeeded, and for
// subsciptions, we should listen for invoice.payment_succeeded (I discovered this when I was
// testing which webhooks arrive for both types of payments)
switch (type) {
// These handlers should eventually become jobs which workers will process asynchronously
case 'invoice.payment_succeeded':
Expand Down
Loading
Loading