diff --git a/CHANGELOG.md b/CHANGELOG.md index c8b22e1..68a4b3b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +## [7.1.0] - 2024-04-10 + +- Adds queries for Bulk Import + ## [7.0.1] - 2024-04-17 - Fixes issues with partial failures during tenant creation diff --git a/build.gradle b/build.gradle index b8b5788..3776b0c 100644 --- a/build.gradle +++ b/build.gradle @@ -2,7 +2,7 @@ plugins { id 'java-library' } -version = "7.0.1" +version = "7.1.0" repositories { mavenCentral() diff --git a/src/main/java/io/supertokens/storage/mysql/BulkImportProxyConnection.java b/src/main/java/io/supertokens/storage/mysql/BulkImportProxyConnection.java new file mode 100644 index 0000000..633066e --- /dev/null +++ b/src/main/java/io/supertokens/storage/mysql/BulkImportProxyConnection.java @@ -0,0 +1,342 @@ +/* + * Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + + +package io.supertokens.storage.mysql; + +import java.sql.Array; +import java.sql.Blob; +import java.sql.CallableStatement; +import java.sql.Clob; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.NClob; +import java.sql.PreparedStatement; +import java.sql.SQLClientInfoException; +import java.sql.SQLException; +import java.sql.SQLWarning; +import java.sql.SQLXML; +import java.sql.Savepoint; +import java.sql.Statement; +import java.sql.Struct; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.Executor; + +/** +* BulkImportProxyConnection is a class implementing the Connection interface, serving as a Connection instance in the bulk import user cronjob. +* This cron extensively utilizes existing queries to import users, all of which internally operate within transactions and those query sometimes +* call the commit/rollback method on the connection. +* +* For the purpose of bulkimport cronjob, we aim to employ a single connection for all queries and rollback any operations in case of query failures. +* To achieve this, we use our own proxy Connection instance and override the commit/rollback/close methods to do nothing. +*/ + +public class BulkImportProxyConnection implements Connection { + private Connection con = null; + + public BulkImportProxyConnection(Connection con) { + this.con = con; + } + + @Override + public void close() throws SQLException { + // We simply ignore when close is called BulkImportProxyConnection + } + + @Override + public void commit() throws SQLException { + // We simply ignore when commit is called BulkImportProxyConnection + } + + @Override + public void rollback() throws SQLException { + // We simply ignore when rollback is called BulkImportProxyConnection + } + + public void closeForBulkImportProxyStorage() throws SQLException { + this.con.close(); + } + + public void commitForBulkImportProxyStorage() throws SQLException { + this.con.commit(); + } + + public void rollbackForBulkImportProxyStorage() throws SQLException { + this.con.rollback(); + } + + /* Following methods are unchaged */ + + @Override + public Statement createStatement() throws SQLException { + return this.con.createStatement(); + } + + @Override + public PreparedStatement prepareStatement(String sql) throws SQLException { + return this.con.prepareStatement(sql); + } + + @Override + public CallableStatement prepareCall(String sql) throws SQLException { + return this.con.prepareCall(sql); + } + + @Override + public String nativeSQL(String sql) throws SQLException { + return this.con.nativeSQL(sql); + } + + @Override + public void setAutoCommit(boolean autoCommit) throws SQLException { + this.con.setAutoCommit(autoCommit); + } + + @Override + public boolean getAutoCommit() throws SQLException { + return this.con.getAutoCommit(); + } + + @Override + public boolean isClosed() throws SQLException { + return this.con.isClosed(); + } + + @Override + public DatabaseMetaData getMetaData() throws SQLException { + return this.con.getMetaData(); + } + + @Override + public void setReadOnly(boolean readOnly) throws SQLException { + this.con.setReadOnly(readOnly); + } + + @Override + public boolean isReadOnly() throws SQLException { + return this.con.isReadOnly(); + } + + @Override + public void setCatalog(String catalog) throws SQLException { + this.con.setCatalog(catalog); + } + + @Override + public String getCatalog() throws SQLException { + return this.con.getCatalog(); + } + + @Override + public void setTransactionIsolation(int level) throws SQLException { + this.con.setTransactionIsolation(level); + } + + @Override + public int getTransactionIsolation() throws SQLException { + return this.con.getTransactionIsolation(); + } + + @Override + public SQLWarning getWarnings() throws SQLException { + return this.con.getWarnings(); + } + + @Override + public void clearWarnings() throws SQLException { + this.con.clearWarnings(); + } + + @Override + public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException { + return this.con.createStatement(resultSetType, resultSetConcurrency); + } + + @Override + public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency) + throws SQLException { + return this.con.prepareStatement(sql, resultSetType, resultSetConcurrency); + } + + @Override + public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) throws SQLException { + return this.con.prepareCall(sql, resultSetType, resultSetConcurrency); + } + + @Override + public Map> getTypeMap() throws SQLException { + return this.con.getTypeMap(); + } + + @Override + public void setTypeMap(Map> map) throws SQLException { + this.con.setTypeMap(map); + } + + @Override + public void setHoldability(int holdability) throws SQLException { + this.con.setHoldability(holdability); + } + + @Override + public int getHoldability() throws SQLException { + return this.con.getHoldability(); + } + + @Override + public Savepoint setSavepoint() throws SQLException { + return this.con.setSavepoint(); + } + + @Override + public Savepoint setSavepoint(String name) throws SQLException { + return this.con.setSavepoint(name); + } + + @Override + public void rollback(Savepoint savepoint) throws SQLException { + this.con.rollback(savepoint); + } + + @Override + public void releaseSavepoint(Savepoint savepoint) throws SQLException { + this.con.releaseSavepoint(savepoint); + } + + @Override + public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability) + throws SQLException { + return this.con.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability); + } + + @Override + public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency, + int resultSetHoldability) throws SQLException { + return this.con.prepareStatement(sql, resultSetType, resultSetConcurrency, resultSetHoldability); + } + + @Override + public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency, + int resultSetHoldability) throws SQLException { + return this.con.prepareCall(sql, resultSetType, resultSetConcurrency, resultSetHoldability); + } + + @Override + public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException { + return this.con.prepareStatement(sql, autoGeneratedKeys); + } + + @Override + public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException { + return this.con.prepareStatement(sql, columnIndexes); + } + + @Override + public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException { + return this.con.prepareStatement(sql, columnNames); + } + + @Override + public Clob createClob() throws SQLException { + return this.con.createClob(); + } + + @Override + public Blob createBlob() throws SQLException { + return this.con.createBlob(); + } + + @Override + public NClob createNClob() throws SQLException { + return this.con.createNClob(); + } + + @Override + public SQLXML createSQLXML() throws SQLException { + return this.con.createSQLXML(); + } + + @Override + public boolean isValid(int timeout) throws SQLException { + return this.con.isValid(timeout); + } + + @Override + public void setClientInfo(String name, String value) throws SQLClientInfoException { + this.con.setClientInfo(name, value); + } + + @Override + public void setClientInfo(Properties properties) throws SQLClientInfoException { + this.con.setClientInfo(properties); + } + + @Override + public String getClientInfo(String name) throws SQLException { + return this.con.getClientInfo(name); + } + + @Override + public Properties getClientInfo() throws SQLException { + return this.con.getClientInfo(); + } + + @Override + public Array createArrayOf(String typeName, Object[] elements) throws SQLException { + return this.con.createArrayOf(typeName, elements); + } + + @Override + public Struct createStruct(String typeName, Object[] attributes) throws SQLException { + return this.con.createStruct(typeName, attributes); + } + + @Override + public void setSchema(String schema) throws SQLException { + this.con.setSchema(schema); + } + + @Override + public String getSchema() throws SQLException { + return this.con.getSchema(); + } + + @Override + public void abort(Executor executor) throws SQLException { + this.con.abort(executor); + } + + @Override + public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException { + this.con.setNetworkTimeout(executor, milliseconds); + } + + @Override + public int getNetworkTimeout() throws SQLException { + return this.con.getNetworkTimeout(); + } + + @Override + public T unwrap(Class iface) throws SQLException { + return this.con.unwrap(iface); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return this.con.isWrapperFor(iface); + } +} diff --git a/src/main/java/io/supertokens/storage/mysql/BulkImportProxyStorage.java b/src/main/java/io/supertokens/storage/mysql/BulkImportProxyStorage.java new file mode 100644 index 0000000..dd43afd --- /dev/null +++ b/src/main/java/io/supertokens/storage/mysql/BulkImportProxyStorage.java @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.storage.mysql; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.List; +import java.util.Set; + +import com.google.gson.JsonObject; + +import io.supertokens.pluginInterface.LOG_LEVEL; +import io.supertokens.pluginInterface.exceptions.DbInitException; +import io.supertokens.pluginInterface.exceptions.InvalidConfigException; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.pluginInterface.sqlStorage.TransactionConnection; + +/** + * BulkImportProxyStorage is a class extending Start, serving as a Storage instance in the bulk import user cronjob. + * This cronjob extensively utilizes existing queries to import users, all of which internally operate within transactions. + * + * For the purpose of bulkimport cronjob, we aim to employ a single connection for all queries and rollback any operations in case of query failures. + * To achieve this, we override the startTransactionHelper method to utilize the same connection and prevent automatic query commits even upon transaction + * success. + * Subsequently, the cronjob is responsible for committing the transaction after ensuring the successful execution of all queries. + */ + +public class BulkImportProxyStorage extends Start { + private BulkImportProxyConnection connection; + + public synchronized Connection getTransactionConnection() throws SQLException, StorageQueryException { + if (this.connection == null) { + Connection con = ConnectionPool.getConnectionForProxyStorage(this); + this.connection = new BulkImportProxyConnection(con); + connection.setTransactionIsolation(Connection.TRANSACTION_SERIALIZABLE); + connection.setAutoCommit(false); + } + return this.connection; + } + + @Override + protected T startTransactionHelper(TransactionLogic logic, TransactionIsolationLevel isolationLevel) + throws StorageQueryException, StorageTransactionLogicException, SQLException { + return logic.mainLogicAndCommit(new TransactionConnection(getTransactionConnection())); + } + + @Override + public void commitTransaction(TransactionConnection con) throws StorageQueryException { + // We do not want to commit the queries when using the BulkImportProxyStorage to be able to rollback everything + // if any query fails while importing the user + } + + @Override + public void loadConfig(JsonObject configJson, Set logLevels, TenantIdentifier tenantIdentifier) + throws InvalidConfigException { + // We are overriding the loadConfig method to set the connection pool size + // to 1 to avoid creating many connections for the bulk import cronjob + configJson.addProperty("postgresql_connection_pool_size", 1); + super.loadConfig(configJson, logLevels, tenantIdentifier); + } + + @Override + public void initStorage(boolean shouldWait, List tenantIdentifiers) throws DbInitException { + super.initStorage(shouldWait, tenantIdentifiers); + + // `BulkImportProxyStorage` uses `BulkImportProxyConnection`, which overrides the `.commit()` method on the Connection object. + // The `initStorage()` method runs `select * from table_name limit 1` queries to check if the tables exist but these queries + // don't get committed due to the overridden `.commit()`, so we need to manually commit the transaction to remove any locks on the tables. + + // Without this commit, a call to `select * from bulk_import_users limit 1` in `doesTableExist()` locks the `bulk_import_users` table, + try { + this.commitTransactionForBulkImportProxyStorage(); + } catch (StorageQueryException e) { + throw new DbInitException(e); + } + } + + @Override + public void closeConnectionForBulkImportProxyStorage() throws StorageQueryException { + try { + if (this.connection != null) { + this.connection.close(); + this.connection = null; + } + ConnectionPool.close(this); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public void commitTransactionForBulkImportProxyStorage() throws StorageQueryException { + try { + if (this.connection != null) { + this.connection.commitForBulkImportProxyStorage(); + } + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public void rollbackTransactionForBulkImportProxyStorage() throws StorageQueryException { + try { + this.connection.rollbackForBulkImportProxyStorage(); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } +} diff --git a/src/main/java/io/supertokens/storage/mysql/ConnectionPool.java b/src/main/java/io/supertokens/storage/mysql/ConnectionPool.java index 8f5bd10..8f96d0c 100644 --- a/src/main/java/io/supertokens/storage/mysql/ConnectionPool.java +++ b/src/main/java/io/supertokens/storage/mysql/ConnectionPool.java @@ -202,7 +202,7 @@ static void initPool(Start start, boolean shouldWait, PostConnectCallback postCo } } - public static Connection getConnection(Start start) throws SQLException, StorageQueryException { + private static Connection getNewConnection(Start start) throws SQLException, StorageQueryException { if (getInstance(start) == null) { throw new IllegalStateException("Please call initPool before getConnection"); } @@ -215,6 +215,17 @@ public static Connection getConnection(Start start) throws SQLException, Storage return getInstance(start).hikariDataSource.getConnection(); } + public static Connection getConnectionForProxyStorage(Start start) throws SQLException, StorageQueryException { + return getNewConnection(start); + } + + public static Connection getConnection(Start start) throws SQLException, StorageQueryException { + if (start instanceof BulkImportProxyStorage) { + return ((BulkImportProxyStorage) start).getTransactionConnection(); + } + return getNewConnection(start); + } + static void close(Start start) { if (getInstance(start) == null) { return; diff --git a/src/main/java/io/supertokens/storage/mysql/QueryExecutorTemplate.java b/src/main/java/io/supertokens/storage/mysql/QueryExecutorTemplate.java index 91d8760..4944c5a 100644 --- a/src/main/java/io/supertokens/storage/mysql/QueryExecutorTemplate.java +++ b/src/main/java/io/supertokens/storage/mysql/QueryExecutorTemplate.java @@ -59,4 +59,16 @@ static int update(Connection con, String QUERY, PreparedStatementValueSetter set } } + static T update(Start start, String QUERY, PreparedStatementValueSetter setter, ResultSetValueExtractor mapper) + throws SQLException, StorageQueryException { + try (Connection con = ConnectionPool.getConnection(start)) { + try (PreparedStatement pst = con.prepareStatement(QUERY)) { + setter.setValues(pst); + try (ResultSet result = pst.executeQuery()) { + return mapper.extract(result); + } + } + } + } + } diff --git a/src/main/java/io/supertokens/storage/mysql/Start.java b/src/main/java/io/supertokens/storage/mysql/Start.java index 2ac5ff5..93a6db2 100644 --- a/src/main/java/io/supertokens/storage/mysql/Start.java +++ b/src/main/java/io/supertokens/storage/mysql/Start.java @@ -25,6 +25,8 @@ import io.supertokens.pluginInterface.authRecipe.AuthRecipeUserInfo; import io.supertokens.pluginInterface.authRecipe.LoginMethod; import io.supertokens.pluginInterface.authRecipe.sqlStorage.AuthRecipeSQLStorage; +import io.supertokens.pluginInterface.bulkimport.BulkImportUser; +import io.supertokens.pluginInterface.bulkimport.sqlStorage.BulkImportSQLStorage; import io.supertokens.pluginInterface.dashboard.DashboardSearchTags; import io.supertokens.pluginInterface.dashboard.DashboardSessionInfo; import io.supertokens.pluginInterface.dashboard.DashboardUser; @@ -110,7 +112,7 @@ public class Start implements SessionSQLStorage, EmailPasswordSQLStorage, EmailVerificationSQLStorage, ThirdPartySQLStorage, JWTRecipeSQLStorage, PasswordlessSQLStorage, UserMetadataSQLStorage, UserRolesSQLStorage, UserIdMappingStorage, UserIdMappingSQLStorage, MultitenancyStorage, MultitenancySQLStorage, DashboardSQLStorage, TOTPSQLStorage, - ActiveUsersStorage, ActiveUsersSQLStorage, AuthRecipeSQLStorage { + ActiveUsersStorage, ActiveUsersSQLStorage, AuthRecipeSQLStorage, BulkImportSQLStorage { // these configs are protected from being modified / viewed by the dev using the SuperTokens // SaaS. If the core is not running in SuperTokens SaaS, this array has no effect. @@ -154,6 +156,29 @@ public STORAGE_TYPE getType() { return STORAGE_TYPE.SQL; } + @Override + public Storage createBulkImportProxyStorageInstance() { + return new BulkImportProxyStorage(); + } + + @Override + public void closeConnectionForBulkImportProxyStorage() throws StorageQueryException { + throw new UnsupportedOperationException( + "closeConnectionForBulkImportProxyStorage should only be called from BulkImportProxyStorage"); + } + + @Override + public void commitTransactionForBulkImportProxyStorage() throws StorageQueryException { + throw new UnsupportedOperationException( + "commitTransactionForBulkImportProxyStorage should only be called from BulkImportProxyStorage"); + } + + @Override + public void rollbackTransactionForBulkImportProxyStorage() throws StorageQueryException { + throw new UnsupportedOperationException( + "rollbackTransactionForBulkImportProxyStorage should only be called from BulkImportProxyStorage"); + } + @Override public void loadConfig(JsonObject configJson, Set logLevels, TenantIdentifier tenantIdentifier) throws InvalidConfigException { Config.loadConfig(this, configJson, logLevels, tenantIdentifier); @@ -290,7 +315,7 @@ public T startTransaction(TransactionLogic logic, TransactionIsolationLev } } - private T startTransactionHelper(TransactionLogic logic, TransactionIsolationLevel isolationLevel) + protected T startTransactionHelper(TransactionLogic logic, TransactionIsolationLevel isolationLevel) throws StorageQueryException, StorageTransactionLogicException, SQLException { Connection con = null; Integer defaultTransactionIsolation = null; @@ -3019,4 +3044,83 @@ public int getDbActivityCount(String dbname) throws SQLException, StorageQueryEx return -1; }); } + + @Override + public void addBulkImportUsers(AppIdentifier appIdentifier, List users) + throws StorageQueryException, + TenantOrAppNotFoundException, + io.supertokens.pluginInterface.bulkimport.exceptions.DuplicateUserIdException { + try { + BulkImportQueries.insertBulkImportUsers(this, appIdentifier, users); + } catch (SQLException e) { + if (e instanceof SQLIntegrityConstraintViolationException) { + MySQLConfig config = Config.getConfig(this); + String errorMessage = e.getMessage(); + if (isPrimaryKeyError(errorMessage, config.getBulkImportUsersTable())) { + throw new io.supertokens.pluginInterface.bulkimport.exceptions.DuplicateUserIdException(); + } + if (isForeignKeyConstraintError(errorMessage, config.getBulkImportUsersTable(), "app_id")) { + throw new TenantOrAppNotFoundException(appIdentifier); + } + } + throw new StorageQueryException(e); + } + } + + @Override + public List getBulkImportUsers(AppIdentifier appIdentifier, @Nonnull Integer limit, @Nullable BULK_IMPORT_USER_STATUS status, + @Nullable String bulkImportUserId, @Nullable Long createdAt) throws StorageQueryException { + try { + return BulkImportQueries.getBulkImportUsers(this, appIdentifier, limit, status, bulkImportUserId, createdAt); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public void updateBulkImportUserStatus_Transaction(AppIdentifier appIdentifier, TransactionConnection con, @Nonnull String bulkImportUserId, @Nonnull BULK_IMPORT_USER_STATUS status, @Nullable String errorMessage) + throws StorageQueryException { + Connection sqlCon = (Connection) con.getConnection(); + try { + BulkImportQueries.updateBulkImportUserStatus_Transaction(this, sqlCon, appIdentifier, bulkImportUserId, status, errorMessage); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public List deleteBulkImportUsers(AppIdentifier appIdentifier, @Nonnull String[] bulkImportUserIds) throws StorageQueryException { + try { + return BulkImportQueries.deleteBulkImportUsers(this, appIdentifier, bulkImportUserIds); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public List getBulkImportUsersAndChangeStatusToProcessing(AppIdentifier appIdentifier, @Nonnull Integer limit) throws StorageQueryException { + try { + return BulkImportQueries.getBulkImportUsersAndChangeStatusToProcessing(this, appIdentifier, limit); + } catch (StorageTransactionLogicException e) { + throw new StorageQueryException(e.actualException); + } + } + + @Override + public void updateBulkImportUserPrimaryUserId(AppIdentifier appIdentifier, @Nonnull String bulkImportUserId, @Nonnull String primaryUserId) throws StorageQueryException { + try { + BulkImportQueries.updateBulkImportUserPrimaryUserId(this, appIdentifier, bulkImportUserId, primaryUserId); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public long getBulkImportUsersCount(AppIdentifier appIdentifier, @Nullable BULK_IMPORT_USER_STATUS status) throws StorageQueryException { + try { + return BulkImportQueries.getBulkImportUsersCount(this, appIdentifier, status); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } } diff --git a/src/main/java/io/supertokens/storage/mysql/config/MySQLConfig.java b/src/main/java/io/supertokens/storage/mysql/config/MySQLConfig.java index 52bc3fe..c8d13a2 100644 --- a/src/main/java/io/supertokens/storage/mysql/config/MySQLConfig.java +++ b/src/main/java/io/supertokens/storage/mysql/config/MySQLConfig.java @@ -317,6 +317,10 @@ public String getTotpUsedCodesTable() { return addPrefixToTableName("totp_used_codes"); } + public String getBulkImportUsersTable() { + return addPrefixToTableName("bulk_import_users"); + } + private String addPrefixToTableName(String tableName) { return mysql_table_names_prefix + tableName; } diff --git a/src/main/java/io/supertokens/storage/mysql/queries/BulkImportQueries.java b/src/main/java/io/supertokens/storage/mysql/queries/BulkImportQueries.java new file mode 100644 index 0000000..42a61af --- /dev/null +++ b/src/main/java/io/supertokens/storage/mysql/queries/BulkImportQueries.java @@ -0,0 +1,335 @@ +/* + * Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.storage.mysql.queries; + +import static io.supertokens.storage.mysql.QueryExecutorTemplate.update; +import static io.supertokens.storage.mysql.QueryExecutorTemplate.execute; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import io.supertokens.pluginInterface.RowMapper; +import io.supertokens.pluginInterface.bulkimport.BulkImportStorage.BULK_IMPORT_USER_STATUS; +import io.supertokens.pluginInterface.bulkimport.BulkImportUser; +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException; +import io.supertokens.pluginInterface.multitenancy.AppIdentifier; +import io.supertokens.storage.mysql.Start; +import io.supertokens.storage.mysql.config.Config; +import io.supertokens.storage.mysql.utils.Utils; + +public class BulkImportQueries { + static String getQueryToCreateBulkImportUsersTable(Start start) { + String tableName = Config.getConfig(start).getBulkImportUsersTable(); + return "CREATE TABLE IF NOT EXISTS " + tableName + " (" + + "id CHAR(36)," + + "app_id VARCHAR(64) NOT NULL DEFAULT 'public'," + + "primary_user_id VARCHAR(36)," + + "raw_data TEXT NOT NULL," + + "status VARCHAR(128) DEFAULT 'NEW'," + + "error_msg TEXT," + + "created_at BIGINT NOT NULL, " + + "updated_at BIGINT NOT NULL," + + "PRIMARY KEY (app_id, id)," + + "FOREIGN KEY (app_id)" + + " REFERENCES " + Config.getConfig(start).getAppsTable() + " (app_id) ON DELETE CASCADE" + + " );"; + } + + public static String getQueryToCreateStatusUpdatedAtIndex(Start start) { + return "CREATE INDEX bulk_import_users_status_updated_at_index ON " + + Config.getConfig(start).getBulkImportUsersTable() + " (app_id, status, updated_at)"; + } + + public static String getQueryToCreatePaginationIndex1(Start start) { + return "CREATE INDEX bulk_import_users_pagination_index1 ON " + + Config.getConfig(start).getBulkImportUsersTable() + " (app_id, status, created_at DESC, id DESC)"; + } + + public static String getQueryToCreatePaginationIndex2(Start start) { + return "CREATE INDEX bulk_import_users_pagination_index2 ON " + + Config.getConfig(start).getBulkImportUsersTable() + " (app_id, created_at DESC, id DESC)"; + } + + public static void insertBulkImportUsers(Start start, AppIdentifier appIdentifier, List users) + throws SQLException, StorageQueryException { + StringBuilder queryBuilder = new StringBuilder( + "INSERT INTO " + Config.getConfig(start).getBulkImportUsersTable() + " (id, app_id, raw_data, created_at, updated_at) VALUES "); + + int userCount = users.size(); + + for (int i = 0; i < userCount; i++) { + queryBuilder.append(" (?, ?, ?, ?, ?)"); + + if (i < userCount - 1) { + queryBuilder.append(","); + } + } + + update(start, queryBuilder.toString(), pst -> { + int parameterIndex = 1; + for (BulkImportUser user : users) { + pst.setString(parameterIndex++, user.id); + pst.setString(parameterIndex++, appIdentifier.getAppId()); + pst.setString(parameterIndex++, user.toRawDataForDbStorage()); + pst.setLong(parameterIndex++, System.currentTimeMillis()); + pst.setLong(parameterIndex++, System.currentTimeMillis()); + } + }); + } + + public static void updateBulkImportUserStatus_Transaction(Start start, Connection con, AppIdentifier appIdentifier, + @Nonnull String bulkImportUserId, @Nonnull BULK_IMPORT_USER_STATUS status, @Nullable String errorMessage) + throws SQLException, StorageQueryException { + String query = "UPDATE " + Config.getConfig(start).getBulkImportUsersTable() + + " SET status = ?, error_msg = ?, updated_at = ? WHERE app_id = ? and id = ?"; + + List parameters = new ArrayList<>(); + + parameters.add(status.toString()); + parameters.add(errorMessage); + parameters.add(System.currentTimeMillis()); + parameters.add(appIdentifier.getAppId()); + parameters.add(bulkImportUserId); + + update(con, query, pst -> { + for (int i = 0; i < parameters.size(); i++) { + pst.setObject(i + 1, parameters.get(i)); + } + }); + } + + public static List getBulkImportUsersAndChangeStatusToProcessing(Start start, + AppIdentifier appIdentifier, + @Nonnull Integer limit) + throws StorageQueryException, StorageTransactionLogicException { + + return start.startTransaction(con -> { + Connection sqlCon = (Connection) con.getConnection(); + try { + // NOTE: On average, we take about 66 seconds to process 1000 users. If, for any reason, the bulk import users were marked as processing but couldn't be processed within 10 minutes, we'll attempt to process them again. + + // "FOR UPDATE" ensures that multiple cron jobs don't read the same rows simultaneously. + // If one process locks the first 1000 rows, others will wait for the lock to be released. + // "SKIP LOCKED" allows other processes to skip locked rows and select the next 1000 available rows. + String selectQuery = "SELECT * FROM " + Config.getConfig(start).getBulkImportUsersTable() + + " WHERE app_id = ?" + + " AND (status = 'NEW' OR (status = 'PROCESSING' AND updated_at < (UNIX_TIMESTAMP() * 1000) - 10 * 60 * 1000))" /* 10 mins */ + + " LIMIT ? FOR UPDATE SKIP LOCKED"; + + + List bulkImportUsers = new ArrayList<>(); + + execute(sqlCon, selectQuery, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setInt(2, limit); + }, result -> { + while (result.next()) { + bulkImportUsers.add(BulkImportUserRowMapper.getInstance().mapOrThrow(result)); + } + return null; + }); + + if (bulkImportUsers.isEmpty()) { + return new ArrayList<>(); + } + + String updateQuery = "UPDATE " + Config.getConfig(start).getBulkImportUsersTable() + + " SET status = ?, updated_at = ? WHERE app_id = ? AND id IN (" + Utils + .generateCommaSeperatedQuestionMarks(bulkImportUsers.size()) + ")"; + + update(sqlCon, updateQuery, pst -> { + int index = 1; + pst.setString(index++, BULK_IMPORT_USER_STATUS.PROCESSING.toString()); + pst.setLong(index++, System.currentTimeMillis()); + pst.setString(index++, appIdentifier.getAppId()); + for (BulkImportUser user : bulkImportUsers) { + pst.setObject(index++, user.id); + } + }); + return bulkImportUsers; + } catch (SQLException throwables) { + throw new StorageTransactionLogicException(throwables); + } + }); + } + + public static List getBulkImportUsers(Start start, AppIdentifier appIdentifier, + @Nonnull Integer limit, @Nullable BULK_IMPORT_USER_STATUS status, + @Nullable String bulkImportUserId, @Nullable Long createdAt) + throws SQLException, StorageQueryException { + + String baseQuery = "SELECT * FROM " + Config.getConfig(start).getBulkImportUsersTable(); + + StringBuilder queryBuilder = new StringBuilder(baseQuery); + List parameters = new ArrayList<>(); + + queryBuilder.append(" WHERE app_id = ?"); + parameters.add(appIdentifier.getAppId()); + + if (status != null) { + queryBuilder.append(" AND status = ?"); + parameters.add(status.toString()); + } + + if (bulkImportUserId != null && createdAt != null) { + queryBuilder + .append(" AND (created_at < ? OR (created_at = ? AND id <= ?))"); + parameters.add(createdAt); + parameters.add(createdAt); + parameters.add(bulkImportUserId); + } + + queryBuilder.append(" ORDER BY created_at DESC, id DESC LIMIT ?"); + parameters.add(limit); + + String query = queryBuilder.toString(); + + return execute(start, query, pst -> { + for (int i = 0; i < parameters.size(); i++) { + pst.setObject(i + 1, parameters.get(i)); + } + }, result -> { + List bulkImportUsers = new ArrayList<>(); + while (result.next()) { + bulkImportUsers.add(BulkImportUserRowMapper.getInstance().mapOrThrow(result)); + } + return bulkImportUsers; + }); + } + + public static List deleteBulkImportUsers(Start start, AppIdentifier appIdentifier, + @Nonnull String[] bulkImportUserIds) throws SQLException, StorageQueryException { + if (bulkImportUserIds.length == 0) { + return new ArrayList<>(); + } + + // This function needs to return the IDs of the deleted users. Since the DELETE query doesn't return the IDs of the deleted entries, + // we first perform a SELECT query to find all IDs that actually exist in the database. After deletion, we return these IDs. + String selectQuery = "SELECT id FROM " + Config.getConfig(start).getBulkImportUsersTable() + + " WHERE app_id = ? AND id IN (" + Utils + .generateCommaSeperatedQuestionMarks(bulkImportUserIds.length) + ")"; + + List deletedIds = new ArrayList<>(); + + execute(start, selectQuery, pst -> { + int index = 1; + pst.setString(index++, appIdentifier.getAppId()); + for (String id : bulkImportUserIds) { + pst.setObject(index++, id); + } + }, result -> { + while (result.next()) { + deletedIds.add(result.getString("id")); + } + return null; + }); + + if (deletedIds.isEmpty()) { + return new ArrayList<>(); + } + + String deleteQuery = "DELETE FROM " + Config.getConfig(start).getBulkImportUsersTable() + + " WHERE app_id = ? AND id IN (" + Utils.generateCommaSeperatedQuestionMarks(deletedIds.size()) + ")"; + + update(start, deleteQuery, pst -> { + int index = 1; + pst.setString(index++, appIdentifier.getAppId()); + for (String id : deletedIds) { + pst.setObject(index++, id); + } + }); + + return deletedIds; + } + + public static void deleteBulkImportUser_Transaction(Start start, Connection con, AppIdentifier appIdentifier, + @Nonnull String bulkImportUserId) throws SQLException, StorageQueryException { + String query = "DELETE FROM " + Config.getConfig(start).getBulkImportUsersTable() + + " WHERE app_id = ? AND id = ?"; + + update(con, query, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setString(2, bulkImportUserId); + }); + } + + public static void updateBulkImportUserPrimaryUserId(Start start, AppIdentifier appIdentifier, + @Nonnull String bulkImportUserId, + @Nonnull String primaryUserId) throws SQLException, StorageQueryException { + String query = "UPDATE " + Config.getConfig(start).getBulkImportUsersTable() + + " SET primary_user_id = ?, updated_at = ? WHERE app_id = ? and id = ?"; + + update(start, query, pst -> { + pst.setString(1, primaryUserId); + pst.setLong(2, System.currentTimeMillis()); + pst.setString(3, appIdentifier.getAppId()); + pst.setString(4, bulkImportUserId); + }); + } + + public static long getBulkImportUsersCount(Start start, AppIdentifier appIdentifier, @Nullable BULK_IMPORT_USER_STATUS status) throws SQLException, StorageQueryException { + String baseQuery = "SELECT COUNT(*) FROM " + Config.getConfig(start).getBulkImportUsersTable(); + StringBuilder queryBuilder = new StringBuilder(baseQuery); + + List parameters = new ArrayList<>(); + + queryBuilder.append(" WHERE app_id = ?"); + parameters.add(appIdentifier.getAppId()); + + if (status != null) { + queryBuilder.append(" AND status = ?"); + parameters.add(status.toString()); + } + + String query = queryBuilder.toString(); + + return execute(start, query, pst -> { + for (int i = 0; i < parameters.size(); i++) { + pst.setObject(i + 1, parameters.get(i)); + } + }, result -> { + result.next(); + return result.getLong(1); + }); + } + + private static class BulkImportUserRowMapper implements RowMapper { + private static final BulkImportUserRowMapper INSTANCE = new BulkImportUserRowMapper(); + + private BulkImportUserRowMapper() { + } + + private static BulkImportUserRowMapper getInstance() { + return INSTANCE; + } + + @Override + public BulkImportUser map(ResultSet result) throws Exception { + return BulkImportUser.fromRawDataFromDbStorage(result.getString("id"), result.getString("raw_data"), + BULK_IMPORT_USER_STATUS.valueOf(result.getString("status")), + result.getString("primary_user_id"), result.getString("error_msg"), result.getLong("created_at"), + result.getLong("updated_at")); + } + } +} \ No newline at end of file diff --git a/src/main/java/io/supertokens/storage/mysql/queries/GeneralQueries.java b/src/main/java/io/supertokens/storage/mysql/queries/GeneralQueries.java index 8df206c..2e9e2bf 100644 --- a/src/main/java/io/supertokens/storage/mysql/queries/GeneralQueries.java +++ b/src/main/java/io/supertokens/storage/mysql/queries/GeneralQueries.java @@ -413,6 +413,15 @@ public static void createTablesIfNotExists(Start start, Connection con) throws S // index: update(con, TOTPQueries.getQueryToCreateUsedCodesExpiryTimeIndex(start), NO_OP_SETTER); } + + if (!doesTableExists(start, con, Config.getConfig(start).getBulkImportUsersTable())) { + getInstance(start).addState(CREATING_NEW_TABLE, null); + update(start, BulkImportQueries.getQueryToCreateBulkImportUsersTable(start), NO_OP_SETTER); + // index: + update(start, BulkImportQueries.getQueryToCreateStatusUpdatedAtIndex(start), NO_OP_SETTER); + update(start, BulkImportQueries.getQueryToCreatePaginationIndex1(start), NO_OP_SETTER); + update(start, BulkImportQueries.getQueryToCreatePaginationIndex2(start), NO_OP_SETTER); + } } @TestOnly