From f2911fd3823c4f9e06f802a3526e9d0c684fe0d1 Mon Sep 17 00:00:00 2001 From: Kush Sharma Date: Sun, 20 Oct 2024 10:12:02 +0530 Subject: [PATCH] feat: generate invoice for overdraft credits It's necessary to configure the product name to calculate per unit price before the overdraft credits can be invoiced in `billing.customer.credit_overdraft_product`. If not set, invoice reconcilation and creation is skipped. If there is already an open(unpaid) invoice for overdraft credits, no new invoice will be created for that customer. Reconcilation of these invoice happens on same cadence as invoices gets synced from the billing provider. Signed-off-by: Kush Sharma --- billing/config.go | 4 + billing/credit/credit.go | 14 +- billing/customer/customer.go | 3 + billing/invoice/invoice.go | 55 ++- billing/invoice/service.go | 419 ++++++++++++++++-- billing/plan/plan.go | 2 + billing/usage/service.go | 3 +- cmd/serve.go | 3 +- config/sample.config.yaml | 3 + docs/docs/reference/configurations.md | 3 + go.mod | 1 + go.sum | 3 + internal/api/v1beta1/billing_invoice.go | 2 +- .../postgres/billing_customer_repository.go | 8 + .../billing_customer_repository_test.go | 74 +++- .../postgres/billing_invoice_repository.go | 53 ++- .../billing_transactions_repository.go | 8 +- ...41015201506_billing_invoice_items.down.sql | 1 + ...0241015201506_billing_invoice_items.up.sql | 1 + pkg/db/config.go | 4 +- pkg/db/db.go | 70 ++- pkg/utils/pointers.go | 15 + 22 files changed, 675 insertions(+), 74 deletions(-) create mode 100644 internal/store/postgres/migrations/20241015201506_billing_invoice_items.down.sql create mode 100644 internal/store/postgres/migrations/20241015201506_billing_invoice_items.up.sql create mode 100644 pkg/utils/pointers.go diff --git a/billing/config.go b/billing/config.go index 40e8d626e..bed8a3c92 100644 --- a/billing/config.go +++ b/billing/config.go @@ -30,6 +30,10 @@ type AccountConfig struct { DefaultPlan string `yaml:"default_plan" mapstructure:"default_plan"` DefaultOffline bool `yaml:"default_offline" mapstructure:"default_offline"` OnboardCreditsWithOrg int64 `yaml:"onboard_credits_with_org" mapstructure:"onboard_credits_with_org"` + + // CreditOverdraftProduct helps identify the product pricing per unit amount for the overdraft + // credits being invoiced + CreditOverdraftProduct string `yaml:"credit_overdraft_product" mapstructure:"credit_overdraft_product"` } type PlanChangeConfig struct { diff --git a/billing/credit/credit.go b/billing/credit/credit.go index 822367fd6..feef225ab 100644 --- a/billing/credit/credit.go +++ b/billing/credit/credit.go @@ -2,6 +2,7 @@ package credit import ( "errors" + "strings" "time" "github.com/google/uuid" @@ -20,10 +21,11 @@ var ( // TxNamespaceUUID is the namespace for generating transaction UUIDs deterministically TxNamespaceUUID = uuid.MustParse("967416d0-716e-4308-b58f-2468ac14f20a") - SourceSystemBuyEvent = "system.buy" - SourceSystemAwardedEvent = "system.awarded" - SourceSystemOnboardEvent = "system.starter" - SourceSystemRevertEvent = "system.revert" + SourceSystemBuyEvent = "system.buy" + SourceSystemAwardedEvent = "system.awarded" + SourceSystemOnboardEvent = "system.starter" + SourceSystemRevertEvent = "system.revert" + SourceSystemOverdraftEvent = "system.overdraft" ) type TransactionType string @@ -71,3 +73,7 @@ type Filter struct { StartRange time.Time EndRange time.Time } + +func TxUUID(tags ...string) string { + return uuid.NewSHA1(TxNamespaceUUID, []byte(strings.Join(tags, ":"))).String() +} diff --git a/billing/customer/customer.go b/billing/customer/customer.go index 5d9a16072..008215676 100644 --- a/billing/customer/customer.go +++ b/billing/customer/customer.go @@ -82,6 +82,9 @@ type Filter struct { OrgID string ProviderID string State State + + Online *bool + AllowedOverdraft *bool } type PaymentMethod struct { diff --git a/billing/invoice/invoice.go b/billing/invoice/invoice.go index ff4e652ae..b4df23ca7 100644 --- a/billing/invoice/invoice.go +++ b/billing/invoice/invoice.go @@ -14,11 +14,31 @@ var ( ErrInvalidDetail = fmt.Errorf("invalid invoice detail") ) +const ( + ItemTypeMetadataKey = "item_type" + ReconciledMetadataKey = "reconciled" + + GenerateForCreditLockKey = "generate_for_credit" +) + +type State string + +func (s State) String() string { + return string(s) +} + +const ( + DraftState State = "draft" + OpenState State = "open" + PaidState State = "paid" +) + type Invoice struct { - ID string - CustomerID string - ProviderID string - State string + ID string + CustomerID string + ProviderID string + // State could be one of draft, open, paid, uncollectible, void + State State Currency string Amount int64 HostedURL string @@ -28,12 +48,39 @@ type Invoice struct { PeriodStartAt time.Time PeriodEndAt time.Time + Items []Item Metadata metadata.Metadata } +type ItemType string + +func (t ItemType) String() string { + return string(t) +} + +const ( + // CreditItemType is used to charge for the credits used in the system + // as overdraft + CreditItemType ItemType = "credit" +) + +type Item struct { + ID string `json:"id"` + ProviderID string `json:"provider_id"` + // Name is the item name + Name string `json:"name"` + // Type is the item type + Type ItemType `json:"type"` + // UnitAmount is per unit cost + UnitAmount int64 `json:"unit_amount"` + // Quantity is the number of units + Quantity int64 `json:"quantity"` +} + type Filter struct { CustomerID string NonZeroOnly bool + State State Pagination *pagination.Pagination } diff --git a/billing/invoice/service.go b/billing/invoice/service.go index 759d9b0eb..295451e82 100644 --- a/billing/invoice/service.go +++ b/billing/invoice/service.go @@ -8,6 +8,12 @@ import ( "sync" "time" + "github.com/raystack/frontier/pkg/db" + + "github.com/google/uuid" + "github.com/raystack/frontier/billing/credit" + "github.com/raystack/frontier/billing/product" + "github.com/robfig/cron/v3" "github.com/stripe/stripe-go/v79" @@ -36,31 +42,60 @@ type CustomerService interface { List(ctx context.Context, filter customer.Filter) ([]customer.Customer, error) } +type CreditService interface { + GetBalance(ctx context.Context, accountID string) (int64, error) + Add(ctx context.Context, cred credit.Credit) error +} + +type ProductService interface { + GetByID(ctx context.Context, id string) (product.Product, error) +} + +type Locker interface { + TryLock(ctx context.Context, id string) (*db.Lock, error) +} + type Service struct { stripeClient *client.API repository Repository customerService CustomerService + creditService CreditService + productService ProductService + locker Locker syncJob *cron.Cron mu sync.Mutex syncDelay time.Duration + + stripeAutoTax bool + creditOverdraftProduct string + creditOverdraftUnitAmount int64 + creditOverdraftInvoiceCurrency string + creditOverdraftInvoiceDOM int } func NewService(stripeClient *client.API, invoiceRepository Repository, - customerService CustomerService, cfg billing.Config) *Service { + customerService CustomerService, creditService CreditService, productService ProductService, + locker Locker, cfg billing.Config) *Service { return &Service{ - stripeClient: stripeClient, - repository: invoiceRepository, - customerService: customerService, - syncDelay: cfg.RefreshInterval.Invoice, + stripeClient: stripeClient, + repository: invoiceRepository, + customerService: customerService, + creditService: creditService, + productService: productService, + locker: locker, + syncDelay: cfg.RefreshInterval.Invoice, + stripeAutoTax: cfg.StripeAutoTax, + creditOverdraftProduct: cfg.AccountConfig.CreditOverdraftProduct, + creditOverdraftInvoiceDOM: 1, // 1st day of month } } func (s *Service) Init(ctx context.Context) error { + logger := grpczap.Extract(ctx) if s.syncJob != nil { s.syncJob.Stop() } - s.syncJob = cron.New(cron.WithChain( cron.SkipIfStillRunning(cron.DefaultLogger), cron.Recover(cron.DefaultLogger), @@ -74,6 +109,29 @@ func (s *Service) Init(ctx context.Context) error { return err } s.syncJob.Start() + + if s.creditOverdraftProduct != "" { + creditProduct, err := s.productService.GetByID(ctx, s.creditOverdraftProduct) + if err != nil { + return fmt.Errorf("failed to get credit overdraft product: %w", err) + } + if creditProduct.Behavior != product.CreditBehavior { + return errors.New("credit overdraft product must have credit behavior") + } + // get first price + if len(creditProduct.Prices) == 0 { + return errors.New("credit overdraft product must have at least one price") + } + creditPrice := creditProduct.Prices[0] + if creditPrice.Currency == "" { + return errors.New("credit overdraft product price must have a currency") + } + s.creditOverdraftInvoiceCurrency = creditPrice.Currency + s.creditOverdraftUnitAmount = int64(float64(creditPrice.Amount) / float64(creditProduct.Config.CreditAmount)) + logger.Info("credit overdraft product details", + zap.Int64("unit_amount", s.creditOverdraftUnitAmount), + zap.String("currency", s.creditOverdraftInvoiceCurrency)) + } return nil } @@ -91,25 +149,35 @@ func (s *Service) backgroundSync(ctx context.Context) { defer record() } logger := grpczap.Extract(ctx) - customers, err := s.customerService.List(ctx, customer.Filter{}) + customers, err := s.customerService.List(ctx, customer.Filter{ + Online: utils.Bool(true), + }) if err != nil { logger.Error("invoice.backgroundSync", zap.Error(err)) return } - for _, customer := range customers { + for _, customr := range customers { if ctx.Err() != nil { // stop processing if context is done break } - if !customer.IsActive() || customer.ProviderID == "" { + if !customr.IsActive() { continue } - if err := s.SyncWithProvider(ctx, customer); err != nil { + if err := s.SyncWithProvider(ctx, customr); err != nil { logger.Error("invoice.SyncWithProvider", zap.Error(err)) } time.Sleep(time.Duration(rand.Intn(1000)) * time.Millisecond) } + if err := s.Reconcile(ctx); err != nil { + logger.Error("invoice.Reconcile", zap.Error(err)) + } + if now := time.Now().UTC(); now.Day() == s.creditOverdraftInvoiceDOM { + if err := s.GenerateForCredits(ctx); err != nil { + logger.Error("invoice.GenerateForCredits", zap.Error(err)) + } + } logger.Info("invoice.backgroundSync finished", zap.Duration("duration", time.Since(start))) } @@ -130,6 +198,9 @@ func (s *Service) SyncWithProvider(ctx context.Context, customr customer.Custome ListParams: stripe.ListParams{ Context: ctx, }, + Expand: []*string{ + stripe.String("data.lines"), + }, }) for stripeInvoices.Next() { stripeInvoice := stripeInvoices.Invoice() @@ -139,30 +210,12 @@ func (s *Service) SyncWithProvider(ctx context.Context, customr customer.Custome return i.ProviderID == stripeInvoice.ID }) if ok { - // already present in our system, update it if needed - updateNeeded := false - if existingInvoice.State != string(stripeInvoice.Status) { - existingInvoice.State = string(stripeInvoice.Status) - updateNeeded = true - } - if stripeInvoice.EffectiveAt != 0 && existingInvoice.EffectiveAt != utils.AsTimeFromEpoch(stripeInvoice.EffectiveAt) { - existingInvoice.EffectiveAt = utils.AsTimeFromEpoch(stripeInvoice.EffectiveAt) - updateNeeded = true - } - if stripeInvoice.HostedInvoiceURL != "" && existingInvoice.HostedURL != stripeInvoice.HostedInvoiceURL { - existingInvoice.HostedURL = stripeInvoice.HostedInvoiceURL - updateNeeded = true - } - - if updateNeeded { - if _, err := s.repository.UpdateByID(ctx, existingInvoice); err != nil { - errs = append(errs, fmt.Errorf("failed to update invoice %s: %w", existingInvoice.ID, err)) - } - } + err = s.upsert(ctx, customr.ID, &existingInvoice, stripeInvoice) } else { - if _, err := s.repository.Create(ctx, stripeInvoiceToInvoice(customr.ID, stripeInvoice)); err != nil { - errs = append(errs, fmt.Errorf("failed to create invoice for customer %s: %w", customr.ID, err)) - } + err = s.upsert(ctx, customr.ID, nil, stripeInvoice) + } + if err != nil { + errs = append(errs, err) } // add jitter @@ -177,6 +230,37 @@ func (s *Service) SyncWithProvider(ctx context.Context, customr customer.Custome return nil } +func (s *Service) upsert(ctx context.Context, customerID string, + existingInvoice *Invoice, stripeInvoice *stripe.Invoice) error { + if existingInvoice != nil { + // already present in our system, update it if needed + updateNeeded := false + if existingInvoice.State != State(stripeInvoice.Status) { + existingInvoice.State = State(stripeInvoice.Status) + updateNeeded = true + } + if stripeInvoice.EffectiveAt != 0 && existingInvoice.EffectiveAt != utils.AsTimeFromEpoch(stripeInvoice.EffectiveAt) { + existingInvoice.EffectiveAt = utils.AsTimeFromEpoch(stripeInvoice.EffectiveAt) + updateNeeded = true + } + if stripeInvoice.HostedInvoiceURL != "" && existingInvoice.HostedURL != stripeInvoice.HostedInvoiceURL { + existingInvoice.HostedURL = stripeInvoice.HostedInvoiceURL + updateNeeded = true + } + + if updateNeeded { + if _, err := s.repository.UpdateByID(ctx, *existingInvoice); err != nil { + return fmt.Errorf("failed to update invoice %s: %w", existingInvoice.ID, err) + } + } + } else { + if _, err := s.repository.Create(ctx, stripeInvoiceToInvoice(customerID, stripeInvoice)); err != nil { + return fmt.Errorf("failed to create invoice for customer %s: %w", customerID, err) + } + } + return nil +} + // ListAll should only be called by admin users func (s *Service) ListAll(ctx context.Context, filter Filter) ([]Invoice, error) { return s.repository.List(ctx, filter) @@ -190,6 +274,8 @@ func (s *Service) List(ctx context.Context, filter Filter) ([]Invoice, error) { return s.repository.List(ctx, filter) } +// GetUpcoming returns the upcoming invoice for the customer based on the +// active subscription plan. If no upcoming invoice is found, it returns empty. func (s *Service) GetUpcoming(ctx context.Context, customerID string) (Invoice, error) { logger := grpczap.Extract(ctx) custmr, err := s.customerService.GetByID(ctx, customerID) @@ -238,11 +324,27 @@ func stripeInvoiceToInvoice(customerID string, stripeInvoice *stripe.Invoice) In if stripeInvoice.PeriodEnd != 0 { periodEndAt = time.Unix(stripeInvoice.PeriodEnd, 0) } + var items []Item + if stripeInvoice.Lines != nil { + for _, line := range stripeInvoice.Lines.Data { + item := Item{ + ID: uuid.New().String(), + ProviderID: line.ID, + Name: line.Description, + Type: ItemType(line.Metadata[ItemTypeMetadataKey]), + Quantity: line.Quantity, + } + if line.Price != nil { + item.UnitAmount = line.Price.UnitAmount + } + items = append(items, item) + } + } return Invoice{ ID: "", ProviderID: stripeInvoice.ID, CustomerID: customerID, - State: string(stripeInvoice.Status), + State: State(stripeInvoice.Status), Currency: string(stripeInvoice.Currency), Amount: stripeInvoice.Total, HostedURL: stripeInvoice.HostedInvoiceURL, @@ -252,6 +354,7 @@ func stripeInvoiceToInvoice(customerID string, stripeInvoice *stripe.Invoice) In CreatedAt: createdAt, PeriodStartAt: periodStartAt, PeriodEndAt: periodEndAt, + Items: items, } } @@ -269,3 +372,251 @@ func (s *Service) DeleteByCustomer(ctx context.Context, c customer.Customer) err } return nil } + +// GenerateForCredits finds all customers which has credit min limit lower than +// 0, that is, allows for negative balance and generates an invoice for them. +// Invoices will be paid asynchronously by the customer but system need to +// reconcile the token balance once it's paid. +func (s *Service) GenerateForCredits(ctx context.Context) error { + var errs []error + logger := grpczap.Extract(ctx) + if s.creditOverdraftUnitAmount == 0 || s.creditOverdraftInvoiceCurrency == "" { + // do not process if credit overdraft details not set + return nil + } + + // ensure only one of this job is running at a time + lock, err := s.locker.TryLock(ctx, GenerateForCreditLockKey) + if err != nil { + if errors.Is(err, db.ErrLockBusy) { + // someone else has the lock, return + return nil + } + return err + } + defer func() { + unlockErr := lock.Unlock(ctx) + if unlockErr != nil { + logger.Error("failed to unlock", zap.Error(unlockErr), zap.String("key", GenerateForCreditLockKey)) + } + }() + + customers, err := s.customerService.List(ctx, customer.Filter{ + Online: utils.Bool(true), + AllowedOverdraft: utils.Bool(true), + }) + if err != nil { + return err + } + for _, c := range customers { + if ctx.Err() != nil { + // stop processing if context is done + break + } + + balance, err := s.creditService.GetBalance(ctx, c.ID) + if err != nil { + errs = append(errs, fmt.Errorf("failed to get balance for customer %s: %w", c.ID, err)) + continue + } + if balance >= 0 { + continue + } + + // check if there is already an invoice open for this balance + invoices, err := s.List(ctx, Filter{ + CustomerID: c.ID, + }) + if err != nil { + errs = append(errs, fmt.Errorf("failed to list invoices for customer %s: %w", c.ID, err)) + continue + } + // check if invoice line items are of type credit + // if yes, don't create a new invoice + var alreadyInvoiced bool + for _, i := range invoices { + if i.State == DraftState || i.State == OpenState { + for _, item := range i.Items { + if item.Type == CreditItemType { + alreadyInvoiced = true + } + } + } + } + if alreadyInvoiced { + continue + } + + // create invoice for the credit overdraft + items := []Item{ + { + Name: "Credit Overdraft", + Type: CreditItemType, + UnitAmount: s.creditOverdraftUnitAmount, + Quantity: abs(balance), + }, + } + newStripeInvoice, err := s.CreateInProvider(ctx, c, items, s.creditOverdraftInvoiceCurrency) + if err != nil { + errs = append(errs, fmt.Errorf("failed to create invoice for customer %s: %w", c.ID, err)) + continue + } + // sync back new invoice + if err := s.upsert(ctx, c.ID, nil, newStripeInvoice); err != nil { + errs = append(errs, fmt.Errorf("failed to sync invoice for customer %s: %w", c.ID, err)) + continue + } + } + + if len(errs) > 0 { + return errors.Join(errs...) + } + return nil +} + +func abs(x int64) int64 { + if x < 0 { + return -x + } + return x +} + +// CreateInProvider creates a custom invoice with items in the provider. +// Once created the invoice object will be synced back within system using +// regular syncer/webhook loop. +func (s *Service) CreateInProvider(ctx context.Context, custmr customer.Customer, + items []Item, currency string) (*stripe.Invoice, error) { + stripeInvoice, err := s.stripeClient.Invoices.New(&stripe.InvoiceParams{ + Params: stripe.Params{ + Context: ctx, + }, + Customer: stripe.String(custmr.ProviderID), + AutoAdvance: stripe.Bool(true), + Description: stripe.String("Invoice for the underpayment of credit utilization"), + AutomaticTax: &stripe.InvoiceAutomaticTaxParams{ + Enabled: stripe.Bool(s.stripeAutoTax), + }, + Currency: stripe.String(currency), + PendingInvoiceItemsBehavior: stripe.String("include"), + Metadata: map[string]string{ + "org_id": custmr.OrgID, + "managed_by": "frontier", + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to create invoice: %w", err) + } + + // create line item for the invoice + for _, item := range items { + _, err = s.stripeClient.InvoiceItems.New(&stripe.InvoiceItemParams{ + Params: stripe.Params{ + Context: ctx, + }, + Customer: stripe.String(custmr.ProviderID), + Currency: stripe.String(custmr.Currency), + Invoice: stripe.String(stripeInvoice.ID), + UnitAmount: &item.UnitAmount, + Quantity: &item.Quantity, + Metadata: map[string]string{ + "org_id": custmr.OrgID, + "managed_by": "frontier", + // type is used to identify the item type in the invoice + // this is useful when reconciling the invoice items for payments and + // avoid creating duplicate invoices + ItemTypeMetadataKey: item.Type.String(), + }, + Description: stripe.String(item.Name), + }) + if err != nil { + return nil, fmt.Errorf("failed to create invoice item: %w", err) + } + } + return stripeInvoice, nil +} + +// Reconcile checks all paid invoices and reconciles them with the system. +// If the invoice was created for credit overdraft, it will credit the customer +// account with the amount of the invoice. +func (s *Service) Reconcile(ctx context.Context) error { + if s.creditOverdraftUnitAmount == 0 { + // do not process if credit overdraft details not set as currently + // we only reconcile credit overdraft invoices + return nil + } + + invoices, err := s.ListAll(ctx, Filter{ + State: PaidState, + NonZeroOnly: true, + }) + if err != nil { + return err + } + var errs []error + for _, i := range invoices { + if ctx.Err() != nil { + // stop processing if context is done + break + } + + // check if already reconciled + if i.Metadata != nil && i.Metadata[ReconciledMetadataKey] == true { + continue + } + + if err := s.reconcileCreditInvoice(ctx, i); err != nil { + errs = append(errs, fmt.Errorf("failed to reconcile invoice %s: %w", i.ID, err)) + continue + } + + // mark invoices reconciled to avoid processing them in future + if i.Metadata == nil { + i.Metadata = make(map[string]any) + } + i.Metadata[ReconciledMetadataKey] = true + if _, err := s.repository.UpdateByID(ctx, i); err != nil { + errs = append(errs, fmt.Errorf("failed to update invoice metadata %s: %w", i.ID, err)) + continue + } + } + if len(errs) > 0 { + return errors.Join(errs...) + } + return nil +} + +func (s *Service) reconcileCreditInvoice(ctx context.Context, i Invoice) error { + if i.State != PaidState { + return nil + } + var creditItems []Item + for _, item := range i.Items { + if item.Type == CreditItemType { + creditItems = append(creditItems, item) + } + } + if len(creditItems) == 0 { + return nil + } + for _, item := range creditItems { + // credit the customer account + if err := s.creditService.Add(ctx, credit.Credit{ + ID: credit.TxUUID(i.ID, item.ProviderID), + CustomerID: i.CustomerID, + Amount: item.Quantity, + Source: credit.SourceSystemOverdraftEvent, + Description: "Paid for credit overdraft invoice", + Metadata: map[string]any{ + "invoice_id": i.ID, + "overdraft": true, + "item": item.ProviderID, + }, + }); err != nil { + if errors.Is(err, credit.ErrAlreadyApplied) { + continue + } + return fmt.Errorf("failed to credit customer %s: %w", i.CustomerID, err) + } + } + return nil +} diff --git a/billing/plan/plan.go b/billing/plan/plan.go index 6fb1f3e98..e04e4eff6 100644 --- a/billing/plan/plan.go +++ b/billing/plan/plan.go @@ -29,6 +29,8 @@ type Plan struct { // Interval is the interval at which the plan is billed // e.g. day, week, month, year + // This is just used to group related product prices and has no + // immediate effect on the billing engine Interval string `json:"interval" yaml:"interval"` // OnStartCredits is the number of credits that are awarded when a subscription is started diff --git a/billing/usage/service.go b/billing/usage/service.go index 489afb9a3..3effe1a57 100644 --- a/billing/usage/service.go +++ b/billing/usage/service.go @@ -6,7 +6,6 @@ import ( "fmt" "strings" - "github.com/google/uuid" "github.com/raystack/frontier/billing/credit" ) @@ -70,7 +69,7 @@ func (s Service) Revert(ctx context.Context, customerID, usageID string, amount // Revert the usage if err := s.creditService.Add(ctx, credit.Credit{ - ID: uuid.NewSHA1(credit.TxNamespaceUUID, []byte(fmt.Sprintf("%s:%s", usageID, customerID))).String(), + ID: credit.TxUUID(usageID, customerID), CustomerID: customerID, Amount: amount, Description: fmt.Sprintf("Revert: %s", creditTx.Description), diff --git a/cmd/serve.go b/cmd/serve.go index 39362f575..b262753eb 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -446,7 +446,8 @@ func buildAPIDependencies( customerService, planService, subscriptionService, productService, creditService, organizationService, authnService) - invoiceService := invoice.NewService(stripeClient, postgres.NewBillingInvoiceRepository(dbc), customerService, cfg.Billing) + invoiceService := invoice.NewService(stripeClient, postgres.NewBillingInvoiceRepository(dbc), + customerService, creditService, productService, dbc, cfg.Billing) usageService := usage.NewService(creditService) diff --git a/config/sample.config.yaml b/config/sample.config.yaml index 9981f01cf..260da9c1a 100644 --- a/config/sample.config.yaml +++ b/config/sample.config.yaml @@ -192,6 +192,9 @@ billing: default_offline: false # free credits to be added to the customer account when created as a part of the org onboard_credits_with_org: 0 + # credit_overdraft_product is the product name that should be used to calculate per unit cost + # of the overdraft credits, it uses the first price available for the product + credit_overdraft_product: "" # plan change configuration applied when a user changes their subscription plan plan_change: # proration_behavior can be one of "create_prorations", "none", "always_invoice" diff --git a/docs/docs/reference/configurations.md b/docs/docs/reference/configurations.md index f323703e3..ad4158ea3 100644 --- a/docs/docs/reference/configurations.md +++ b/docs/docs/reference/configurations.md @@ -187,6 +187,9 @@ billing: default_offline: false # free credits to be added to the customer account when created as a part of the org onboard_credits_with_org: 0 + # credit_overdraft_product is the product name that should be used to calculate per unit cost + # of the overdraft credits, it uses the first price available for the product + credit_overdraft_product: "" # plan change configuration applied when a user changes their subscription plan plan_change: # proration_behavior can be one of "create_prorations", "none", "always_invoice" diff --git a/go.mod b/go.mod index d220febd8..24cf894d6 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/authzed/authzed-go v0.11.2-0.20240507202708-8b150c491e4a github.com/authzed/grpcutil v0.0.0-20240123092924-129dc0a6a6e1 github.com/authzed/spicedb v1.33.1 + github.com/cespare/xxhash v1.1.0 github.com/coreos/go-oidc/v3 v3.5.0 github.com/doug-martin/goqu/v9 v9.18.0 github.com/envoyproxy/protoc-gen-validate v1.0.4 diff --git a/go.sum b/go.sum index 3db3ab890..0da4d2544 100644 --- a/go.sum +++ b/go.sum @@ -514,6 +514,7 @@ github.com/NYTimes/gziphandler v1.1.1/go.mod h1:n/CVRwUEOgIxrgPvAQhUUr9oeUtvrhMo github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2/go.mod h1:HBCaDeC1lPdgDeDbhX8XFpy1jqjK0IBG8W5K+xYqA0w= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= +github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.0.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= @@ -721,6 +722,7 @@ github.com/certifi/gocertifi v0.0.0-20191021191039-0944d244cd40/go.mod h1:sGbDF6 github.com/certifi/gocertifi v0.0.0-20200922220541-2c3bb06c6054/go.mod h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA= github.com/certifi/gocertifi v0.0.0-20210507211836-431795d63e8d h1:S2NE3iHSwP0XV47EEXL8mWmRdEfGscSJ+7EgePNgt0s= github.com/certifi/gocertifi v0.0.0-20210507211836-431795d63e8d/go.mod h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA= +github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -2141,6 +2143,7 @@ github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk= diff --git a/internal/api/v1beta1/billing_invoice.go b/internal/api/v1beta1/billing_invoice.go index c62919002..c56a7026b 100644 --- a/internal/api/v1beta1/billing_invoice.go +++ b/internal/api/v1beta1/billing_invoice.go @@ -88,7 +88,7 @@ func transformInvoiceToPB(i invoice.Invoice) (*frontierv1beta1.Invoice, error) { Id: i.ID, CustomerId: i.CustomerID, ProviderId: i.ProviderID, - State: i.State, + State: i.State.String(), Currency: i.Currency, Amount: i.Amount, HostedUrl: i.HostedURL, diff --git a/internal/store/postgres/billing_customer_repository.go b/internal/store/postgres/billing_customer_repository.go index fa7e82923..61d493fcf 100644 --- a/internal/store/postgres/billing_customer_repository.go +++ b/internal/store/postgres/billing_customer_repository.go @@ -10,6 +10,8 @@ import ( "strings" "time" + "github.com/raystack/frontier/pkg/utils" + "github.com/doug-martin/goqu/v9" "github.com/jmoiron/sqlx/types" "github.com/raystack/frontier/billing/customer" @@ -201,6 +203,12 @@ func (r BillingCustomerRepository) List(ctx context.Context, flt customer.Filter "provider_id": flt.ProviderID, }) } + if utils.BoolValue(flt.Online) { + stmt = stmt.Where(goqu.L("(provider_id IS NOT NULL AND provider_id != '')")) + } + if utils.BoolValue(flt.AllowedOverdraft) { + stmt = stmt.Where(goqu.L("credit_min < 0")) + } query, params, err := stmt.ToSQL() if err != nil { return nil, fmt.Errorf("%w: %s", parseErr, err) diff --git a/internal/store/postgres/billing_customer_repository_test.go b/internal/store/postgres/billing_customer_repository_test.go index ff877c5ed..a8cd1491a 100644 --- a/internal/store/postgres/billing_customer_repository_test.go +++ b/internal/store/postgres/billing_customer_repository_test.go @@ -3,9 +3,12 @@ package postgres_test import ( "context" "fmt" + "strings" "testing" "time" + "golang.org/x/exp/slices" + "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/pkg/utils" "github.com/stretchr/testify/assert" @@ -174,10 +177,13 @@ func (s *BillingCustomerRepositoryTestSuite) TestList() { Description string Expected []customer.Customer ErrString string + filter customer.Filter } sampleID1 := uuid.New().String() sampleID2 := uuid.New().String() + sampleID3 := uuid.New().String() + sampleID4 := uuid.New().String() customers := []customer.Customer{ { ID: sampleID1, @@ -201,9 +207,10 @@ func (s *BillingCustomerRepositoryTestSuite) TestList() { DeletedAt: nil, }, { - ID: sampleID2, - OrgID: s.orgIDs[1], - Name: "customer 2", + ID: sampleID2, + ProviderID: sampleID2, + OrgID: s.orgIDs[1], + Name: "customer 2", TaxData: []customer.Tax{ { Type: "t1", @@ -220,31 +227,76 @@ func (s *BillingCustomerRepositoryTestSuite) TestList() { UpdatedAt: time.Time{}, DeletedAt: nil, }, + { + ID: sampleID3, + OrgID: s.orgIDs[0], + Name: "customer 3", + Email: "email", + State: "active", + Metadata: metadata.Metadata{}, + CreatedAt: time.Time{}, + UpdatedAt: time.Time{}, + DeletedAt: nil, + }, + { + ID: sampleID4, + ProviderID: sampleID4, + OrgID: s.orgIDs[0], + Name: "customer 4", + Email: "email", + State: "active", + CreditMin: -200, + Metadata: metadata.Metadata{}, + CreatedAt: time.Time{}, + UpdatedAt: time.Time{}, + DeletedAt: nil, + }, } var testCases = []testCase{ { Description: "should create basic customer with provider successfully", Expected: []customer.Customer{ customers[0], + customers[3], + }, + filter: customer.Filter{ + OrgID: s.orgIDs[0], + State: customer.ActiveState, + Online: utils.Bool(true), + }, + }, + { + Description: "should list customers with overdraft limits", + Expected: []customer.Customer{ + customers[3], + }, + filter: customer.Filter{ + OrgID: s.orgIDs[0], + State: customer.ActiveState, + Online: utils.Bool(true), + AllowedOverdraft: utils.Bool(true), }, }, } + for _, c := range customers { + _, err := s.repository.Create(s.ctx, c) + assert.NoError(s.T(), err) + } for _, tc := range testCases { s.Run(tc.Description, func() { - for _, c := range customers { - _, err := s.repository.Create(s.ctx, c) - assert.NoError(s.T(), err) - } - got, err := s.repository.List(s.ctx, customer.Filter{ - OrgID: s.orgIDs[0], - State: customer.ActiveState, - }) + got, err := s.repository.List(s.ctx, tc.filter) if err != nil { if err.Error() != tc.ErrString { s.T().Fatalf("got error %s, expected was %s", err.Error(), tc.ErrString) } } + slices.SortFunc(got, func(i, j customer.Customer) int { + return strings.Compare(i.Name, j.Name) + }) + slices.SortFunc(tc.Expected, func(i, j customer.Customer) int { + return strings.Compare(i.Name, j.Name) + }) if diff := cmp.Diff(tc.Expected, got, cmpopts.IgnoreFields(customer.Customer{}, "ID", "CreatedAt", "UpdatedAt")); diff != "" { s.T().Fatalf("mismatch (-want +got):\n%s", diff) } diff --git a/internal/store/postgres/billing_invoice_repository.go b/internal/store/postgres/billing_invoice_repository.go index 332063283..5141aa2a8 100644 --- a/internal/store/postgres/billing_invoice_repository.go +++ b/internal/store/postgres/billing_invoice_repository.go @@ -3,6 +3,7 @@ package postgres import ( "context" "database/sql" + "database/sql/driver" "encoding/json" "errors" "fmt" @@ -27,6 +28,7 @@ type Invoice struct { Amount int64 `db:"amount"` HostedURL string `db:"hosted_url"` + Items Items `db:"items"` Metadata types.NullJSONText `db:"metadata"` PeriodStartAt *time.Time `db:"period_start_at"` @@ -38,6 +40,26 @@ type Invoice struct { DeletedAt *time.Time `db:"deleted_at"` } +type Items struct { + Data []invoice.Item `json:"data"` +} + +func (t *Items) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return json.Unmarshal(src, t) + case string: + return json.Unmarshal([]byte(src), t) + case nil: + return nil + } + return fmt.Errorf("cannot convert %T to JsonB", src) +} + +func (t Items) Value() (driver.Value, error) { + return json.Marshal(t) +} + func (i Invoice) transform() (invoice.Invoice, error) { var unmarshalledMetadata map[string]any if i.Metadata.Valid { @@ -65,10 +87,11 @@ func (i Invoice) transform() (invoice.Invoice, error) { ID: i.ID, ProviderID: i.ProviderID, CustomerID: i.CustomerID, - State: i.State, + State: invoice.State(i.State), Currency: i.Currency, Amount: i.Amount, HostedURL: i.HostedURL, + Items: i.Items.Data, Metadata: unmarshalledMetadata, DueAt: dueAt, EffectiveAt: effectiveAt, @@ -102,15 +125,18 @@ func (r BillingInvoiceRepository) Create(ctx context.Context, toCreate invoice.I query, params, err := dialect.Insert(TABLE_BILLING_INVOICES).Rows( goqu.Record{ - "id": toCreate.ID, - "provider_id": toCreate.ProviderID, - "customer_id": toCreate.CustomerID, - "state": toCreate.State, - "currency": toCreate.Currency, - "amount": toCreate.Amount, - "hosted_url": toCreate.HostedURL, - "due_at": toCreate.DueAt, - "effective_at": toCreate.EffectiveAt, + "id": toCreate.ID, + "provider_id": toCreate.ProviderID, + "customer_id": toCreate.CustomerID, + "state": toCreate.State.String(), + "currency": toCreate.Currency, + "amount": toCreate.Amount, + "hosted_url": toCreate.HostedURL, + "due_at": toCreate.DueAt, + "effective_at": toCreate.EffectiveAt, + "items": Items{ + Data: toCreate.Items, + }, "metadata": marshaledMetadata, "period_start_at": toCreate.PeriodStartAt, "period_end_at": toCreate.PeriodEndAt, @@ -167,6 +193,11 @@ func (r BillingInvoiceRepository) List(ctx context.Context, flt invoice.Filter) "amount": goqu.Op{"gt": 0}, }) } + if flt.State != "" { + stmt = stmt.Where(goqu.Ex{ + "state": flt.State.String(), + }) + } if flt.Pagination != nil { offset := flt.Pagination.Offset() @@ -231,7 +262,7 @@ func (r BillingInvoiceRepository) UpdateByID(ctx context.Context, toUpdate invoi updateRecord["metadata"] = marshaledMetadata } if toUpdate.State != "" { - updateRecord["state"] = toUpdate.State + updateRecord["state"] = toUpdate.State.String() } if !toUpdate.EffectiveAt.IsZero() { updateRecord["effective_at"] = toUpdate.EffectiveAt diff --git a/internal/store/postgres/billing_transactions_repository.go b/internal/store/postgres/billing_transactions_repository.go index 2ca78bc95..7b6880c0c 100644 --- a/internal/store/postgres/billing_transactions_repository.go +++ b/internal/store/postgres/billing_transactions_repository.go @@ -138,7 +138,9 @@ func (r BillingTransactionRepository) CreateEntry(ctx context.Context, debitEntr } var creditReturnedEntry, debitReturnedEntry credit.Transaction - if err := r.dbc.WithTxn(ctx, sql.TxOptions{}, func(tx *sqlx.Tx) error { + if err := r.dbc.WithTxn(ctx, sql.TxOptions{ + Isolation: sql.LevelSerializable, + }, func(tx *sqlx.Tx) error { // check if balance is enough if it's a customer entry if customerAcc.ID != "" { currentBalance, err := r.getBalanceInTx(ctx, tx, customerAcc.ID) @@ -388,7 +390,9 @@ func (r BillingTransactionRepository) getBalanceInTx(ctx context.Context, tx *sq // in transaction table till now. func (r BillingTransactionRepository) GetBalance(ctx context.Context, accountID string) (int64, error) { var amount int64 - if err := r.dbc.WithTxn(ctx, sql.TxOptions{}, func(tx *sqlx.Tx) error { + if err := r.dbc.WithTxn(ctx, sql.TxOptions{ + Isolation: sql.LevelSerializable, + }, func(tx *sqlx.Tx) error { var err error amount, err = r.getBalanceInTx(ctx, tx, accountID) return err diff --git a/internal/store/postgres/migrations/20241015201506_billing_invoice_items.down.sql b/internal/store/postgres/migrations/20241015201506_billing_invoice_items.down.sql new file mode 100644 index 000000000..b5cbf1640 --- /dev/null +++ b/internal/store/postgres/migrations/20241015201506_billing_invoice_items.down.sql @@ -0,0 +1 @@ +ALTER TABLE billing_invoices DROP COLUMN IF EXISTS items; diff --git a/internal/store/postgres/migrations/20241015201506_billing_invoice_items.up.sql b/internal/store/postgres/migrations/20241015201506_billing_invoice_items.up.sql new file mode 100644 index 000000000..17ee6e3b2 --- /dev/null +++ b/internal/store/postgres/migrations/20241015201506_billing_invoice_items.up.sql @@ -0,0 +1 @@ +ALTER TABLE billing_invoices ADD COLUMN IF NOT EXISTS items jsonb DEFAULT '{}'; diff --git a/pkg/db/config.go b/pkg/db/config.go index 95a0754e5..59565a5b5 100644 --- a/pkg/db/config.go +++ b/pkg/db/config.go @@ -7,6 +7,6 @@ type Config struct { URL string `yaml:"url" mapstructure:"url"` MaxIdleConns int `yaml:"max_idle_conns" mapstructure:"max_idle_conns" default:"10"` MaxOpenConns int `yaml:"max_open_conns" mapstructure:"max_open_conns" default:"10"` - ConnMaxLifeTime time.Duration `yaml:"conn_max_life_time" mapstructure:"conn_max_life_time" default:"60s"` - MaxQueryTimeout time.Duration `yaml:"max_query_timeout" mapstructure:"max_query_timeout" default:"1s"` + ConnMaxLifeTime time.Duration `yaml:"conn_max_life_time" mapstructure:"conn_max_life_time" default:"15m"` + MaxQueryTimeout time.Duration `yaml:"max_query_timeout" mapstructure:"max_query_timeout" default:"5s"` } diff --git a/pkg/db/db.go b/pkg/db/db.go index a5e47af7c..b36d82291 100644 --- a/pkg/db/db.go +++ b/pkg/db/db.go @@ -3,15 +3,21 @@ package db import ( "context" "database/sql" + "errors" "fmt" "time" + "github.com/cespare/xxhash" + "github.com/raystack/frontier/internal/metrics" newrelic "github.com/newrelic/go-agent" "github.com/jmoiron/sqlx" - "github.com/pkg/errors" +) + +var ( + ErrLockBusy = errors.New("lock busy") ) type Client struct { @@ -73,7 +79,7 @@ func (c Client) WithTxn(ctx context.Context, txnOptions sql.TxOptions, txFunc fu case error: err = p default: - err = errors.Errorf("%s", p) + err = fmt.Errorf("%s", p) } err = txn.Rollback() panic(p) @@ -92,3 +98,63 @@ func (c Client) WithTxn(ctx context.Context, txnOptions sql.TxOptions, txFunc fu err = txFunc(txn) return err } + +type Lock struct { + ID uint64 + conn *sqlx.Conn +} + +// Unlock uses postgres advisory locks to release a lock on a given id +func (l Lock) Unlock(ctx context.Context) error { + var errs []error + _, err := l.conn.ExecContext(ctx, fmt.Sprintf("SELECT pg_advisory_unlock(%d)", l.ID)) + if err != nil { + errs = append(errs, err) + } + + err = l.conn.Close() + if err != nil { + errs = append(errs, err) + } + + if len(errs) > 0 { + return errors.Join(errs...) + } + return nil +} + +// TryLock uses postgres advisory locks to acquire a lock on a given id +// if acquired, it returns the Lock object, else fail with ErrLockBusy +// In worst case if not unlocked, it will be released after the session ends +// which is configured via SetConnMaxLifetime +func (c Client) TryLock(ctx context.Context, id string) (*Lock, error) { + newConn, err := c.Connx(ctx) + if err != nil { + return nil, fmt.Errorf("failed to acquire connection: %w", err) + } + + hash := xxhash.Sum64String(id) + query := fmt.Sprintf("SELECT pg_try_advisory_lock(%d)", hash) + var acquired bool + if err := c.GetContext(ctx, &acquired, query); err != nil { + var errs []error + errs = append(errs, err) + if connErr := newConn.Close(); connErr != nil { + errs = append(errs, fmt.Errorf("failed to close connection: %w", connErr)) + } + return nil, errors.Join(errs...) + } + + if !acquired { + if connErr := newConn.Close(); connErr != nil { + return nil, fmt.Errorf("failed to close connection: %w", connErr) + } + return nil, ErrLockBusy + } + + lock := &Lock{ + ID: hash, + conn: newConn, + } + return lock, nil +} diff --git a/pkg/utils/pointers.go b/pkg/utils/pointers.go new file mode 100644 index 000000000..f2da199fc --- /dev/null +++ b/pkg/utils/pointers.go @@ -0,0 +1,15 @@ +package utils + +// Bool returns a pointer to the bool value passed in. +func Bool(v bool) *bool { + return &v +} + +// BoolValue returns the value of the bool pointer passed in or +// false if the pointer is nil. +func BoolValue(v *bool) bool { + if v != nil { + return *v + } + return false +}