Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring Stripe Customer Management #8

Merged
merged 15 commits into from
Apr 30, 2024
7 changes: 3 additions & 4 deletions dbos-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,19 @@ database:
port: 5432
username: 'postgres'
password: ${PGPASSWORD}
app_db_name: 'cloudsub'
app_db_name: 'cloud_account'
connectionTimeoutMillis: 3000
app_db_client: 'knex'
migrate:
- npx knex migrate:latest
rollback:
- npx knex migrate:rollback
application:
STRIPE_WEBHOOK_SECRET: ${STRIPE_WEBHOOK_SECRET}
STRIPE_DBOS_PRO_PRICE: ${STRIPE_DBOS_PRO_PRICE}
env:
DBOS_DOMAIN: ${DBOS_DOMAIN}
STRIPE_SECRET_KEY: ${STRIPE_SECRET_KEY}
DBOS_DEPLOY_REFRESH_TOKEN: ${DBOS_DEPLOY_REFRESH_TOKEN}
STRIPE_DBOS_PRO_PRICE: ${STRIPE_DBOS_PRO_PRICE}
STRIPE_WEBHOOK_SECRET: ${STRIPE_WEBHOOK_SECRET}
http:
cors_middleware: true
credentials: true
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
const { Knex } = require("knex");

exports.up = async function(knex) {
await knex.schema.createTable('subscriptions', table => {
table.text('auth0_user_id').primary();
await knex.schema.createTable('accounts', table => {
table.text('auth0_subject_id').primary();
table.text('email').notNullable();
table.text('stripe_customer_id').notNullable();
table.text('dbos_plan').notNullable().defaultTo('free');
table.bigInteger('created_at')
.notNullable()
.defaultTo(knex.raw('(EXTRACT(EPOCH FROM now())*1000)::bigint'));
table.bigInteger('updated_at')
.notNullable()
.defaultTo(knex.raw('(EXTRACT(EPOCH FROM now())*1000)::bigint'));
table.index('stripe_customer_id');
});
};

exports.down = async function(knex) {
return knex.schema.dropTable('subscriptions');
return knex.schema.dropTable('accounts');
};
69 changes: 0 additions & 69 deletions scripts/auth0_post_login.js

This file was deleted.

12 changes: 6 additions & 6 deletions scripts/dbos_deploy.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
# This script is used to automatically deploy this subscription app to DBOS Cloud
import os
from utils import (login, run_subprocess)
from utils import (login, run_subprocess, generate_password)
from config import config

script_dir = os.path.dirname(os.path.abspath(__file__))
app_dir = os.path.join(script_dir, "..")

def deploy(path: str):
output = run_subprocess(['npx', 'dbos-cloud', 'database', 'status', config.db_name], path, check=False)
output = run_subprocess(['npx', 'dbos-cloud', 'db', 'status', config.db_name], path, check=False)
if "error" in output:
raise Exception(f"Database {config.db_name} errored!")

# run_subprocess(['npx', 'dbos-cloud', 'applications', 'register', '--database', DB_NAME], path, check=False)
run_subprocess(['npx', 'dbos-cloud', 'applications', 'deploy'], path)
# Provision a database
run_subprocess(['npx', 'dbos-cloud', 'db', 'provision', config.db_name, '-U', config.deploy_username, '-W', generate_password()], path)
run_subprocess(['npx', 'dbos-cloud', 'app', 'register', '--database', config.db_name], path, check=False)
run_subprocess(['npx', 'dbos-cloud', 'app', 'deploy'], path)

if __name__ == "__main__":
login(app_dir, is_deploy=True)
Expand Down
31 changes: 29 additions & 2 deletions scripts/staging_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import json
import os
import time
from utils import (login, run_subprocess)

import requests
from utils import (login, run_subprocess, get_credentials)
from config import config
import stripe

Expand All @@ -22,16 +24,41 @@ def test_endpoints(path: str):
if json_data['SubscriptionPlan'] != "free":
raise Exception("Free tier check failed")

# Test the subscribe endpoint
credentials = get_credentials(path)
token = credentials['token']
url = f"https://subscribe-dbos.{config.dbos_domain}/subscribe"
headers = {
'Authorization': f'Bearer {token}',
'Content-Type': 'application/json'
}
data = {
'plan': 'dbospro'
}
res = requests.post(url, headers=headers, data=json.dumps(data))
assert res.status_code == 200, f"Cloud subscribe endpoint failed: {res.status_code} - {res.text}"

# Test customer portal endpoint
url = f"https://subscribe-dbos.{config.dbos_domain}/create-customer-portal"
headers = {
'Authorization': f'Bearer {token}',
'Content-Type': 'application/json'
}
res = requests.post(url, headers=headers)
assert res.status_code == 200, f"Cloud create-customer-portal endpoint failed: {res.status_code} - {res.text}"

# Look up customer ID
customers = stripe.Customer.list(email=config.test_email, limit=1)
if len(customers) == 0:
raise Exception("No Stripe customer found for test email")
customer_id = customers.data[0].id

# Create a subscription that uses the default test payment
# Create a subscription that sets a trial that ends in 1 day.
subscription = stripe.Subscription.create(
customer=customer_id,
items=[{"price": config.stripe_pro_price}],
trial_period_days=1,
trial_settings={"end_behavior": {"missing_payment_method": "cancel"}},
)

time.sleep(30) # Wait for subscription to take effect
Expand Down
16 changes: 14 additions & 2 deletions scripts/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
import os
import random
import string
import subprocess

import requests

from config import config

def run_subprocess(command, path: str, check: bool = True, silent: bool = False):
Expand All @@ -23,3 +23,15 @@ def login(path: str, is_deploy: bool = False):
# Automated login using the refresh token
refresh_token = config.deploy_refresh_token if is_deploy else config.test_refresh_token
run_subprocess(['npx', 'dbos-cloud', 'login', '--with-refresh-token', refresh_token], path, check=True)

def get_credentials(path: str):
credentials_path = os.path.join(path, '.dbos', 'credentials')
if not os.path.exists(credentials_path):
raise Exception(f'Could not find credentials file {credentials_path}')
with open(credentials_path, 'r') as f:
return json.load(f)

def generate_password():
characters = string.ascii_letters + string.digits
random_string = ''.join(random.choice(characters) for _ in range(16))
return random_string
11 changes: 11 additions & 0 deletions src/operations.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,26 @@ import request from "supertest";

describe("cors-tests", () => {
let testRuntime: TestingRuntime;
const auth0TestID = "testauth0123";
const stripeTestID = "teststripe123";
const testEmail = "[email protected]";

beforeAll(async () => {
testRuntime = await createTestingRuntime([CloudSubscription, Utils]);
await testRuntime.queryUserDB(`DELETE FROM accounts WHERE auth0_subject_id='${auth0TestID}';`);
});

afterAll(async () => {
await testRuntime.destroy();
});

test("account-management", async () => {
// Check our transactions are correct
await expect(testRuntime.invoke(Utils).recordStripeCustomer(auth0TestID, stripeTestID, testEmail)).resolves.toBeFalsy(); // No error
await expect(testRuntime.invoke(Utils).findStripeCustomerID(auth0TestID)).resolves.toBe(stripeTestID);
await expect(testRuntime.invoke(Utils).findAuth0UserID(stripeTestID)).resolves.toBe(auth0TestID);
});

test("subscribe-cors", async () => {
const req = {
plan: "dbospro",
Expand Down
66 changes: 21 additions & 45 deletions src/operations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import { HandlerContext, ArgSource, ArgSources, PostApi, DBOSResponseError, Requ
import Stripe from 'stripe';
import jwt from "koa-jwt";
import { koaJwtSecret } from "jwks-rsa";
import { DBOSLoginDomain, stripe, Utils } from './utils';
import { DBOSLoginDomain, Utils } from './utils';
export { Utils } from './utils';

const DBOSProPlanString = "dbospro";
const dbosJWT = jwt({
const auth0JwtVerifier = jwt({
secret: koaJwtSecret({
jwksUri: `https://${DBOSLoginDomain}/.well-known/jwks.json`,
cache: true,
Expand All @@ -21,34 +21,28 @@ const dbosJWT = jwt({

// These endpoints can only be called with an authenticated user on DBOS cloud
@Authentication(Utils.userAuthMiddleware)
@KoaMiddleware(dbosJWT)
@KoaMiddleware(auth0JwtVerifier)
export class CloudSubscription {
@RequiredRole(['user'])
@PostApi('/create-customer-portal')
static async createCustomerPortal(ctxt: HandlerContext) {
const authUser = ctxt.authenticatedUser;
const sessionURL = await ctxt.invoke(Utils).createPortal(authUser);
@PostApi('/subscribe')
static async subscribePlan(ctxt: HandlerContext, @ArgSource(ArgSources.BODY) plan: string) {
if (plan !== DBOSProPlanString) { throw new DBOSResponseError("Invalid DBOS Plan", 400); }
const auth0UserID = ctxt.authenticatedUser;
const userEmail = ctxt.koaContext.state.user["https://dbos.dev/email"] as string;
const sessionURL = await ctxt.invoke(Utils).createSubscription(auth0UserID, userEmail).then(x => x.getResult());
if (!sessionURL) {
ctxt.logger.error("Failed to create a customer portal!");
throw new DBOSResponseError("Failed to create customer portal!", 500);
throw new DBOSResponseError("Failed to create a checkout session!");
}
return { url: sessionURL };
}

// This function redirects user to a subscription page
@RequiredRole(['user'])
@PostApi('/subscribe')
static async subscribePlan(ctxt: HandlerContext, @ArgSource(ArgSources.BODY) plan: string) {
// Validate argument
if (plan !== DBOSProPlanString) {
ctxt.logger.error(`Invalid DBOS plan: ${plan}`);
throw new DBOSResponseError("Invalid DBOS Plan", 400);
}

const authUser = ctxt.authenticatedUser;
const sessionURL = await ctxt.invoke(Utils).createCheckout(authUser);
@PostApi('/create-customer-portal')
static async createCustomerPortal(ctxt: HandlerContext) {
const auth0User = ctxt.authenticatedUser;
const sessionURL = await ctxt.invoke(Utils).createStripeCustomerPortal(auth0User).then(x => x.getResult());
if (!sessionURL) {
throw new Error("Failed to create a checkout session!");
throw new DBOSResponseError("Failed to create customer portal!", 500);
}
return { url: sessionURL };
}
Expand All @@ -58,42 +52,24 @@ export class CloudSubscription {
export class StripeWebhook {
@PostApi('/stripe_webhook')
static async stripeWebhook(ctxt: HandlerContext) {
// Make sure the request is actually from Stripe.
// Verify the request is actually from Stripe.
const req = ctxt.koaContext.request;
const sigHeader = req.headers['stripe-signature'];
if (typeof sigHeader !== 'string') {
throw new DBOSResponseError("Invalid stripe request", 400);
}

const payload: string = req.rawBody;
let event: Stripe.Event;
try {
event = stripe.webhooks.constructEvent(payload, sigHeader, ctxt.getConfig("STRIPE_WEBHOOK_SECRET") as string);
} catch (err) {
ctxt.logger.error(err);
throw new DBOSResponseError("Webhook Error", 400);
}

// Fetch auth0 credential every 6 hours.
const ts = Date.now();
const uuidStr = 'authtoken-' + (ts - (ts % 21600000)).toString();
await ctxt.invoke(Utils, uuidStr).retrieveCloudCredential();
const event = Utils.verifyStripeEvent(req.rawBody as string, req.headers['stripe-signature']);

// Handle the event.
// Use event ID as the idempotency key for the workflow, making sure once-and-only-once execution.
// Invoke the workflow but don't wait for it to finish. Fast response to Stripe.
// Invoke the workflow asynchronously and quickly response to Stripe.
// Use event.id as the workflow idempotency key to guarantee exactly once processing.
switch (event.type) {
case 'customer.subscription.created':
case 'customer.subscription.deleted':
case 'customer.subscription.updated': {
const subscription = event.data.object as Stripe.Subscription;
await ctxt.invoke(Utils, event.id).subscriptionWorkflow(subscription.id, subscription.customer as string);
await ctxt.invoke(Utils, event.id).stripeWebhookWorkflow(subscription.id, subscription.customer as string);
break;
}
case 'checkout.session.completed': {
const checkout = event.data.object as Stripe.Checkout.Session;
if (checkout.mode === 'subscription') {
await ctxt.invoke(Utils, event.id).subscriptionWorkflow(checkout.subscription as string, checkout.customer as string);
await ctxt.invoke(Utils, event.id).stripeWebhookWorkflow(checkout.subscription as string, checkout.customer as string);
}
break;
}
Expand Down
Loading
Loading