diff --git a/api/api.go b/api/api.go index 4ce19f4..87d3e92 100644 --- a/api/api.go +++ b/api/api.go @@ -153,6 +153,12 @@ func (a *API) initRouter() http.Handler { // pending organization invitations log.Infow("new route", "method", "GET", "path", organizationPendingMembersEndpoint) r.Get(organizationPendingMembersEndpoint, a.pendingOrganizationMembersHandler) + // handle stripe checkout session + log.Infow("new route", "method", "POST", "path", subscriptionsCheckout) + r.Post(subscriptionsCheckout, a.createSubscriptionCheckoutHandler) + // get stripe checkout session info + log.Infow("new route", "method", "GET", "path", subscriptionsCheckoutSession) + r.Get(subscriptionsCheckoutSession, a.checkoutSessionHandler) }) // Public routes @@ -201,6 +207,7 @@ func (a *API) initRouter() http.Handler { // get subscription info log.Infow("new route", "method", "GET", "path", planInfoEndpoint) r.Get(planInfoEndpoint, a.planInfoHandler) + // handle stripe webhook log.Infow("new route", "method", "POST", "path", subscriptionsWebhook) r.Post(subscriptionsWebhook, a.handleWebhook) }) diff --git a/api/docs.md b/api/docs.md index 4e96787..b8ca007 100644 --- a/api/docs.md +++ b/api/docs.md @@ -33,8 +33,11 @@ - [🤠 Available organization members roles](#-available-organization-members-roles) - [🏛️ Available organization types](#-available-organization-types) - [🏦 Plans](#-plans) - - [🛒 Get Available Plans](#-get-plans) - - [🛍️ Get Plan Info](#-get-plan-info) + - [📋 Get Available Plans](#-get-plans) + - [📄 Get Plan Info](#-get-plan-info) +- [🔰 Subscriptions](#-subscriptions) + - [🛒 Create Checkout session](#-create-checkout-session) + - [🛍️ Get Checkout session info](#-get-checkout-session-info) @@ -324,8 +327,7 @@ This endpoint only returns the addresses of the organizations where the current "subscription":{ "PlanID":3, "StartDate":"2024-11-07T15:25:49.218Z", - "EndDate":"0001-01-01T00:00:00Z", - "RenewalDate":"0001-01-01T00:00:00Z", + "RenewalDate":"2025-11-07T15:25:49.218Z", "Active":true, "MaxCensusSize":10 }, @@ -911,8 +913,62 @@ This request can be made only by organization admins. * **Errors** +| HTTP Status | Error code | Message | +|:---:|:---:|:---| +| `400` | `40004` | `malformed JSON body` | +| `404` | `40009` | `plan not found` | +| `500` | `50001` | `internal server error` | + + +## 🔰 Subscriptions + +### 🛒 Create Checkout session + +* **Path** `/subscriptions/checkout/` +* **Method** `POST` +* **Request Body** +```json +{ + "lookupKey": 1, // PLan's corresponging DB ID + "returnURL": "https://example.com/return", + "address": "user@mail.com", + "amount": 1000, // The desired maxCensusSize +} +``` + +* **Response** +```json +{ + "id": "cs_test_a1b2c3d4e5f6g7h8i9j0", + // ... rest of stripe session attributes +} +``` + +* **Errors** + | HTTP Status | Error code | Message | |:---:|:---:|:---| | `400` | `40010` | `malformed URL parameter` | | `400` | `40023` | `plan not found` | | `500` | `50002` | `internal server error` | + +### 🛍️ Get Checkout session info + +* **Path** `/subscriptions/checkout/{sessionID}` +* **Method** `GET` +* **Response** +```json +{ + "status": "complete", // session status + "customer_email": "customer@example.com", + "subscription_status": "active" +} +``` + +* **Errors** + +| HTTP Status | Error code | Message | +|:---:|:---:|:---| +| `400` | `40010` | `malformed URL parameter` | +| `400` | `40023` | `session not found` | +| `500` | `50002` | `internal server error` | \ No newline at end of file diff --git a/api/errors_definition.go b/api/errors_definition.go index 904a229..10747e2 100644 --- a/api/errors_definition.go +++ b/api/errors_definition.go @@ -54,4 +54,5 @@ var ( ErrGenericInternalServerError = Error{Code: 50002, HTTPstatus: http.StatusInternalServerError, Err: fmt.Errorf("internal server error")} ErrCouldNotCreateFaucetPackage = Error{Code: 50003, HTTPstatus: http.StatusInternalServerError, Err: fmt.Errorf("could not create faucet package")} ErrVochainRequestFailed = Error{Code: 50004, HTTPstatus: http.StatusInternalServerError, Err: fmt.Errorf("vochain request failed")} + ErrStripeError = Error{Code: 50005, HTTPstatus: http.StatusInternalServerError, Err: fmt.Errorf("stripe error")} ) diff --git a/api/organizations.go b/api/organizations.go index dddb6c6..a6fcc31 100644 --- a/api/organizations.go +++ b/api/organizations.go @@ -496,11 +496,6 @@ func (a *API) getOrganizationSubscriptionHandler(w http.ResponseWriter, r *http. ErrNoOrganizationSubscription.Write(w) return } - if !org.Subscription.Active || - (org.Subscription.EndDate.After(time.Now()) && org.Subscription.StartDate.Before(time.Now())) { - ErrOganizationSubscriptionIncative.Write(w) - return - } // get the subscription from the database plan, err := a.db.Plan(org.Subscription.PlanID) if err != nil { diff --git a/api/routes.go b/api/routes.go index 05ead71..126f9d4 100644 --- a/api/routes.go +++ b/api/routes.go @@ -66,4 +66,8 @@ const ( planInfoEndpoint = "/plans/{planID}" // POST /subscriptions/webhook to receive the subscription webhook from stripe subscriptionsWebhook = "/subscriptions/webhook" + // POST /subscriptions/checkout to create a new subscription + subscriptionsCheckout = "/subscriptions/checkout" + // GET /subscriptions/checkout/{sessionID} to get the checkout session information + subscriptionsCheckoutSession = "/subscriptions/checkout/{sessionID}" ) diff --git a/api/stripe.go b/api/stripe.go index 36b4997..6adfd2c 100644 --- a/api/stripe.go +++ b/api/stripe.go @@ -1,25 +1,32 @@ package api import ( + "encoding/json" + "fmt" "io" "net/http" "time" + "github.com/go-chi/chi/v5" + "github.com/stripe/stripe-go/v81" "github.com/vocdoni/saas-backend/db" + stripeService "github.com/vocdoni/saas-backend/stripe" "go.vocdoni.io/dvote/log" ) // handleWebhook handles the incoming webhook event from Stripe. -// It takes the API data and signature as input parameters and returns the session ID and an error (if any). -// The request body and Stripe-Signature header are passed to ConstructEvent, along with the webhook signing key. -// If the event type is "customer.subscription.created", it unmarshals the event data into a CheckoutSession struct -// and returns the session ID. Otherwise, it returns an empty string. +// It processes various subscription-related events (created, updated, deleted) +// and updates the organization's subscription status accordingly. +// The webhook verifies the Stripe signature and handles different event types: +// - customer.subscription.created: Creates a new subscription for an organization +// - customer.subscription.updated: Updates an existing subscription +// - customer.subscription.deleted: Reverts to the default plan +// If any error occurs during processing, it returns an appropriate HTTP status code. func (a *API) handleWebhook(w http.ResponseWriter, r *http.Request) { const MaxBodyBytes = int64(65536) r.Body = http.MaxBytesReader(w, r.Body, MaxBodyBytes) payload, err := io.ReadAll(r.Body) if err != nil { - log.Errorf("stripe webhook: Error reading request body: %s\n", err.Error()) w.WriteHeader(http.StatusBadRequest) return @@ -35,55 +42,186 @@ func (a *API) handleWebhook(w http.ResponseWriter, r *http.Request) { // Unmarshal the event data into an appropriate struct depending on its Type switch event.Type { case "customer.subscription.created": - customer, subscription, err := a.stripe.GetInfoFromEvent(*event) + log.Infof("received stripe event Type: %s", event.Type) + stripeSubscriptionInfo, org, err := a.getSubscriptionOrgInfo(event) if err != nil { - log.Errorf("stripe webhook: error getting info from event: %s\n", err.Error()) - w.WriteHeader(http.StatusBadRequest) - return - } - address := subscription.Metadata["address"] - if len(address) == 0 { - log.Errorf("subscription %s does not contain an address in metadata", subscription.ID) - w.WriteHeader(http.StatusBadRequest) - return - } - org, _, err := a.db.Organization(address, false) - if err != nil || org == nil { log.Errorf("could not update subscription %s, a corresponding organization with address %s was not found.", - subscription.ID, address) - log.Errorf("please do manually for creator %s \n Error: %s", customer.Email, err.Error()) + stripeSubscriptionInfo.ID, stripeSubscriptionInfo.OrganizationAddress) + log.Errorf("please do manually for creator %s \n Error: %s", stripeSubscriptionInfo.CustomerEmail, err.Error()) w.WriteHeader(http.StatusBadRequest) return } - dbSubscription, err := a.db.PlanByStripeId(subscription.Items.Data[0].Plan.Product.ID) + dbSubscription, err := a.db.PlanByStripeId(stripeSubscriptionInfo.ProductID) if err != nil || dbSubscription == nil { log.Errorf("could not update subscription %s, a corresponding subscription was not found.", - subscription.ID) + stripeSubscriptionInfo.ID) log.Errorf("please do manually: %s", err.Error()) w.WriteHeader(http.StatusBadRequest) return } - startDate := time.Unix(subscription.CurrentPeriodStart, 0) - endDate := time.Unix(subscription.CurrentPeriodEnd, 0) - renewalDate := time.Unix(subscription.BillingCycleAnchor, 0) organizationSubscription := &db.OrganizationSubscription{ PlanID: dbSubscription.ID, - StartDate: startDate, - EndDate: endDate, - RenewalDate: renewalDate, - Active: subscription.Status == "active", - MaxCensusSize: int(subscription.Items.Data[0].Quantity), - Email: customer.Email, + StartDate: stripeSubscriptionInfo.StartDate, + RenewalDate: stripeSubscriptionInfo.EndDate, + Active: stripeSubscriptionInfo.Status == "active", + MaxCensusSize: stripeSubscriptionInfo.Quantity, + Email: stripeSubscriptionInfo.CustomerEmail, } // TODO will only worked for new subscriptions if err := a.db.SetOrganizationSubscription(org.Address, organizationSubscription); err != nil { - log.Errorf("could not update subscription %s for organization %s: %s", subscription.ID, org.Address, err.Error()) + log.Errorf("could not update subscription %s for organization %s: %s", stripeSubscriptionInfo.ID, org.Address, err.Error()) + w.WriteHeader(http.StatusBadRequest) + return + } + log.Infof("stripe webhook: subscription %s for organization %s processed successfully", stripeSubscriptionInfo.ID, org.Address) + case "customer.subscription.updated", "customer.subscription.deleted": + log.Infof("received stripe event Type: %s", event.Type) + stripeSubscriptionInfo, org, err := a.getSubscriptionOrgInfo(event) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + orgPlan, err := a.db.Plan(org.Subscription.PlanID) + if err != nil || orgPlan == nil { + log.Errorf("could not update subscription %s", stripeSubscriptionInfo.ID) + log.Errorf("a corresponding plan with id %d for organization with address %s was not found", + org.Subscription.PlanID, stripeSubscriptionInfo.OrganizationAddress) + log.Errorf("please do manually for creator %s \n Error: %s", stripeSubscriptionInfo.CustomerEmail, err.Error()) + w.WriteHeader(http.StatusBadRequest) + return + } + if stripeSubscriptionInfo.Status == "canceled" && stripeSubscriptionInfo.ProductID == orgPlan.StripeID { + // replace organization subscription with the default plan + defaultPlan, err := a.db.DefaultPlan() + if err != nil || defaultPlan == nil { + ErrNoDefaultPLan.WithErr((err)).Write(w) + return + } + orgSubscription := &db.OrganizationSubscription{ + PlanID: defaultPlan.ID, + StartDate: time.Now(), + Active: true, + MaxCensusSize: defaultPlan.Organization.MaxCensus, + } + if err := a.db.SetOrganizationSubscription(org.Address, orgSubscription); err != nil { + log.Errorf("could not cancel subscription %s for organization %s: %s", stripeSubscriptionInfo.ID, org.Address, err.Error()) + w.WriteHeader(http.StatusBadRequest) + return + } + } else if stripeSubscriptionInfo.Status == "active" && !org.Subscription.Active { + org.Subscription.Active = true + if err := a.db.SetOrganization(org); err != nil { + log.Errorf("could not activate organizations %s subscription to active: %s", org.Address, err.Error()) + w.WriteHeader(http.StatusBadRequest) + return + } + } + log.Infof("stripe webhook: subscription %s for organization %s processed as %s successfully", + stripeSubscriptionInfo.ID, org.Address, stripeSubscriptionInfo.Status) + default: + log.Infof("received stripe event Type: %s", event.Type) + customer, subscription, err := a.stripe.GetInfoFromEvent(*event) + if err != nil { + log.Errorf("could not decode event for customer with email with address %s was not found. "+ + "Error: %s", customer.Email, err.Error()) w.WriteHeader(http.StatusBadRequest) return } - log.Debugf("stripe webhook: subscription %s for organization %s processed successfully", subscription.ID, org.Address) + if subscription != nil { + stripeSubscriptionInfo, err := a.stripe.GetSubscriptionInfoFromEvent(*event) + if err != nil { + log.Errorf("could not decode event for subscription %s for customer %s. Error: %s", + stripeSubscriptionInfo.ID, customer.Email, err.Error()) + w.WriteHeader(http.StatusBadRequest) + return + } + log.Infof("stripe webhook: event subscription %s for organization %s and customer %s received", + stripeSubscriptionInfo.ID, stripeSubscriptionInfo.OrganizationAddress, customer.Email) + + } else { + log.Infof("stripe webhook: subscription %s for customer %s received as %s successfully", + subscription.ID, customer.Email, subscription.Status) + } } w.WriteHeader(http.StatusOK) } + +// createSubscriptionCheckoutHandler handles requests to create a new Stripe checkout session +// for subscription purchases. +func (a *API) createSubscriptionCheckoutHandler(w http.ResponseWriter, r *http.Request) { + checkout := &SubscriptionCheckout{} + if err := json.NewDecoder(r.Body).Decode(checkout); err != nil { + ErrMalformedBody.Write(w) + return + } + + if checkout.Amount == 0 || checkout.Address == "" { + ErrMalformedBody.Withf("Missing required fields").Write(w) + return + } + + // TODO check if the user has another active paid subscription + + plan, err := a.db.Plan(checkout.LookupKey) + if err != nil { + ErrMalformedURLParam.Withf("Plan not found: %v", err).Write(w) + return + } + + session, err := a.stripe.CreateSubscriptionCheckoutSession( + plan.StripePriceID, checkout.ReturnURL, checkout.Address, checkout.Locale, checkout.Amount) + if err != nil { + ErrStripeError.Withf("Cannot create session: %v", err).Write(w) + return + } + + data := &struct { + ClientSecret string `json:"clientSecret"` + SessionID string `json:"sessionID"` + }{ + ClientSecret: session.ClientSecret, + SessionID: session.ID, + } + httpWriteJSON(w, data) +} + +// checkoutSessionHandler retrieves the status of a Stripe checkout session. +func (a *API) checkoutSessionHandler(w http.ResponseWriter, r *http.Request) { + sessionID := chi.URLParam(r, "sessionID") + if sessionID == "" { + ErrMalformedURLParam.Withf("sessionID is required").Write(w) + return + } + status, err := a.stripe.RetrieveCheckoutSession(sessionID) + if err != nil { + ErrStripeError.Withf("Cannot get session: %v", err).Write(w) + return + } + + httpWriteJSON(w, status) +} + +// getSubscriptionOrgInfo is a helper function that retrieves the subscription information from +// the subscription event and the Organization information from the database. +func (a *API) getSubscriptionOrgInfo(event *stripe.Event) (*stripeService.StripeSubscriptionInfo, *db.Organization, error) { + stripeSubscriptionInfo, err := a.stripe.GetSubscriptionInfoFromEvent(*event) + if err != nil { + return nil, nil, fmt.Errorf("could not decode event for subscription: %s", err.Error()) + } + org, _, err := a.db.Organization(stripeSubscriptionInfo.CustomerEmail, false) + if err != nil || org == nil { + log.Errorf("could not update subscription %s, a corresponding organization with address %s was not found.", + stripeSubscriptionInfo.ID, stripeSubscriptionInfo.OrganizationAddress) + log.Errorf("please do manually for creator %s \n Error: %s", stripeSubscriptionInfo.CustomerEmail, err.Error()) + if org == nil { + return nil, nil, fmt.Errorf("no organization found with address %s", stripeSubscriptionInfo.OrganizationAddress) + } else { + return nil, nil, fmt.Errorf("could not retrieve organization with address %s: %s", + stripeSubscriptionInfo.OrganizationAddress, err.Error()) + } + } + + return stripeSubscriptionInfo, org, nil +} diff --git a/api/types.go b/api/types.go index 3804bb1..f0d0a72 100644 --- a/api/types.go +++ b/api/types.go @@ -183,3 +183,12 @@ type OrganizationSubscriptionInfo struct { Usage *db.OrganizationCounters `json:"usage"` Plan *db.Plan `json:"plan"` } + +// SubscriptionCheckout represents the details required for a subscription checkout process. +type SubscriptionCheckout struct { + LookupKey uint64 `json:"lookupKey"` + ReturnURL string `json:"returnURL"` + Amount int64 `json:"amount"` + Address string `json:"address"` + Locale string `json:"locale"` +} diff --git a/db/organizations_test.go b/db/organizations_test.go index 6d38bf2..e9f5127 100644 --- a/db/organizations_test.go +++ b/db/organizations_test.go @@ -196,13 +196,11 @@ func TestAddOrganizationPlan(t *testing.T) { // add a subscription to the organization subscriptionName := "testPlan" startDate := time.Now() - endDate := startDate.AddDate(1, 0, 0) active := true stripeID := "stripeID" orgSubscription := &OrganizationSubscription{ PlanID: 100, StartDate: startDate, - EndDate: endDate, Active: true, } // using a non existing subscription should fail diff --git a/db/types.go b/db/types.go index 71d1df8..1f0dd16 100644 --- a/db/types.go +++ b/db/types.go @@ -96,6 +96,7 @@ type Plan struct { ID uint64 `json:"id" bson:"_id"` Name string `json:"name" bson:"name"` StripeID string `json:"stripeID" bson:"stripeID"` + StripePriceID string `json:"stripePriceID" bson:"stripePriceID"` StartingPrice int64 `json:"startingPrice" bson:"startingPrice"` Default bool `json:"default" bson:"default"` Organization PlanLimits `json:"organization" bson:"organization"` @@ -112,7 +113,6 @@ type PlanTier struct { type OrganizationSubscription struct { PlanID uint64 `json:"planID" bson:"planID"` StartDate time.Time `json:"startDate" bson:"startDate"` - EndDate time.Time `json:"endDate" bson:"endDate"` RenewalDate time.Time `json:"renewalDate" bson:"renewalDate"` Active bool `json:"active" bson:"active"` MaxCensusSize int `json:"maxCensusSize" bson:"maxCensusSize"` diff --git a/stripe/stripe.go b/stripe/stripe.go index b30f080..8ce15f8 100644 --- a/stripe/stripe.go +++ b/stripe/stripe.go @@ -3,8 +3,10 @@ package stripe import ( "encoding/json" "fmt" + "time" "github.com/stripe/stripe-go/v81" + "github.com/stripe/stripe-go/v81/checkout/session" "github.com/stripe/stripe-go/v81/customer" "github.com/stripe/stripe-go/v81/price" "github.com/stripe/stripe-go/v81/product" @@ -13,6 +15,7 @@ import ( "go.vocdoni.io/dvote/log" ) +// ProductsIDs contains the Stripe product IDs for different subscription tiers var ProductsIDs = []string{ "prod_R3LTVsjklmuQAL", // Essential "prod_R0kTryoMNl8I19", // Premium @@ -20,6 +23,26 @@ var ProductsIDs = []string{ "prod_RHurAb3OjkgJRy", // Custom } +// ReturnStatus represents the response structure for checkout session status +type ReturnStatus struct { + Status string `json:"status"` + CustomerEmail string `json:"customer_email"` + SubscriptionStatus string `json:"subscription_status"` +} + +// StripeSubscriptionInfo represents the information related to a Stripe subscription +// that are relevant for the application. +type StripeSubscriptionInfo struct { + ID string + Status string + ProductID string + Quantity int + OrganizationAddress string + CustomerEmail string + StartDate time.Time + EndDate time.Time +} + // StripeClient is a client for interacting with the Stripe API. // It holds the necessary configuration such as the webhook secret. type StripeClient struct { @@ -36,6 +59,7 @@ func New(apiSecret, webhookSecret string) *StripeClient { } // DecodeEvent decodes a Stripe webhook event from the given payload and signature header. +// It verifies the webhook signature and returns the decoded event or an error if validation fails. func (s *StripeClient) DecodeEvent(payload []byte, signatureHeader string) (*stripe.Event, error) { event := stripe.Event{} if err := json.Unmarshal(payload, &event); err != nil { @@ -52,6 +76,8 @@ func (s *StripeClient) DecodeEvent(payload []byte, signatureHeader string) (*str } // GetInfoFromEvent processes a Stripe event to extract customer and subscription information. +// It unmarshals the event data and retrieves the associated customer details. +// Returns the customer and subscription objects, or an error if processing fails. func (s *StripeClient) GetInfoFromEvent(event stripe.Event) (*stripe.Customer, *stripe.Subscription, error) { var subscription stripe.Subscription err := json.Unmarshal(event.Data.Raw, &subscription) @@ -69,6 +95,37 @@ func (s *StripeClient) GetInfoFromEvent(event stripe.Event) (*stripe.Customer, * return customer, &subscription, nil } +// GetSubscriptionInfoFromEvent processes a Stripe event to extract subscription information. +// It unmarshals the event data and retrieves the associated customer and subscription details. +func (s *StripeClient) GetSubscriptionInfoFromEvent(event stripe.Event) (*StripeSubscriptionInfo, error) { + customer, subscription, err := s.GetInfoFromEvent(event) + if err != nil { + return &StripeSubscriptionInfo{}, fmt.Errorf("error getting info from event: %s\n", err.Error()) + } + address := subscription.Metadata["address"] + if len(address) == 0 { + return &StripeSubscriptionInfo{}, fmt.Errorf("subscription %s does not contain an address in metadata", subscription.ID) + } + + if len(subscription.Items.Data) == 0 { + return &StripeSubscriptionInfo{}, fmt.Errorf("subscription %s does not contain any items", subscription.ID) + } + + return &StripeSubscriptionInfo{ + ID: subscription.ID, + Status: string(subscription.Status), + ProductID: subscription.Items.Data[0].Plan.Product.ID, + Quantity: int(subscription.Items.Data[0].Quantity), + OrganizationAddress: address, + CustomerEmail: customer.Email, + StartDate: time.Unix(subscription.CurrentPeriodStart, 0), + EndDate: time.Unix(subscription.CurrentPeriodEnd, 0), + }, nil +} + +// GetPriceByID retrieves a Stripe price object by its ID. +// It searches for an active price with the given lookup key. +// Returns nil if no matching price is found. func (s *StripeClient) GetPriceByID(priceID string) *stripe.Price { params := &stripe.PriceSearchParams{ SearchParams: stripe.SearchParams{ @@ -82,6 +139,9 @@ func (s *StripeClient) GetPriceByID(priceID string) *stripe.Price { return nil } +// GetProductByID retrieves a Stripe product by its ID. +// It expands the default price and its tiers in the response. +// Returns the product object and any error encountered. func (s *StripeClient) GetProductByID(productID string) (*stripe.Product, error) { params := &stripe.ProductParams{} params.AddExpand("default_price") @@ -93,6 +153,9 @@ func (s *StripeClient) GetProductByID(productID string) (*stripe.Product, error) return product, nil } +// GetPrices retrieves multiple Stripe prices by their IDs. +// It returns a slice of Price objects for all valid price IDs. +// Invalid or non-existent price IDs are silently skipped. func (s *StripeClient) GetPrices(priceIDs []string) []*stripe.Price { var prices []*stripe.Price for _, priceID := range priceIDs { @@ -103,6 +166,9 @@ func (s *StripeClient) GetPrices(priceIDs []string) []*stripe.Price { return prices } +// GetPlans retrieves and constructs a list of subscription plans from Stripe products. +// It processes product metadata to extract organization limits, voting types, and features. +// Returns a slice of Plan objects and any error encountered during processing. func (s *StripeClient) GetPlans() ([]*db.Plan, error) { var plans []*db.Plan for i, productID := range ProductsIDs { @@ -138,7 +204,8 @@ func (s *StripeClient) GetPlans() ([]*db.Plan, error) { ID: uint64(i), Name: product.Name, StartingPrice: startingPrice, - StripeID: price.ID, + StripeID: productID, + StripePriceID: price.ID, Default: price.Metadata["Default"] == "true", Organization: organizationData, VotingTypes: votingTypesData, @@ -151,3 +218,73 @@ func (s *StripeClient) GetPlans() ([]*db.Plan, error) { } return plans, nil } + +// CreateSubscriptionCheckoutSession creates a new Stripe checkout session for a subscription. +// It configures the session with the specified price, amount return URL, and subscription metadata. +// The priceID is that is provided corrsponds to the subscription tier selected by the user. +// Returns the created checkout session and any error encountered. +// Overview of stripe checkout mechanics: https://docs.stripe.com/checkout/custom/quickstart +// API description https://docs.stripe.com/api/checkout/sessions +func (s *StripeClient) CreateSubscriptionCheckoutSession( + priceID, returnURL, address, locale string, amount int64, +) (*stripe.CheckoutSession, error) { + if len(locale) == 0 { + locale = "auto" + } + checkoutParams := &stripe.CheckoutSessionParams{ + // Subscription mode + Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), + LineItems: []*stripe.CheckoutSessionLineItemParams{ + { + Price: stripe.String(priceID), + Quantity: stripe.Int64(amount), + }, + }, + // UI mode is set to embedded, since the client is integrated in our UI + UIMode: stripe.String(string(stripe.CheckoutSessionUIModeEmbedded)), + // Automatic tax calculation is enabled + AutomaticTax: &stripe.CheckoutSessionAutomaticTaxParams{ + Enabled: stripe.Bool(true), + }, + // We store in the metadata the address of the organization + SubscriptionData: &stripe.CheckoutSessionSubscriptionDataParams{ + Metadata: map[string]string{ + "address": address, + }, + }, + // The locale is being used to configure the language of the embedded client + Locale: stripe.String(locale), + } + + // The returnURL is used to redirect the user after the payment is completed + if len(returnURL) > 0 { + checkoutParams.ReturnURL = stripe.String(returnURL + "/{CHECKOUT_SESSION_ID}") + } else { + checkoutParams.RedirectOnCompletion = stripe.String("never") + } + session, err := session.New(checkoutParams) + if err != nil { + return nil, err + } + + return session, nil +} + +// RetrieveCheckoutSession retrieves a checkout session from Stripe by session ID. +// It returns a ReturnStatus object and an error if any. +// The ReturnStatus object contains information about the session status, +// customer email, and subscription status. +func (s *StripeClient) RetrieveCheckoutSession(sessionID string) (*ReturnStatus, error) { + params := &stripe.CheckoutSessionParams{} + params.AddExpand("line_items") + sess, err := session.Get(sessionID, params) + if err != nil { + return nil, err + } + data := &ReturnStatus{ + Status: string(sess.Status), + CustomerEmail: sess.CustomerDetails.Email, + SubscriptionStatus: string(sess.Subscription.Status), + } + return data, nil +}