diff --git a/server/graphql/v2/input/TransactionsImportReferenceInput.ts b/server/graphql/v2/input/TransactionsImportReferenceInput.ts new file mode 100644 index 000000000000..a09bc14fd2de --- /dev/null +++ b/server/graphql/v2/input/TransactionsImportReferenceInput.ts @@ -0,0 +1,36 @@ +import { GraphQLInputObjectType, GraphQLNonNull } from 'graphql'; +import { GraphQLNonEmptyString } from 'graphql-scalars'; + +import { TransactionsImport } from '../../../models'; +import { idDecode } from '../identifiers'; + +export type GraphQLTransactionsImportReferenceInputFields = { + id: string; +}; + +export const GraphQLTransactionsImportReferenceInput = new GraphQLInputObjectType({ + name: 'TransactionsImportReferenceInput', + fields: () => ({ + id: { + type: new GraphQLNonNull(GraphQLNonEmptyString), + description: 'The id of the row', + }, + }), +}); + +export const fetchTransactionsImportWithReference = async ( + input: { id: string }, + { throwIfMissing = false, ...sequelizeOpts } = {}, +): Promise => { + let row; + if (input.id) { + const decodedId = idDecode(input.id, 'transactions-import-row'); + row = await TransactionsImport.findByPk(decodedId, sequelizeOpts); + } + + if (!row && throwIfMissing) { + throw new Error(`TransactionsImport not found`); + } + + return row; +}; diff --git a/server/graphql/v2/mutation/PlaidMutations.ts b/server/graphql/v2/mutation/PlaidMutations.ts index 9e90c000d383..92892a5c0c1b 100644 --- a/server/graphql/v2/mutation/PlaidMutations.ts +++ b/server/graphql/v2/mutation/PlaidMutations.ts @@ -12,6 +12,10 @@ import { fetchConnectedAccountWithReference, GraphQLConnectedAccountReferenceInput, } from '../input/ConnectedAccountReferenceInput'; +import { + fetchTransactionsImportWithReference, + GraphQLTransactionsImportReferenceInput, +} from '../input/TransactionsImportReferenceInput'; import { GraphQLConnectedAccount } from '../object/ConnectedAccount'; import { GraphQLTransactionsImport } from '../object/TransactionsImport'; @@ -106,6 +110,10 @@ export const plaidMutations = { type: GraphQLString, description: 'The name of the bank account', }, + transactionImport: { + type: GraphQLTransactionsImportReferenceInput, + description: 'The transaction import to use. If not provided, a new one will be created.', + }, }, resolve: async (_, args, req) => { checkRemoteUserCanUseTransactions(req); @@ -121,7 +129,13 @@ export const plaidMutations = { ); } - const accountInfo = pick(args, ['sourceName', 'name']); + const accountInfo: Parameters[3] = pick(args, ['sourceName', 'name']); + if (args.transactionImport) { + accountInfo.transactionsImport = await fetchTransactionsImportWithReference(args.transactionImport, { + throwIfMissing: true, + }); + } + return connectPlaidAccount(req.remoteUser, host, args.publicToken, accountInfo); }, }, diff --git a/server/lib/plaid/connect.ts b/server/lib/plaid/connect.ts index b5eeb6e69fcd..308cf9ca36df 100644 --- a/server/lib/plaid/connect.ts +++ b/server/lib/plaid/connect.ts @@ -1,4 +1,4 @@ -import { truncate } from 'lodash'; +import { omit, truncate } from 'lodash'; import { CountryCode, ItemPublicTokenExchangeResponse, Products } from 'plaid'; import { Service } from '../../constants/connected-account'; @@ -39,13 +39,25 @@ export const connectPlaidAccount = async ( remoteUser: User, host: Collective, publicToken: string, - { sourceName, name }: { sourceName: string; name: string }, + { + sourceName, + name, + transactionsImport, + }: { sourceName: string; name: string; transactionsImport?: TransactionsImport }, ) => { // Permissions check if (!remoteUser.isAdminOfCollective(host)) { throw new Error('You must be an admin of the host to connect a Plaid account to it'); } else if (!host.isHostAccount) { throw new Error('You can only connect a Plaid account to a host account'); + } else if (transactionsImport) { + if (transactionsImport.CollectiveId !== host.id) { + throw new Error('The transaction import must belong to the host'); + } else if (transactionsImport.type !== 'PLAID') { + throw new Error('The transaction import must be of type PLAID'); + } else if (transactionsImport.ConnectedAccountId) { + throw new Error('This transaction import is already connected to Plaid'); + } } // Exchange Plaid public token @@ -71,6 +83,9 @@ export const connectPlaidAccount = async ( } } + // TODO: If reconnecting, make sure the connected account matches the previous one + // ... + // Create connected account return sequelize.transaction(async transaction => { const connectedAccount = await ConnectedAccount.create( @@ -80,25 +95,38 @@ export const connectPlaidAccount = async ( clientId: exchangeTokenResponse['item_id'], token: exchangeTokenResponse['access_token'], CreatedByUserId: remoteUser.id, + data: omit(exchangeTokenResponse, ['item_id', 'access_token']), }, { transaction }, ); - const transactionsImport = await TransactionsImport.createWithActivity( - remoteUser, - host, + if (transactionsImport) { + await transactionsImport.update({ ConnectedAccountId: connectedAccount.id }, { transaction }); + } else { + transactionsImport = await TransactionsImport.createWithActivity( + remoteUser, + host, + { + type: 'PLAID', + source: truncate(sourceName, { length: 255 }) || 'Bank', + name: truncate(name, { length: 255 }) || 'Bank Account', + ConnectedAccountId: connectedAccount.id, + }, + { transaction }, + ); + } + + // Record the transactions import ID in the connected account for audit purposes + await connectedAccount.update( { - type: 'PLAID', - source: truncate(sourceName, { length: 255 }) || 'Bank', - name: truncate(name, { length: 255 }) || 'Bank Account', - ConnectedAccountId: connectedAccount.id, + data: { + ...connectedAccount.data, + transactionsImportId: transactionsImport.id, + }, }, { transaction }, ); - // Record the transactions import ID in the connected account for audit purposes - await connectedAccount.update({ data: { transactionsImportId: transactionsImport.id } }, { transaction }); - return { connectedAccount, transactionsImport }; }); };