Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions internal/api/v1beta1connect/billing_checkout.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package v1beta1connect

import (
"context"
"errors"

"connectrpc.com/connect"
"github.com/raystack/frontier/billing/checkout"
"github.com/raystack/frontier/billing/product"
"github.com/raystack/frontier/billing/subscription"
frontierv1beta1 "github.com/raystack/frontier/proto/v1beta1"
"google.golang.org/protobuf/types/known/timestamppb"
)

type CheckoutService interface {
Expand All @@ -19,6 +21,83 @@ type CheckoutService interface {
CreateSessionForCustomerPortal(ctx context.Context, ch checkout.Checkout) (checkout.Checkout, error)
}

func (h *ConnectHandler) CreateCheckout(ctx context.Context, request *connect.Request[frontierv1beta1.CreateCheckoutRequest]) (*connect.Response[frontierv1beta1.CreateCheckoutResponse], error) {
// check if setup requested
if request.Msg.GetSetupBody() != nil && request.Msg.GetSetupBody().GetPaymentMethod() {
newCheckout, err := h.checkoutService.CreateSessionForPaymentMethod(ctx, checkout.Checkout{
CustomerID: request.Msg.GetBillingId(),
SuccessUrl: request.Msg.GetSuccessUrl(),
CancelUrl: request.Msg.GetCancelUrl(),
})
if err != nil {
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
}

return connect.NewResponse(&frontierv1beta1.CreateCheckoutResponse{
CheckoutSession: transformCheckoutToPB(newCheckout),
}), nil
}

// check if customer portal requested
if request.Msg.GetSetupBody() != nil && request.Msg.GetSetupBody().GetCustomerPortal() {
newCheckout, err := h.checkoutService.CreateSessionForCustomerPortal(ctx, checkout.Checkout{
CustomerID: request.Msg.GetBillingId(),
SuccessUrl: request.Msg.GetSuccessUrl(),
CancelUrl: request.Msg.GetCancelUrl(),
})
if err != nil {
if errors.Is(err, checkout.ErrKycCompleted) {
return nil, connect.NewError(connect.CodeFailedPrecondition, ErrPortalChangesKycCompleted)
}
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
}

return connect.NewResponse(&frontierv1beta1.CreateCheckoutResponse{
CheckoutSession: transformCheckoutToPB(newCheckout),
}), nil
}

// check if checkout requested (subscription or product)
if request.Msg.GetSubscriptionBody() == nil && request.Msg.GetProductBody() == nil {
return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest)
}
planID := ""
var skipTrial bool
var cancelAfterTrial bool
if request.Msg.GetSubscriptionBody() != nil {
planID = request.Msg.GetSubscriptionBody().GetPlan()
skipTrial = request.Msg.GetSubscriptionBody().GetSkipTrial()
cancelAfterTrial = request.Msg.GetSubscriptionBody().GetCancelAfterTrial()
}

var featureID string
var quantity int64
if request.Msg.GetProductBody() != nil {
featureID = request.Msg.GetProductBody().GetProduct()
quantity = request.Msg.GetProductBody().GetQuantity()
}
newCheckout, err := h.checkoutService.Create(ctx, checkout.Checkout{
CustomerID: request.Msg.GetBillingId(),
SuccessUrl: request.Msg.GetSuccessUrl(),
CancelUrl: request.Msg.GetCancelUrl(),
PlanID: planID,
ProductID: featureID,
Quantity: quantity,
SkipTrial: skipTrial,
CancelAfterTrial: cancelAfterTrial,
})
if err != nil {
if errors.Is(err, product.ErrPerSeatLimitReached) {
return nil, connect.NewError(connect.CodeInvalidArgument, ErrPerSeatLimitReached)
}
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
}

return connect.NewResponse(&frontierv1beta1.CreateCheckoutResponse{
CheckoutSession: transformCheckoutToPB(newCheckout),
}), nil
}

func (h *ConnectHandler) DelegatedCheckout(ctx context.Context, request *connect.Request[frontierv1beta1.DelegatedCheckoutRequest]) (*connect.Response[frontierv1beta1.DelegatedCheckoutResponse], error) {
var planID string
var skipTrial bool
Expand Down Expand Up @@ -67,3 +146,54 @@ func (h *ConnectHandler) DelegatedCheckout(ctx context.Context, request *connect
Product: productPb,
}), nil
}

func (h *ConnectHandler) ListCheckouts(ctx context.Context, request *connect.Request[frontierv1beta1.ListCheckoutsRequest]) (*connect.Response[frontierv1beta1.ListCheckoutsResponse], error) {
if request.Msg.GetOrgId() == "" {
return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest)
}

var checkouts []*frontierv1beta1.CheckoutSession
checkoutList, err := h.checkoutService.List(ctx, checkout.Filter{
CustomerID: request.Msg.GetBillingId(),
})
if err != nil {
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
}
for _, v := range checkoutList {
checkouts = append(checkouts, transformCheckoutToPB(v))
}

return connect.NewResponse(&frontierv1beta1.ListCheckoutsResponse{
CheckoutSessions: checkouts,
}), nil
}

func (h *ConnectHandler) GetCheckout(ctx context.Context, request *connect.Request[frontierv1beta1.GetCheckoutRequest]) (*connect.Response[frontierv1beta1.GetCheckoutResponse], error) {
if request.Msg.GetOrgId() == "" || request.Msg.GetId() == "" {
return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest)
}

ch, err := h.checkoutService.GetByID(ctx, request.Msg.GetId())
if err != nil {
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
}

return connect.NewResponse(&frontierv1beta1.GetCheckoutResponse{
CheckoutSession: transformCheckoutToPB(ch),
}), nil
}

func transformCheckoutToPB(ch checkout.Checkout) *frontierv1beta1.CheckoutSession {
return &frontierv1beta1.CheckoutSession{
Id: ch.ID,
CheckoutUrl: ch.CheckoutUrl,
SuccessUrl: ch.SuccessUrl,
CancelUrl: ch.CancelUrl,
State: ch.State,
Plan: ch.PlanID,
Product: ch.ProductID,
CreatedAt: timestamppb.New(ch.CreatedAt),
UpdatedAt: timestamppb.New(ch.UpdatedAt),
ExpireAt: timestamppb.New(ch.ExpireAt),
}
}
Loading
Loading