Skip to content
This repository has been archived by the owner on Dec 1, 2023. It is now read-only.

Commit

Permalink
Add Stripe
Browse files Browse the repository at this point in the history
Implements a 'Stripe' Lightning Flow component. This uses the
Stripe client and defines the necessary HTTP endpoints for
creating and completing a checkout session using Stripe as the
payment processor.
  • Loading branch information
alecmerdler committed Dec 14, 2022
1 parent b16b532 commit 43d61c2
Show file tree
Hide file tree
Showing 5 changed files with 278 additions and 7 deletions.
85 changes: 81 additions & 4 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import os
import uuid
from datetime import datetime, timedelta
Expand All @@ -6,6 +7,7 @@

import requests
from fastapi import HTTPException
from fastapi.responses import RedirectResponse
from lightning import LightningApp, LightningFlow
from lightning.app.api.http_methods import Delete, Get, Post
from lightning.app.frontend import StaticWebFrontend
Expand All @@ -20,6 +22,7 @@
from echo.components.database.server import Database
from echo.components.fileserver import FileServer
from echo.components.loadbalancing.loadbalancer import LoadBalancer
from echo.components.payment.stripe import Order, Stripe
from echo.components.recognizer import SpeechRecognizer
from echo.components.youtuber import YouTuber
from echo.constants import SHARED_STORAGE_DRIVE_ID
Expand Down Expand Up @@ -47,6 +50,11 @@

RECOGNIZER_ATTRIBUTE_PREFIX = "recognizer_"

PRICE_PER_MINUTE_DEFAULT = 1.00
STRIPE_PRODUCT_NAME = f"Audio/video transcription ($1.00 per minute)"
STRIPE_PRODUCT_DESCRIPTION = "Transcribe audio and video files"
STRIPE_PRODUCT_IMAGES = []

RECOGNIZER_MIN_REPLICAS_DEFAULT = 1
RECOGNIZER_MAX_IDLE_SECONDS_PER_WORK_DEFAULT = 120
RECOGNIZER_MAX_PENDING_CALLS_PER_WORK_DEFAULT = 10
Expand Down Expand Up @@ -161,6 +169,8 @@ def __init__(self):
self.database_cloud_compute = os.environ.get("ECHO_DATABASE_CLOUD_COMPUTE", DATABASE_CLOUD_COMPUTE_DEFAULT)
self._loadbalancer_auth_token = os.environ.get("ECHO_LOADBALANCER_AUTH_TOKEN", None)

self.price_per_minute = int(os.environ.get("ECHO_PRICE_PER_MINUTE", PRICE_PER_MINUTE_DEFAULT))

# Need to wait for database to be ready before initializing clients
self._echo_db_client = None
self._segment_db_client = None
Expand Down Expand Up @@ -196,11 +206,28 @@ def __init__(self):
dummy_run_kwargs={"echo": dummy_echo, "db_url": None},
)

self.stripe_enabled = os.environ.get("ECHO_STRIPE_API_KEY", None) is not None
if self.stripe_enabled:
logger.info("Enabling Stripe payment processing")

self.stripe = Stripe(
api_key=os.environ.get("ECHO_STRIPE_API_KEY", None),
product_name=STRIPE_PRODUCT_NAME,
product_description=STRIPE_PRODUCT_DESCRIPTION,
product_images=STRIPE_PRODUCT_IMAGES,
determine_price=self.determine_price,
determine_quantity=self.determine_quantity,
on_checkout_completed=self.on_checkout_completed,
)

def run(self):
# Run child components
self.database.run()
self.fileserver.run()

if self.stripe_enabled:
self.stripe.run()

if self.database.alive() and self._echo_db_client is None:
self._echo_db_client = DatabaseClient(model=Echo, db_url=self.database.db_url)
self._segment_db_client = DatabaseClient(model=Segment, db_url=self.database.db_url)
Expand Down Expand Up @@ -234,9 +261,6 @@ def create_echo(self, echo: Echo) -> Echo:
if echo.source_youtube_url is not None:
self.youtuber.run(youtube_url=echo.source_youtube_url, echo_id=echo.id, fileserver_url=self.fileserver.url)

# Run speech recognition for the Echo
self.recognizer.run(echo=echo, db_url=self.database.url)

return echo

def list_echoes(self, config: ListEchoesConfig) -> List[Echo]:
Expand Down Expand Up @@ -289,6 +313,52 @@ def login(self):

return LoginResponse(user_id=new_user_id)

def determine_price(self, echo_id: str) -> int:
if self._echo_db_client is None:
logger.warn("Database client not initialized!")
return None

echo = self._echo_db_client.get_echo(echo_id)
if echo is None:
return None

if echo.source_youtube_url is not None:
video_length = youtube_video_length(echo.source_youtube_url)

return math.ceil(video_length / 60) * self.price_per_minute * 100

# FIXME(alecmerdler): Determine real audio file length for price...
return 100

def determine_quantity(self, echo_id: str) -> int:
if self._echo_db_client is None:
logger.warn("Database client not initialized!")
return None

echo = self._echo_db_client.get_echo(echo_id)
if echo is None:
return None

if echo.source_youtube_url is not None:
video_length = youtube_video_length(echo.source_youtube_url)

return math.ceil(video_length / 60)

# FIXME(alecmerdler): Determine real audio file length for price...
return 100

def on_checkout_completed(self, order: Order):
if self._echo_db_client is None:
logger.warn("Database client not initialized!")
return None

echo = self._echo_db_client.get_echo(order.id)
if echo is None:
return None

# Trigger recognizer to process the file once the user has paid
self.recognizer.run(echo=echo, db_url=self.database.db_url)

def validate_echo(self, echo: Echo) -> ValidateEchoResponse:
# Guard against disabled source types
if echo.source_youtube_url is not None:
Expand Down Expand Up @@ -327,7 +397,13 @@ def handle_create_echo(self, echo: Echo) -> Echo:
if not validation.valid:
raise HTTPException(status_code=400, detail=validation.reason)

return self.create_echo(echo)
echo = self.create_echo(echo)

# If payment processing is disabled, start processing the Echo immediately
if not self.stripe_enabled:
self.recognizer.run(echo=echo, db_url=self.database.url)

return echo

def handle_list_echoes(self, user_id: str) -> List[Echo]:
return self.list_echoes(ListEchoesConfig(user_id=user_id))
Expand Down Expand Up @@ -369,6 +445,7 @@ def configure_api(self):
Post("/api/validate", method=self.handle_validate_echo, timeout=REST_API_TIMEOUT_SECONDS),
Get("/api/login", method=self.handle_login, timeout=REST_API_TIMEOUT_SECONDS),
Post("/api/scale", method=self.handle_scale, timeout=REST_API_TIMEOUT_SECONDS),
*self.stripe.configure_api(),
]

def configure_commands(self):
Expand Down
Empty file.
188 changes: 188 additions & 0 deletions echo/components/payment/stripe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import os
from dataclasses import dataclass
from datetime import date, datetime
from typing import Callable, List

import stripe
from fastapi import Header, Request, Response
from fastapi.responses import RedirectResponse
from lightning import LightningFlow
from lightning.app.api import Get, Post
from lightning.app.utilities.app_helpers import Logger
from stripe.error import SignatureVerificationError

logger = Logger(__name__)

create_checkout_session_path = "/_stripe/create-checkout-session"
stripe_events_path = "/_stripe/stripe-events"


DEFAULT_UNIT_PRICE = 100


@dataclass
class Order:
"""Order represents a single order from a customer."""

id: str
customer_email: str
created_at: date


class Stripe(LightningFlow):
"""Stripe is a `LightningFlow` that provides payment processing using Stripe Checkout."""

def __init__(
self,
api_key: str,
product_name: str,
product_description: str = "",
product_images: List[str] = [],
# TODO(alecmerdler): Determine if `determine_price` or `determine_quantity` makes a better API...
unit_price=DEFAULT_UNIT_PRICE,
determine_price: Callable[[str], int] = None,
determine_quantity: Callable[[str], int] = None,
on_checkout_completed: Callable[[Order], None] = None,
success_url: str = None,
cancel_url: str = None,
):
super().__init__()

if determine_price is None:
raise ValueError("Must provide a `determine_price` function.")
if on_checkout_completed is None:
raise ValueError("Must provide an `on_checkout_completed` function.")

stripe.api_key = api_key

self.app_url = os.environ.get("LIGHTNING_APP_EXTERNAL_URL", "http://localhost:7501")
self.success_url = success_url or f"{self.app_url}?success=true"
self.cancel_url = cancel_url or f"{self.app_url}?canceled=true"

self.create_checkout_session_url = self.app_url + create_checkout_session_path
self.product_name = product_name
self.product_description = product_description
self.product_images = product_images
self.unit_price = unit_price

# These will be created dynamically using the Stripe API at runtime
self._endpoint_secret = os.environ.get("DEVELOPMENT_STRIPE_WEBHOOK_SECRET", None)
self._product = None
self._price = None
self._on_checkout_completed = on_checkout_completed
self._determine_price = determine_price
self._determine_quantity = determine_quantity

def handle_create_checkout_session(self, id: str):
# TODO(alecmerdler): Determine if `determine_price` or `determine_quantity` makes a better API...
price = self.unit_price
# price = self._determine_price(id)
quantity = self._determine_quantity(id)

try:
checkout_session = stripe.checkout.Session.create(
line_items=[
{
"quantity": quantity,
"price_data": {
"currency": "usd",
"product": self._product["id"],
"unit_amount": price,
},
},
],
metadata={
"order_id": id,
},
mode="payment",
success_url=self.success_url,
cancel_url=self.cancel_url,
)
except Exception as e:
logger.error(e)

return Response(status_code=500)

return RedirectResponse(checkout_session.url, status_code=303)

def handle_stripe_events(self, request: Request, stripe_signature: str = Header(None)):
event = None
payload = request._body

try:
event = stripe.Webhook.construct_event(payload, stripe_signature, self._endpoint_secret)
except ValueError as e:
logger.error(e)

return Response(status_code=400)
except SignatureVerificationError as e:
logger.error(e)

return Response(status_code=400)

if event["type"] == "checkout.session.completed":
session = stripe.checkout.Session.retrieve(
event["data"]["object"]["id"],
expand=["line_items"],
)

order = Order(
id=session["metadata"]["order_id"],
customer_email=session["customer_details"]["email"],
created_at=datetime.fromtimestamp(session["created"]),
)

logger.debug("Received order from Stripe: " + str(order))

self._on_checkout_completed(order)

return None

def configure_api(self):
return [
Get(create_checkout_session_path, self.handle_create_checkout_session),
Post(stripe_events_path, self.handle_stripe_events),
]

def create_product(self):
products = stripe.Product.list()
for product in products["data"]:
if product["name"] == self.product_name:
return product

logger.info("Creating new Product in Stripe: " + self.product_name)

return stripe.Product.create(
name=self.product_name,
description=self.product_description,
images=self.product_images,
)

def create_webhook_endpoint(self):
endpoints = stripe.WebhookEndpoint.list()
for endpoint in endpoints["data"]:
if endpoint["url"] == self.app_url + stripe_events_path:
# NOTE: We have to recreate the webhook because you can only fetch the secret from the API at creation.
logger.info("Deleting existing Webhook Endpoint in Stripe: " + self.app_url)

stripe.WebhookEndpoint.delete(endpoint["id"])

if self.app_url == "http://localhost:7501":
logger.error(
"Cannot create Webhook Endpoint in Stripe for local development. Use `stripe listen` command instead."
)

return None

logger.info("Creating new Webhook Endpoint in Stripe: " + self.app_url)

return stripe.WebhookEndpoint.create(
url=f"{self.app_url}/_stripe/stripe-events", enabled_events=["checkout.session.completed"]
)

def run(self):
self._product = self.create_product()

if self._endpoint_secret is None:
endpoint = self.create_webhook_endpoint()
self._endpoint_secret = endpoint["secret"]
7 changes: 6 additions & 1 deletion echo/ui/src/hooks/useCreateEcho.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export default function useCreateEcho() {
const { userId } = useAuth();

const fileserverURL = lightningState?.works["fileserver"]["vars"]["_url"];
const checkoutURL = lightningState?.flows["stripe"]["vars"]["create_checkout_session_url"];

return useMutation<Echo, unknown, CreateEchoArgs>(
async ({ echoID, sourceFile, mediaType, displayName, sourceYouTubeURL }) => {
Expand All @@ -44,8 +45,12 @@ export default function useCreateEcho() {
});
},
{
onSuccess: () => {
onSuccess: echo => {
queryClient.invalidateQueries("listEchoes");

if (process.env.REACT_APP_STRIPE_ENABLED === "true") {
window.open(`${checkoutURL}?id=${echo.id}`, "_blank");
}
},
},
);
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
lightning==1.8.3.post2
git+https://github.com/Lightning-AI/lightning.git@add_support_for_request
sqlmodel==0.0.8
pydantic==1.10.2
uvicorn==0.18.3
fastapi==0.85.0
fastapi==0.88.0
python-multipart==0.0.5
humps==0.2.2
pytube==12.1.0
Expand All @@ -11,3 +11,4 @@ python-magic==0.4.27
pysrt==1.1.2
ffmpeg-python==0.2.0
youtube_dl==2021.12.17
stripe==5.0.0

0 comments on commit 43d61c2

Please sign in to comment.