From 6c1283dac7898c8d8d70ae97389d9e0d95a3c9b7 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Mon, 6 Nov 2023 10:00:50 +0530 Subject: [PATCH] fix: mfa --- .../io/supertokens/storage/mysql/Start.java | 141 +++--------------- .../storage/mysql/config/MySQLConfig.java | 8 + .../mysql/queries/ActiveUsersQueries.java | 37 +++++ .../storage/mysql/queries/GeneralQueries.java | 41 +++++ .../mysql/queries/MultitenancyQueries.java | 49 +++++- .../storage/mysql/queries/TOTPQueries.java | 7 +- .../queries/multitenancy/MfaSqlHelper.java | 120 +++++++++++++++ .../multitenancy/TenantConfigSQLHelper.java | 33 +++- .../storage/mysql/test/DeadlockTest.java | 4 +- .../storage/mysql/test/LoggingTest.java | 1 + .../storage/mysql/test/StorageLayerTest.java | 2 +- .../test/multitenancy/StorageLayerTest.java | 20 +++ .../TestUserPoolIdChangeBehaviour.java | 4 + 13 files changed, 333 insertions(+), 134 deletions(-) create mode 100644 src/main/java/io/supertokens/storage/mysql/queries/multitenancy/MfaSqlHelper.java diff --git a/src/main/java/io/supertokens/storage/mysql/Start.java b/src/main/java/io/supertokens/storage/mysql/Start.java index 5a753da..18a8831 100644 --- a/src/main/java/io/supertokens/storage/mysql/Start.java +++ b/src/main/java/io/supertokens/storage/mysql/Start.java @@ -48,8 +48,6 @@ import io.supertokens.pluginInterface.jwt.JWTSigningKeyInfo; import io.supertokens.pluginInterface.jwt.exceptions.DuplicateKeyIdException; import io.supertokens.pluginInterface.jwt.sqlstorage.JWTRecipeSQLStorage; -import io.supertokens.pluginInterface.mfa.MfaStorage; -import io.supertokens.pluginInterface.mfa.sqlStorage.MfaSQLStorage; import io.supertokens.pluginInterface.multitenancy.*; import io.supertokens.pluginInterface.multitenancy.exceptions.DuplicateClientTypeException; import io.supertokens.pluginInterface.multitenancy.exceptions.DuplicateTenantException; @@ -108,7 +106,7 @@ public class Start implements SessionSQLStorage, EmailPasswordSQLStorage, EmailVerificationSQLStorage, ThirdPartySQLStorage, JWTRecipeSQLStorage, PasswordlessSQLStorage, UserMetadataSQLStorage, UserRolesSQLStorage, UserIdMappingStorage, UserIdMappingSQLStorage, MultitenancyStorage, MultitenancySQLStorage, DashboardSQLStorage, TOTPSQLStorage, - ActiveUsersStorage, ActiveUsersSQLStorage, MfaStorage, MfaSQLStorage, AuthRecipeSQLStorage { + ActiveUsersStorage, ActiveUsersSQLStorage, AuthRecipeSQLStorage { // these configs are protected from being modified / viewed by the dev using the SuperTokens // SaaS. If the core is not running in SuperTokens SaaS, this array has no effect. @@ -728,13 +726,6 @@ public boolean isUserIdBeingUsedInNonAuthRecipe(AppIdentifier appIdentifier, Str return false; } else if (className.equals(ActiveUsersStorage.class.getName())) { return ActiveUsersQueries.getLastActiveByUserId(this, appIdentifier, userId) != null; - } else if (className.equals(MfaStorage.class.getName())) { - try { - MultitenancyQueries.getAllTenants(this); - return MfaQueries.listFactors(this, appIdentifier, userId).length > 0; - } catch (SQLException e) { - throw new StorageQueryException(e); - } } else { throw new IllegalStateException("ClassName: " + className + " is not part of NonAuthRecipeStorage"); } @@ -805,7 +796,7 @@ public void addInfoToNonAuthRecipesBasedOnUserId(TenantIdentifier tenantIdentifi } } else if (className.equals(TOTPStorage.class.getName())) { try { - TOTPDevice device = new TOTPDevice(userId, "testDevice", "secret", 0, 30, false); + TOTPDevice device = new TOTPDevice(userId, "testDevice", "secret", 0, 30, false, System.currentTimeMillis()); this.startTransaction(con -> { try { long now = System.currentTimeMillis(); @@ -830,12 +821,6 @@ public void addInfoToNonAuthRecipesBasedOnUserId(TenantIdentifier tenantIdentifi } catch (SQLException e) { throw new StorageQueryException(e); } - } else if (className.equals(MfaStorage.class.getName())) { - try { - MfaQueries.enableFactor(this, tenantIdentifier, userId, "emailpassword"); - } catch (SQLException e) { - throw new StorageQueryException(e); - } } else { throw new IllegalStateException("ClassName: " + className + " is not part of NonAuthRecipeStorage"); } @@ -1295,44 +1280,6 @@ public int countUsersActiveSince(AppIdentifier appIdentifier, long time) throws } } - @Override - public int countUsersEnabledTotp(AppIdentifier appIdentifier) throws StorageQueryException { - try { - return ActiveUsersQueries.countUsersEnabledTotp(this, appIdentifier); - } catch (SQLException e) { - throw new StorageQueryException(e); - } - } - - @Override - public int countUsersEnabledTotpAndActiveSince(AppIdentifier appIdentifier, long time) - throws StorageQueryException { - try { - return ActiveUsersQueries.countUsersEnabledTotpAndActiveSince(this, appIdentifier, time); - } catch (SQLException e) { - throw new StorageQueryException(e); - } - } - - @Override - public int countUsersEnabledMfa(AppIdentifier appIdentifier) throws StorageQueryException { - try { - return ActiveUsersQueries.countUsersEnabledMfa(this, appIdentifier); - } catch (SQLException e) { - throw new StorageQueryException(e); - } - } - - @Override - public int countUsersEnabledMfaAndActiveSince(AppIdentifier appIdentifier, long time) - throws StorageQueryException { - try { - return ActiveUsersQueries.countUsersEnabledMfaAndActiveSince(this, appIdentifier, time); - } catch (SQLException e) { - throw new StorageQueryException(e); - } - } - @Override public void deleteUserActive_Transaction(TransactionConnection con, AppIdentifier appIdentifier, String userId) throws StorageQueryException { @@ -2770,72 +2717,6 @@ public int removeExpiredCodes(TenantIdentifier tenantIdentifier, long expiredBef } } - - // MFA recipe: - @Override - public boolean enableFactor(TenantIdentifier tenantIdentifier, String userId, String factor) - throws StorageQueryException { - try { - int insertedCount = MfaQueries.enableFactor(this, tenantIdentifier, userId, factor); - if (insertedCount == 0) { - return false; - } - return true; - } catch (SQLException e) { - throw new StorageQueryException(e); - } - } - - @Override - public String[] listFactors(TenantIdentifier tenantIdentifier, String userId) - throws StorageQueryException { - try { - return MfaQueries.listFactors(this, tenantIdentifier, userId); - } catch (SQLException e) { - throw new StorageQueryException(e); - } - } - - @Override - public boolean disableFactor(TenantIdentifier tenantIdentifier, String userId, String factor) - throws StorageQueryException { - try { - int deletedCount = MfaQueries.disableFactor(this, tenantIdentifier, userId, factor); - if (deletedCount == 0) { - return false; - } - return true; - } catch (SQLException e) { - throw new StorageQueryException(e); - } - } - - @Override - public boolean deleteMfaInfoForUser_Transaction(TransactionConnection con, AppIdentifier appIdentifier, String userId) throws StorageQueryException { - try { - int deletedCount = MfaQueries.deleteUser_Transaction(this, (Connection) con.getConnection(), appIdentifier, userId); - if (deletedCount == 0) { - return false; - } - return true; - } catch (SQLException e) { - throw new StorageQueryException(e); - } - } - - @Override - public boolean deleteMfaInfoForUser(TenantIdentifier tenantIdentifier, String userId) throws StorageQueryException { - try { - int deletedCount = MfaQueries.deleteUser(this, tenantIdentifier, userId); - if (deletedCount == 0) { - return false; - } - return true; - } catch (SQLException e) { - throw new StorageQueryException(e); - } - } - @Override public Set getValidFieldsInConfig() { return MySQLConfig.getValidFields(); @@ -3053,6 +2934,24 @@ public UserIdMapping[] getUserIdMapping_Transaction(TransactionConnection con, A } } + @Override + public int getUsersCountWithMoreThanOneLoginMethodOrTOTPEnabled(AppIdentifier appIdentifier) throws StorageQueryException { + try { + return GeneralQueries.getUsersCountWithMoreThanOneLoginMethodOrTOTPEnabled(this, appIdentifier); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + + @Override + public int countUsersThatHaveMoreThanOneLoginMethodOrTOTPEnabledAndActiveSince(AppIdentifier appIdentifier, long sinceTime) throws StorageQueryException { + try { + return ActiveUsersQueries.countUsersThatHaveMoreThanOneLoginMethodOrTOTPEnabledAndActiveSince(this, appIdentifier, sinceTime); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } + public static boolean isEnabledForDeadlockTesting() { return enableForDeadlockTesting; } diff --git a/src/main/java/io/supertokens/storage/mysql/config/MySQLConfig.java b/src/main/java/io/supertokens/storage/mysql/config/MySQLConfig.java index 2c294ef..e9277f1 100644 --- a/src/main/java/io/supertokens/storage/mysql/config/MySQLConfig.java +++ b/src/main/java/io/supertokens/storage/mysql/config/MySQLConfig.java @@ -176,6 +176,14 @@ public String getTenantConfigsTable() { return addPrefixToTableName("tenant_configs"); } + public String getTenantFirstFactorsTable() { + return addPrefixToTableName("tenant_first_factors"); + } + + public String getTenantDefaultRequiredFactorIdsTable() { + return addPrefixToTableName("tenant_default_required_factor_ids"); + } + public String getTenantThirdPartyProvidersTable() { return addPrefixToTableName("tenant_thirdparty_providers"); } diff --git a/src/main/java/io/supertokens/storage/mysql/queries/ActiveUsersQueries.java b/src/main/java/io/supertokens/storage/mysql/queries/ActiveUsersQueries.java index 6922254..3d26a12 100644 --- a/src/main/java/io/supertokens/storage/mysql/queries/ActiveUsersQueries.java +++ b/src/main/java/io/supertokens/storage/mysql/queries/ActiveUsersQueries.java @@ -169,4 +169,41 @@ public static void deleteUserActive_Transaction(Connection con, Start start, App pst.setString(2, userId); }); } + + public static int countUsersThatHaveMoreThanOneLoginMethodOrTOTPEnabledAndActiveSince(Start start, AppIdentifier appIdentifier, long sinceTime) + throws SQLException, StorageQueryException { + // TODO: Active users are present only on public tenant and MFA users may be present on different storages + String QUERY = + "SELECT COUNT (DISTINCT user_id) as c FROM (" + + " (" // users with more than one login method + + " SELECT primary_or_recipe_user_id AS user_id FROM (" + + " SELECT COUNT(user_id) as num_login_methods, app_id, primary_or_recipe_user_id" + + " FROM " + Config.getConfig(start).getAppIdToUserIdTable() + + " WHERE app_id = ? AND primary_or_recipe_user_id IN (" + + " SELECT user_id FROM " + Config.getConfig(start).getUserLastActiveTable() + + " WHERE app_id = ? AND last_active_time >= ?" + + " )" + + " GROUP BY (app_id, primary_or_recipe_user_id)" + + " ) AS nloginmethods" + + " WHERE num_login_methods > 1" + + " ) UNION (" // TOTP users + + " SELECT user_id FROM " + Config.getConfig(start).getTotpUsersTable() + + " WHERE app_id = ? AND user_id IN (" + + " SELECT user_id FROM " + Config.getConfig(start).getUserLastActiveTable() + + " WHERE app_id = ? AND last_active_time >= ?" + + " )" + + " )" + + ") AS all_users"; + + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setString(2, appIdentifier.getAppId()); + pst.setLong(3, sinceTime); + pst.setString(4, appIdentifier.getAppId()); + pst.setString(5, appIdentifier.getAppId()); + pst.setLong(6, sinceTime); + }, result -> { + return result.next() ? result.getInt("c") : 0; + }); + } } diff --git a/src/main/java/io/supertokens/storage/mysql/queries/GeneralQueries.java b/src/main/java/io/supertokens/storage/mysql/queries/GeneralQueries.java index 0ca6737..075a263 100644 --- a/src/main/java/io/supertokens/storage/mysql/queries/GeneralQueries.java +++ b/src/main/java/io/supertokens/storage/mysql/queries/GeneralQueries.java @@ -248,6 +248,21 @@ public static void createTablesIfNotExists(Start start) throws SQLException, Sto update(start, MultitenancyQueries.getQueryToCreateTenantConfigsTable(start), NO_OP_SETTER); } + if (!doesTableExists(start, Config.getConfig(start).getTenantFirstFactorsTable())) { + getInstance(start).addState(CREATING_NEW_TABLE, null); + update(start, MultitenancyQueries.getQueryToCreateFirstFactorsTable(start), NO_OP_SETTER); + } + + if (!doesTableExists(start, Config.getConfig(start).getTenantDefaultRequiredFactorIdsTable())) { + getInstance(start).addState(CREATING_NEW_TABLE, null); + update(start, MultitenancyQueries.getQueryToCreateDefaultRequiredFactorIdsTable(start), NO_OP_SETTER); + + // index + update(start, + MultitenancyQueries.getQueryToCreateOrderIndexForDefaultRequiredFactorIdsTable(start), + NO_OP_SETTER); + } + if (!doesTableExists(start, Config.getConfig(start).getTenantThirdPartyProvidersTable())) { getInstance(start).addState(CREATING_NEW_TABLE, null); update(start, MultitenancyQueries.getQueryToCreateTenantThirdPartyProvidersTable(start), @@ -1551,6 +1566,32 @@ public static int getUsersCountWithMoreThanOneLoginMethod(Start start, AppIdenti }); } + public static int getUsersCountWithMoreThanOneLoginMethodOrTOTPEnabled(Start start, AppIdentifier appIdentifier) + throws SQLException, StorageQueryException { + String QUERY = + "SELECT COUNT (DISTINCT user_id) as c FROM (" + + " (" // Users with number of login methods > 1 + + " SELECT primary_or_recipe_user_id AS user_id FROM (" + + " SELECT COUNT(user_id) as num_login_methods, app_id, primary_or_recipe_user_id" + + " FROM " + Config.getConfig(start).getAppIdToUserIdTable() + + " WHERE app_id = ? " + + " GROUP BY (app_id, primary_or_recipe_user_id)" + + " ) AS nloginmethods" + + " WHERE num_login_methods > 1" + + " ) UNION (" // TOTP users + + " SELECT user_id FROM " + Config.getConfig(start).getTotpUsersTable() + + " WHERE app_id = ?" + + " )" + + ") AS all_users"; + + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setString(2, appIdentifier.getAppId()); + }, result -> { + return result.next() ? result.getInt("c") : 0; + }); + } + public static boolean checkIfUsesAccountLinking(Start start, AppIdentifier appIdentifier) throws SQLException, StorageQueryException { String QUERY = "SELECT 1 FROM " diff --git a/src/main/java/io/supertokens/storage/mysql/queries/MultitenancyQueries.java b/src/main/java/io/supertokens/storage/mysql/queries/MultitenancyQueries.java index edacff4..9231eca 100644 --- a/src/main/java/io/supertokens/storage/mysql/queries/MultitenancyQueries.java +++ b/src/main/java/io/supertokens/storage/mysql/queries/MultitenancyQueries.java @@ -27,6 +27,7 @@ import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; import io.supertokens.storage.mysql.Start; import io.supertokens.storage.mysql.config.Config; +import io.supertokens.storage.mysql.queries.multitenancy.MfaSqlHelper; import io.supertokens.storage.mysql.queries.multitenancy.TenantConfigSQLHelper; import io.supertokens.storage.mysql.queries.multitenancy.ThirdPartyProviderClientSQLHelper; import io.supertokens.storage.mysql.queries.multitenancy.ThirdPartyProviderSQLHelper; @@ -51,6 +52,9 @@ static String getQueryToCreateTenantConfigsTable(Start start) { + "email_password_enabled BOOLEAN," + "passwordless_enabled BOOLEAN," + "third_party_enabled BOOLEAN," + + "totp_enabled BOOLEAN," + + "has_first_factors BOOLEAN DEFAULT FALSE," + + "has_default_required_factor_ids BOOLEAN DEFAULT FALSE," + "PRIMARY KEY (connection_uri_domain, app_id, tenant_id)" + ");"; // @formatter:on @@ -107,6 +111,43 @@ static String getQueryToCreateTenantThirdPartyProviderClientsTable(Start start) + ");"; } + public static String getQueryToCreateFirstFactorsTable(Start start) { + String tableName = Config.getConfig(start).getTenantFirstFactorsTable(); + // @formatter:off + return "CREATE TABLE IF NOT EXISTS " + tableName + " (" + + "connection_uri_domain VARCHAR(256) DEFAULT ''," + + "app_id VARCHAR(64) DEFAULT 'public'," + + "tenant_id VARCHAR(64) DEFAULT 'public'," + + "factor_id VARCHAR(128)," + + "PRIMARY KEY (connection_uri_domain, app_id, tenant_id, factor_id)," + + "FOREIGN KEY (connection_uri_domain, app_id, tenant_id)" + + " REFERENCES " + Config.getConfig(start).getTenantConfigsTable() + " (connection_uri_domain, app_id, tenant_id) ON DELETE CASCADE" + + ");"; + // @formatter:on + } + + public static String getQueryToCreateDefaultRequiredFactorIdsTable(Start start) { + String tableName = Config.getConfig(start).getTenantDefaultRequiredFactorIdsTable(); + // @formatter:off + return "CREATE TABLE IF NOT EXISTS " + tableName + " (" + + "connection_uri_domain VARCHAR(256) DEFAULT ''," + + "app_id VARCHAR(64) DEFAULT 'public'," + + "tenant_id VARCHAR(64) DEFAULT 'public'," + + "factor_id VARCHAR(128)," + + "order_idx INTEGER NOT NULL," + + "PRIMARY KEY (connection_uri_domain, app_id, tenant_id, factor_id)," + + "FOREIGN KEY (connection_uri_domain, app_id, tenant_id)" + + " REFERENCES " + Config.getConfig(start).getTenantConfigsTable() + " (connection_uri_domain, app_id, tenant_id) ON DELETE CASCADE," + + " UNIQUE (connection_uri_domain, app_id, tenant_id, order_idx)" + + ");"; + // @formatter:on + } + + public static String getQueryToCreateOrderIndexForDefaultRequiredFactorIdsTable(Start start) { + return "CREATE INDEX IF NOT EXISTS tenant_default_required_factor_ids_tenant_id_index ON " + + getConfig(start).getTenantDefaultRequiredFactorIdsTable() + " (order_idx ASC);"; + } + private static void executeCreateTenantQueries(Start start, Connection sqlCon, TenantConfig tenantConfig) throws SQLException, StorageTransactionLogicException { @@ -221,7 +262,13 @@ public static TenantConfig[] getAllTenants(Start start) throws StorageQueryExcep // Map (tenantIdentifier) -> thirdPartyId -> provider HashMap> providerMap = ThirdPartyProviderSQLHelper.selectAll(start, providerClientsMap); - return TenantConfigSQLHelper.selectAll(start, providerMap); + // Map (tenantIdentifier) -> firstFactors + HashMap firstFactorsMap = MfaSqlHelper.selectAllFirstFactors(start); + + // Map (tenantIdentifier) -> defaultRequiredFactorIds + HashMap defaultRequiredFactorIdsMap = MfaSqlHelper.selectAllDefaultRequiredFactorIds(start); + + return TenantConfigSQLHelper.selectAll(start, providerMap, firstFactorsMap, defaultRequiredFactorIdsMap); } catch (SQLException throwables) { throw new StorageQueryException(throwables); } diff --git a/src/main/java/io/supertokens/storage/mysql/queries/TOTPQueries.java b/src/main/java/io/supertokens/storage/mysql/queries/TOTPQueries.java index 443e0c4..b3536dd 100644 --- a/src/main/java/io/supertokens/storage/mysql/queries/TOTPQueries.java +++ b/src/main/java/io/supertokens/storage/mysql/queries/TOTPQueries.java @@ -39,6 +39,7 @@ public static String getQueryToCreateUserDevicesTable(Start start) { + "period INTEGER NOT NULL," + "skew INTEGER NOT NULL," + "verified BOOLEAN NOT NULL," + + "created_at BIGINT UNSIGNED NOT NULL," + "PRIMARY KEY (app_id, user_id, device_name)," + "FOREIGN KEY (app_id, user_id)" + " REFERENCES " + Config.getConfig(start).getTotpUsersTable() + "(app_id, user_id) ON DELETE CASCADE" @@ -88,7 +89,7 @@ private static int insertUser_Transaction(Start start, Connection con, AppIdenti private static int insertDevice_Transaction(Start start, Connection con, AppIdentifier appIdentifier, TOTPDevice device) throws SQLException, StorageQueryException { String QUERY = "INSERT INTO " + Config.getConfig(start).getTotpUserDevicesTable() - + " (app_id, user_id, device_name, secret_key, period, skew, verified) VALUES (?, ?, ?, ?, ?, ?, ?)"; + + " (app_id, user_id, device_name, secret_key, period, skew, verified, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)"; return update(con, QUERY, pst -> { pst.setString(1, appIdentifier.getAppId()); @@ -98,6 +99,7 @@ private static int insertDevice_Transaction(Start start, Connection con, AppIden pst.setInt(5, device.period); pst.setInt(6, device.skew); pst.setBoolean(7, device.verified); + pst.setLong(8, device.createdAt); }); } @@ -293,7 +295,8 @@ public TOTPDevice map(ResultSet result) throws SQLException { result.getString("secret_key"), result.getInt("period"), result.getInt("skew"), - result.getBoolean("verified")); + result.getBoolean("verified"), + result.getLong("created_at")); } } diff --git a/src/main/java/io/supertokens/storage/mysql/queries/multitenancy/MfaSqlHelper.java b/src/main/java/io/supertokens/storage/mysql/queries/multitenancy/MfaSqlHelper.java new file mode 100644 index 0000000..c101e09 --- /dev/null +++ b/src/main/java/io/supertokens/storage/mysql/queries/multitenancy/MfaSqlHelper.java @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.supertokens.storage.mysql.queries.multitenancy; + +import io.supertokens.pluginInterface.exceptions.StorageQueryException; +import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import io.supertokens.storage.mysql.Start; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.*; + +import static io.supertokens.storage.mysql.QueryExecutorTemplate.execute; +import static io.supertokens.storage.mysql.QueryExecutorTemplate.update; +import static io.supertokens.storage.mysql.config.Config.getConfig; + +public class MfaSqlHelper { + public static HashMap selectAllFirstFactors(Start start) + throws SQLException, StorageQueryException { + String QUERY = "SELECT connection_uri_domain, app_id, tenant_id, factor_id FROM " + + getConfig(start).getTenantFirstFactorsTable() + ";"; + return execute(start, QUERY, pst -> {}, result -> { + HashMap> firstFactors = new HashMap<>(); + + while (result.next()) { + TenantIdentifier tenantIdentifier = new TenantIdentifier(result.getString("connection_uri_domain"), result.getString("app_id"), result.getString("tenant_id")); + if (!firstFactors.containsKey(tenantIdentifier)) { + firstFactors.put(tenantIdentifier, new ArrayList<>()); + } + + firstFactors.get(tenantIdentifier).add(result.getString("factor_id")); + } + + HashMap finalResult = new HashMap<>(); + for (TenantIdentifier tenantIdentifier : firstFactors.keySet()) { + finalResult.put(tenantIdentifier, firstFactors.get(tenantIdentifier).toArray(new String[0])); + } + + return finalResult; + }); + } + + public static HashMap selectAllDefaultRequiredFactorIds(Start start) + throws SQLException, StorageQueryException { + String QUERY = "SELECT connection_uri_domain, app_id, tenant_id, factor_id, order_idx FROM " + + getConfig(start).getTenantDefaultRequiredFactorIdsTable() + " ORDER BY order_idx ASC;"; + return execute(start, QUERY, pst -> {}, result -> { + HashMap> defaultRequiredFactors = new HashMap<>(); + + while (result.next()) { + TenantIdentifier tenantIdentifier = new TenantIdentifier(result.getString("connection_uri_domain"), + result.getString("app_id"), result.getString("tenant_id")); + if (!defaultRequiredFactors.containsKey(tenantIdentifier)) { + defaultRequiredFactors.put(tenantIdentifier, new ArrayList<>()); + } + + defaultRequiredFactors.get(tenantIdentifier).add(result.getString("factor_id")); + } + + HashMap finalResult = new HashMap<>(); + for (TenantIdentifier tenantIdentifier : defaultRequiredFactors.keySet()) { + finalResult.put(tenantIdentifier, defaultRequiredFactors.get(tenantIdentifier).toArray(new String[0])); + } + + return finalResult; + }); + } + + public static void createFirstFactors(Start start, Connection sqlCon, TenantIdentifier tenantIdentifier, String[] firstFactors) + throws SQLException, StorageQueryException { + if (firstFactors == null || firstFactors.length == 0) { + return; + } + + String QUERY = "INSERT INTO " + getConfig(start).getTenantFirstFactorsTable() + "(connection_uri_domain, app_id, tenant_id, factor_id) VALUES (?, ?, ?, ?);"; + for (String factorId : new HashSet<>(Arrays.asList(firstFactors))) { + update(sqlCon, QUERY, pst -> { + pst.setString(1, tenantIdentifier.getConnectionUriDomain()); + pst.setString(2, tenantIdentifier.getAppId()); + pst.setString(3, tenantIdentifier.getTenantId()); + pst.setString(4, factorId); + }); + } + } + + public static void createDefaultRequiredFactorIds(Start start, Connection sqlCon, TenantIdentifier tenantIdentifier, String[] defaultRequiredFactorIds) + throws SQLException, StorageQueryException { + if (defaultRequiredFactorIds == null || defaultRequiredFactorIds.length == 0) { + return; + } + + String QUERY = "INSERT INTO " + getConfig(start).getTenantDefaultRequiredFactorIdsTable() + "(connection_uri_domain, app_id, tenant_id, factor_id, order_idx) VALUES (?, ?, ?, ?, ?);"; + int orderIdx = 0; + for (String factorId : defaultRequiredFactorIds) { + int finalOrderIdx = orderIdx; + update(sqlCon, QUERY, pst -> { + pst.setString(1, tenantIdentifier.getConnectionUriDomain()); + pst.setString(2, tenantIdentifier.getAppId()); + pst.setString(3, tenantIdentifier.getTenantId()); + pst.setString(4, factorId); + pst.setInt(5, finalOrderIdx); + }); + orderIdx++; + } + } +} diff --git a/src/main/java/io/supertokens/storage/mysql/queries/multitenancy/TenantConfigSQLHelper.java b/src/main/java/io/supertokens/storage/mysql/queries/multitenancy/TenantConfigSQLHelper.java index 7c82f71..d10adb2 100644 --- a/src/main/java/io/supertokens/storage/mysql/queries/multitenancy/TenantConfigSQLHelper.java +++ b/src/main/java/io/supertokens/storage/mysql/queries/multitenancy/TenantConfigSQLHelper.java @@ -37,13 +37,17 @@ public class TenantConfigSQLHelper { public static class TenantConfigRowMapper implements RowMapper { ThirdPartyConfig.Provider[] providers; + String[] firstFactors; + String[] defaultRequiredFactorIds; - private TenantConfigRowMapper(ThirdPartyConfig.Provider[] providers) { + private TenantConfigRowMapper(ThirdPartyConfig.Provider[] providers, String[] firstFactors, String[] defaultRequiredFactorIds) { this.providers = providers; + this.firstFactors = firstFactors; + this.defaultRequiredFactorIds = defaultRequiredFactorIds; } - public static TenantConfigRowMapper getInstance(ThirdPartyConfig.Provider[] providers) { - return new TenantConfigRowMapper(providers); + public static TenantConfigRowMapper getInstance(ThirdPartyConfig.Provider[] providers, String[] firstFactors, String[] defaultRequiredFactorIds) { + return new TenantConfigRowMapper(providers, firstFactors, defaultRequiredFactorIds); } @Override @@ -54,6 +58,9 @@ public TenantConfig map(ResultSet result) throws StorageQueryException { new EmailPasswordConfig(result.getBoolean("email_password_enabled")), new ThirdPartyConfig(result.getBoolean("third_party_enabled"), this.providers), new PasswordlessConfig(result.getBoolean("passwordless_enabled")), + new TotpConfig(result.getBoolean("totp_enabled")), + result.getBoolean("has_first_factors") ? firstFactors : null, + result.getBoolean("has_default_required_factor_ids") ? defaultRequiredFactorIds : null, JsonUtils.stringToJsonObject(result.getString("core_config")) ); } catch (Exception e) { @@ -62,9 +69,11 @@ public TenantConfig map(ResultSet result) throws StorageQueryException { } } - public static TenantConfig[] selectAll(Start start, HashMap> providerMap) + public static TenantConfig[] selectAll(Start start, HashMap> providerMap, HashMap firstFactorsMap, HashMap defaultRequiredFactorIdsMap) throws SQLException, StorageQueryException { - String QUERY = "SELECT connection_uri_domain, app_id, tenant_id, core_config, email_password_enabled, passwordless_enabled, third_party_enabled FROM " + String QUERY = "SELECT connection_uri_domain, app_id, tenant_id, core_config," + + " email_password_enabled, passwordless_enabled, third_party_enabled," + + " totp_enabled, has_first_factors, has_default_required_factor_ids FROM " + getConfig(start).getTenantConfigsTable() + ";"; TenantConfig[] tenantConfigs = execute(start, QUERY, pst -> {}, result -> { @@ -75,7 +84,11 @@ public static TenantConfig[] selectAll(Start start, HashMap { @@ -100,6 +116,9 @@ public static void create(Start start, Connection sqlCon, TenantConfig tenantCon pst.setBoolean(5, tenantConfig.emailPasswordConfig.enabled); pst.setBoolean(6, tenantConfig.passwordlessConfig.enabled); pst.setBoolean(7, tenantConfig.thirdPartyConfig.enabled); + pst.setBoolean(8, tenantConfig.totpConfig.enabled); + pst.setBoolean(9, tenantConfig.firstFactors != null); + pst.setBoolean(10, tenantConfig.defaultRequiredFactorIds != null); }); } catch (StorageQueryException e) { throw new StorageTransactionLogicException(e); diff --git a/src/test/java/io/supertokens/storage/mysql/test/DeadlockTest.java b/src/test/java/io/supertokens/storage/mysql/test/DeadlockTest.java index 4839358..db3633d 100644 --- a/src/test/java/io/supertokens/storage/mysql/test/DeadlockTest.java +++ b/src/test/java/io/supertokens/storage/mysql/test/DeadlockTest.java @@ -280,7 +280,7 @@ public void testConcurrentDeleteAndInsert() throws Exception { // Create a device as well as a user: TOTPSQLStorage totpStorage = (TOTPSQLStorage) StorageLayer.getStorage(process.getProcess()); - TOTPDevice device = new TOTPDevice("user", "d1", "secret", 30, 1, false); + TOTPDevice device = new TOTPDevice("user", "d1", "secret", 30, 1, false, System.currentTimeMillis()); totpStorage.createDevice(TenantIdentifier.BASE_TENANT.toAppIdentifier(), device); long now = System.currentTimeMillis(); @@ -445,7 +445,7 @@ public void testConcurrentDeleteAndUpdate() throws Exception { // Create a device as well as a user: TOTPSQLStorage totpStorage = (TOTPSQLStorage) StorageLayer.getStorage(process.getProcess()); - TOTPDevice device = new TOTPDevice("user", "d1", "secret", 30, 1, false); + TOTPDevice device = new TOTPDevice("user", "d1", "secret", 30, 1, false, System.currentTimeMillis()); totpStorage.createDevice(TenantIdentifier.BASE_TENANT.toAppIdentifier(), device); long now = System.currentTimeMillis(); diff --git a/src/test/java/io/supertokens/storage/mysql/test/LoggingTest.java b/src/test/java/io/supertokens/storage/mysql/test/LoggingTest.java index 59a7372..294e982 100644 --- a/src/test/java/io/supertokens/storage/mysql/test/LoggingTest.java +++ b/src/test/java/io/supertokens/storage/mysql/test/LoggingTest.java @@ -283,6 +283,7 @@ public void confirmHikariLoggerClosedOnlyWhenProcessEnds() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), + new TotpConfig(false), null, null, config ), false); diff --git a/src/test/java/io/supertokens/storage/mysql/test/StorageLayerTest.java b/src/test/java/io/supertokens/storage/mysql/test/StorageLayerTest.java index e6414ea..02d6a74 100644 --- a/src/test/java/io/supertokens/storage/mysql/test/StorageLayerTest.java +++ b/src/test/java/io/supertokens/storage/mysql/test/StorageLayerTest.java @@ -82,7 +82,7 @@ public void totpCodeLengthTest() throws Exception { long now = System.currentTimeMillis(); long nextDay = now + 1000 * 60 * 60 * 24; // 1 day from now - TOTPDevice d1 = new TOTPDevice("user", "d1", "secret", 30, 1, false); + TOTPDevice d1 = new TOTPDevice("user", "d1", "secret", 30, 1, false, System.currentTimeMillis()); storage.createDevice(TenantIdentifier.BASE_TENANT.toAppIdentifier(), d1); // Try code with length > 8 diff --git a/src/test/java/io/supertokens/storage/mysql/test/multitenancy/StorageLayerTest.java b/src/test/java/io/supertokens/storage/mysql/test/multitenancy/StorageLayerTest.java index 2ed0dbe..128e43c 100644 --- a/src/test/java/io/supertokens/storage/mysql/test/multitenancy/StorageLayerTest.java +++ b/src/test/java/io/supertokens/storage/mysql/test/multitenancy/StorageLayerTest.java @@ -112,6 +112,7 @@ public void mergingTenantWithBaseConfigWorks() new TenantConfig(new TenantIdentifier("abc", null, null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -156,6 +157,7 @@ public void storageInstanceIsReusedAcrossTenants() new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -203,14 +205,17 @@ public void storageInstanceIsReusedAcrossTenantsComplex() new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, tenantConfig), new TenantConfig(new TenantIdentifier(null, "abc", "t1"), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, tenantConfig1), new TenantConfig(new TenantIdentifier(null, null, "t2"), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, tenantConfig1)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -275,6 +280,7 @@ public void mergingTenantWithBaseConfigWithInvalidConfigThrowsErrorWorks() new TenantConfig(new TenantIdentifier("abc", null, null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -308,6 +314,7 @@ public void mergingTenantWithBaseConfigWithConflictingConfigsThrowsError() new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -342,6 +349,7 @@ public void mergingDifferentConnectionPoolIdTenantWithBaseConfigWithConflictingC new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -377,6 +385,7 @@ public void mergingDifferentUserPoolIdTenantWithBaseConfigWithConflictingConfigs new TenantConfig(new TenantIdentifier("abc", null, null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -421,6 +430,7 @@ public void newStorageIsNotCreatedWhenSameTenantIsAdded() new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -468,6 +478,7 @@ public void testDifferentWaysToGetConfigBasedOnConnectionURIAndTenantId() tenants[0] = new TenantConfig(new TenantIdentifier("c1", null, null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, tenantConfig); } @@ -479,6 +490,7 @@ public void testDifferentWaysToGetConfigBasedOnConnectionURIAndTenantId() tenants[1] = new TenantConfig(new TenantIdentifier("c1", null, "t1"), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, tenantConfig); } @@ -488,6 +500,7 @@ public void testDifferentWaysToGetConfigBasedOnConnectionURIAndTenantId() tenants[2] = new TenantConfig(new TenantIdentifier(null, null, "t2"), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, tenantConfig); } @@ -497,6 +510,7 @@ public void testDifferentWaysToGetConfigBasedOnConnectionURIAndTenantId() tenants[3] = new TenantConfig(new TenantIdentifier(null, null, "t1"), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, tenantConfig); } @@ -558,6 +572,7 @@ public void multipleTenantsSameUserPoolAndConnectionPoolShouldWork() new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -593,6 +608,7 @@ public void multipleTenantsSameUserPoolAndDifferentConnectionPoolShouldWork() new TenantConfig(new TenantIdentifier(null, "abc", null), new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, tenantConfig)}; Config.loadAllTenantConfig(process.getProcess(), tenants); @@ -632,6 +648,7 @@ public void testCreating50StorageLayersUsage() new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, config); try { Multitenancy.addNewOrUpdateAppOrTenant(process.getProcess(), new TenantIdentifier(null, null, null), @@ -685,6 +702,7 @@ public void testCantCreateTenantWithUnknownDb() new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, tenantConfigJson); try { @@ -723,6 +741,7 @@ public void testTenantCreationAndThenDbDownDbThrowsErrorInRecipesAndDoesntAffect new EmailPasswordConfig(true), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, tenantConfigJson); StorageLayer.getMultitenancyStorage(process.getProcess()).createTenant(tenantConfig); @@ -794,6 +813,7 @@ public void testBadPortWithNewTenantShouldNotCauseItToWaitInput() throws Excepti new EmailPasswordConfig(false), new ThirdPartyConfig(false, new ThirdPartyConfig.Provider[0]), new PasswordlessConfig(false), + new TotpConfig(false), null, null, tenantConfigJson); try { diff --git a/src/test/java/io/supertokens/storage/mysql/test/multitenancy/TestUserPoolIdChangeBehaviour.java b/src/test/java/io/supertokens/storage/mysql/test/multitenancy/TestUserPoolIdChangeBehaviour.java index e012f70..162280f 100644 --- a/src/test/java/io/supertokens/storage/mysql/test/multitenancy/TestUserPoolIdChangeBehaviour.java +++ b/src/test/java/io/supertokens/storage/mysql/test/multitenancy/TestUserPoolIdChangeBehaviour.java @@ -82,6 +82,7 @@ public void testUsersWorkAfterUserPoolIdChanges() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), + new TotpConfig(false), null, null, coreConfig ), false); @@ -99,6 +100,7 @@ public void testUsersWorkAfterUserPoolIdChanges() throws Exception { new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), + new TotpConfig(false), null, null, coreConfig ), false); @@ -124,6 +126,7 @@ public void testUsersWorkAfterUserPoolIdChangesAndServerRestart() throws Excepti new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), + new TotpConfig(false), null, null, coreConfig ), false); @@ -141,6 +144,7 @@ public void testUsersWorkAfterUserPoolIdChangesAndServerRestart() throws Excepti new EmailPasswordConfig(true), new ThirdPartyConfig(true, null), new PasswordlessConfig(true), + new TotpConfig(false), null, null, coreConfig ), false);