Skip to content

Commit

Permalink
fix: mfa
Browse files Browse the repository at this point in the history
  • Loading branch information
sattvikc committed Nov 6, 2023
1 parent 90d05c1 commit 6c1283d
Show file tree
Hide file tree
Showing 13 changed files with 333 additions and 134 deletions.
141 changes: 20 additions & 121 deletions src/main/java/io/supertokens/storage/mysql/Start.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@
import io.supertokens.pluginInterface.jwt.JWTSigningKeyInfo;
import io.supertokens.pluginInterface.jwt.exceptions.DuplicateKeyIdException;
import io.supertokens.pluginInterface.jwt.sqlstorage.JWTRecipeSQLStorage;
import io.supertokens.pluginInterface.mfa.MfaStorage;
import io.supertokens.pluginInterface.mfa.sqlStorage.MfaSQLStorage;
import io.supertokens.pluginInterface.multitenancy.*;
import io.supertokens.pluginInterface.multitenancy.exceptions.DuplicateClientTypeException;
import io.supertokens.pluginInterface.multitenancy.exceptions.DuplicateTenantException;
Expand Down Expand Up @@ -108,7 +106,7 @@ public class Start
implements SessionSQLStorage, EmailPasswordSQLStorage, EmailVerificationSQLStorage, ThirdPartySQLStorage,
JWTRecipeSQLStorage, PasswordlessSQLStorage, UserMetadataSQLStorage, UserRolesSQLStorage, UserIdMappingStorage,
UserIdMappingSQLStorage, MultitenancyStorage, MultitenancySQLStorage, DashboardSQLStorage, TOTPSQLStorage,
ActiveUsersStorage, ActiveUsersSQLStorage, MfaStorage, MfaSQLStorage, AuthRecipeSQLStorage {
ActiveUsersStorage, ActiveUsersSQLStorage, AuthRecipeSQLStorage {

// these configs are protected from being modified / viewed by the dev using the SuperTokens
// SaaS. If the core is not running in SuperTokens SaaS, this array has no effect.
Expand Down Expand Up @@ -728,13 +726,6 @@ public boolean isUserIdBeingUsedInNonAuthRecipe(AppIdentifier appIdentifier, Str
return false;
} else if (className.equals(ActiveUsersStorage.class.getName())) {
return ActiveUsersQueries.getLastActiveByUserId(this, appIdentifier, userId) != null;
} else if (className.equals(MfaStorage.class.getName())) {
try {
MultitenancyQueries.getAllTenants(this);
return MfaQueries.listFactors(this, appIdentifier, userId).length > 0;
} catch (SQLException e) {
throw new StorageQueryException(e);
}
} else {
throw new IllegalStateException("ClassName: " + className + " is not part of NonAuthRecipeStorage");
}
Expand Down Expand Up @@ -805,7 +796,7 @@ public void addInfoToNonAuthRecipesBasedOnUserId(TenantIdentifier tenantIdentifi
}
} else if (className.equals(TOTPStorage.class.getName())) {
try {
TOTPDevice device = new TOTPDevice(userId, "testDevice", "secret", 0, 30, false);
TOTPDevice device = new TOTPDevice(userId, "testDevice", "secret", 0, 30, false, System.currentTimeMillis());
this.startTransaction(con -> {
try {
long now = System.currentTimeMillis();
Expand All @@ -830,12 +821,6 @@ public void addInfoToNonAuthRecipesBasedOnUserId(TenantIdentifier tenantIdentifi
} catch (SQLException e) {
throw new StorageQueryException(e);
}
} else if (className.equals(MfaStorage.class.getName())) {
try {
MfaQueries.enableFactor(this, tenantIdentifier, userId, "emailpassword");
} catch (SQLException e) {
throw new StorageQueryException(e);
}
} else {
throw new IllegalStateException("ClassName: " + className + " is not part of NonAuthRecipeStorage");
}
Expand Down Expand Up @@ -1295,44 +1280,6 @@ public int countUsersActiveSince(AppIdentifier appIdentifier, long time) throws
}
}

@Override
public int countUsersEnabledTotp(AppIdentifier appIdentifier) throws StorageQueryException {
try {
return ActiveUsersQueries.countUsersEnabledTotp(this, appIdentifier);
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public int countUsersEnabledTotpAndActiveSince(AppIdentifier appIdentifier, long time)
throws StorageQueryException {
try {
return ActiveUsersQueries.countUsersEnabledTotpAndActiveSince(this, appIdentifier, time);
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public int countUsersEnabledMfa(AppIdentifier appIdentifier) throws StorageQueryException {
try {
return ActiveUsersQueries.countUsersEnabledMfa(this, appIdentifier);
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public int countUsersEnabledMfaAndActiveSince(AppIdentifier appIdentifier, long time)
throws StorageQueryException {
try {
return ActiveUsersQueries.countUsersEnabledMfaAndActiveSince(this, appIdentifier, time);
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public void deleteUserActive_Transaction(TransactionConnection con, AppIdentifier appIdentifier, String userId)
throws StorageQueryException {
Expand Down Expand Up @@ -2770,72 +2717,6 @@ public int removeExpiredCodes(TenantIdentifier tenantIdentifier, long expiredBef
}
}


// MFA recipe:
@Override
public boolean enableFactor(TenantIdentifier tenantIdentifier, String userId, String factor)
throws StorageQueryException {
try {
int insertedCount = MfaQueries.enableFactor(this, tenantIdentifier, userId, factor);
if (insertedCount == 0) {
return false;
}
return true;
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public String[] listFactors(TenantIdentifier tenantIdentifier, String userId)
throws StorageQueryException {
try {
return MfaQueries.listFactors(this, tenantIdentifier, userId);
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public boolean disableFactor(TenantIdentifier tenantIdentifier, String userId, String factor)
throws StorageQueryException {
try {
int deletedCount = MfaQueries.disableFactor(this, tenantIdentifier, userId, factor);
if (deletedCount == 0) {
return false;
}
return true;
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public boolean deleteMfaInfoForUser_Transaction(TransactionConnection con, AppIdentifier appIdentifier, String userId) throws StorageQueryException {
try {
int deletedCount = MfaQueries.deleteUser_Transaction(this, (Connection) con.getConnection(), appIdentifier, userId);
if (deletedCount == 0) {
return false;
}
return true;
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public boolean deleteMfaInfoForUser(TenantIdentifier tenantIdentifier, String userId) throws StorageQueryException {
try {
int deletedCount = MfaQueries.deleteUser(this, tenantIdentifier, userId);
if (deletedCount == 0) {
return false;
}
return true;
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public Set<String> getValidFieldsInConfig() {
return MySQLConfig.getValidFields();
Expand Down Expand Up @@ -3053,6 +2934,24 @@ public UserIdMapping[] getUserIdMapping_Transaction(TransactionConnection con, A
}
}

@Override
public int getUsersCountWithMoreThanOneLoginMethodOrTOTPEnabled(AppIdentifier appIdentifier) throws StorageQueryException {
try {
return GeneralQueries.getUsersCountWithMoreThanOneLoginMethodOrTOTPEnabled(this, appIdentifier);
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

@Override
public int countUsersThatHaveMoreThanOneLoginMethodOrTOTPEnabledAndActiveSince(AppIdentifier appIdentifier, long sinceTime) throws StorageQueryException {
try {
return ActiveUsersQueries.countUsersThatHaveMoreThanOneLoginMethodOrTOTPEnabledAndActiveSince(this, appIdentifier, sinceTime);
} catch (SQLException e) {
throw new StorageQueryException(e);
}
}

public static boolean isEnabledForDeadlockTesting() {
return enableForDeadlockTesting;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ public String getTenantConfigsTable() {
return addPrefixToTableName("tenant_configs");
}

public String getTenantFirstFactorsTable() {
return addPrefixToTableName("tenant_first_factors");
}

public String getTenantDefaultRequiredFactorIdsTable() {
return addPrefixToTableName("tenant_default_required_factor_ids");
}

public String getTenantThirdPartyProvidersTable() {
return addPrefixToTableName("tenant_thirdparty_providers");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,41 @@ public static void deleteUserActive_Transaction(Connection con, Start start, App
pst.setString(2, userId);
});
}

public static int countUsersThatHaveMoreThanOneLoginMethodOrTOTPEnabledAndActiveSince(Start start, AppIdentifier appIdentifier, long sinceTime)
throws SQLException, StorageQueryException {
// TODO: Active users are present only on public tenant and MFA users may be present on different storages
String QUERY =
"SELECT COUNT (DISTINCT user_id) as c FROM ("
+ " (" // users with more than one login method
+ " SELECT primary_or_recipe_user_id AS user_id FROM ("
+ " SELECT COUNT(user_id) as num_login_methods, app_id, primary_or_recipe_user_id"
+ " FROM " + Config.getConfig(start).getAppIdToUserIdTable()
+ " WHERE app_id = ? AND primary_or_recipe_user_id IN ("
+ " SELECT user_id FROM " + Config.getConfig(start).getUserLastActiveTable()
+ " WHERE app_id = ? AND last_active_time >= ?"
+ " )"
+ " GROUP BY (app_id, primary_or_recipe_user_id)"
+ " ) AS nloginmethods"
+ " WHERE num_login_methods > 1"
+ " ) UNION (" // TOTP users
+ " SELECT user_id FROM " + Config.getConfig(start).getTotpUsersTable()
+ " WHERE app_id = ? AND user_id IN ("
+ " SELECT user_id FROM " + Config.getConfig(start).getUserLastActiveTable()
+ " WHERE app_id = ? AND last_active_time >= ?"
+ " )"
+ " )"
+ ") AS all_users";

return execute(start, QUERY, pst -> {
pst.setString(1, appIdentifier.getAppId());
pst.setString(2, appIdentifier.getAppId());
pst.setLong(3, sinceTime);
pst.setString(4, appIdentifier.getAppId());
pst.setString(5, appIdentifier.getAppId());
pst.setLong(6, sinceTime);
}, result -> {
return result.next() ? result.getInt("c") : 0;
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,21 @@ public static void createTablesIfNotExists(Start start) throws SQLException, Sto
update(start, MultitenancyQueries.getQueryToCreateTenantConfigsTable(start), NO_OP_SETTER);
}

if (!doesTableExists(start, Config.getConfig(start).getTenantFirstFactorsTable())) {
getInstance(start).addState(CREATING_NEW_TABLE, null);
update(start, MultitenancyQueries.getQueryToCreateFirstFactorsTable(start), NO_OP_SETTER);
}

if (!doesTableExists(start, Config.getConfig(start).getTenantDefaultRequiredFactorIdsTable())) {
getInstance(start).addState(CREATING_NEW_TABLE, null);
update(start, MultitenancyQueries.getQueryToCreateDefaultRequiredFactorIdsTable(start), NO_OP_SETTER);

// index
update(start,
MultitenancyQueries.getQueryToCreateOrderIndexForDefaultRequiredFactorIdsTable(start),
NO_OP_SETTER);
}

if (!doesTableExists(start, Config.getConfig(start).getTenantThirdPartyProvidersTable())) {
getInstance(start).addState(CREATING_NEW_TABLE, null);
update(start, MultitenancyQueries.getQueryToCreateTenantThirdPartyProvidersTable(start),
Expand Down Expand Up @@ -1551,6 +1566,32 @@ public static int getUsersCountWithMoreThanOneLoginMethod(Start start, AppIdenti
});
}

public static int getUsersCountWithMoreThanOneLoginMethodOrTOTPEnabled(Start start, AppIdentifier appIdentifier)
throws SQLException, StorageQueryException {
String QUERY =
"SELECT COUNT (DISTINCT user_id) as c FROM ("
+ " (" // Users with number of login methods > 1
+ " SELECT primary_or_recipe_user_id AS user_id FROM ("
+ " SELECT COUNT(user_id) as num_login_methods, app_id, primary_or_recipe_user_id"
+ " FROM " + Config.getConfig(start).getAppIdToUserIdTable()
+ " WHERE app_id = ? "
+ " GROUP BY (app_id, primary_or_recipe_user_id)"
+ " ) AS nloginmethods"
+ " WHERE num_login_methods > 1"
+ " ) UNION (" // TOTP users
+ " SELECT user_id FROM " + Config.getConfig(start).getTotpUsersTable()
+ " WHERE app_id = ?"
+ " )"
+ ") AS all_users";

return execute(start, QUERY, pst -> {
pst.setString(1, appIdentifier.getAppId());
pst.setString(2, appIdentifier.getAppId());
}, result -> {
return result.next() ? result.getInt("c") : 0;
});
}

public static boolean checkIfUsesAccountLinking(Start start, AppIdentifier appIdentifier)
throws SQLException, StorageQueryException {
String QUERY = "SELECT 1 FROM "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException;
import io.supertokens.storage.mysql.Start;
import io.supertokens.storage.mysql.config.Config;
import io.supertokens.storage.mysql.queries.multitenancy.MfaSqlHelper;
import io.supertokens.storage.mysql.queries.multitenancy.TenantConfigSQLHelper;
import io.supertokens.storage.mysql.queries.multitenancy.ThirdPartyProviderClientSQLHelper;
import io.supertokens.storage.mysql.queries.multitenancy.ThirdPartyProviderSQLHelper;
Expand All @@ -51,6 +52,9 @@ static String getQueryToCreateTenantConfigsTable(Start start) {
+ "email_password_enabled BOOLEAN,"
+ "passwordless_enabled BOOLEAN,"
+ "third_party_enabled BOOLEAN,"
+ "totp_enabled BOOLEAN,"
+ "has_first_factors BOOLEAN DEFAULT FALSE,"
+ "has_default_required_factor_ids BOOLEAN DEFAULT FALSE,"
+ "PRIMARY KEY (connection_uri_domain, app_id, tenant_id)"
+ ");";
// @formatter:on
Expand Down Expand Up @@ -107,6 +111,43 @@ static String getQueryToCreateTenantThirdPartyProviderClientsTable(Start start)
+ ");";
}

public static String getQueryToCreateFirstFactorsTable(Start start) {
String tableName = Config.getConfig(start).getTenantFirstFactorsTable();
// @formatter:off
return "CREATE TABLE IF NOT EXISTS " + tableName + " ("
+ "connection_uri_domain VARCHAR(256) DEFAULT '',"
+ "app_id VARCHAR(64) DEFAULT 'public',"
+ "tenant_id VARCHAR(64) DEFAULT 'public',"
+ "factor_id VARCHAR(128),"
+ "PRIMARY KEY (connection_uri_domain, app_id, tenant_id, factor_id),"
+ "FOREIGN KEY (connection_uri_domain, app_id, tenant_id)"
+ " REFERENCES " + Config.getConfig(start).getTenantConfigsTable() + " (connection_uri_domain, app_id, tenant_id) ON DELETE CASCADE"
+ ");";
// @formatter:on
}

public static String getQueryToCreateDefaultRequiredFactorIdsTable(Start start) {
String tableName = Config.getConfig(start).getTenantDefaultRequiredFactorIdsTable();
// @formatter:off
return "CREATE TABLE IF NOT EXISTS " + tableName + " ("
+ "connection_uri_domain VARCHAR(256) DEFAULT '',"
+ "app_id VARCHAR(64) DEFAULT 'public',"
+ "tenant_id VARCHAR(64) DEFAULT 'public',"
+ "factor_id VARCHAR(128),"
+ "order_idx INTEGER NOT NULL,"
+ "PRIMARY KEY (connection_uri_domain, app_id, tenant_id, factor_id),"
+ "FOREIGN KEY (connection_uri_domain, app_id, tenant_id)"
+ " REFERENCES " + Config.getConfig(start).getTenantConfigsTable() + " (connection_uri_domain, app_id, tenant_id) ON DELETE CASCADE,"
+ " UNIQUE (connection_uri_domain, app_id, tenant_id, order_idx)"
+ ");";
// @formatter:on
}

public static String getQueryToCreateOrderIndexForDefaultRequiredFactorIdsTable(Start start) {
return "CREATE INDEX IF NOT EXISTS tenant_default_required_factor_ids_tenant_id_index ON "
+ getConfig(start).getTenantDefaultRequiredFactorIdsTable() + " (order_idx ASC);";
}

private static void executeCreateTenantQueries(Start start, Connection sqlCon, TenantConfig tenantConfig)
throws SQLException, StorageTransactionLogicException {

Expand Down Expand Up @@ -221,7 +262,13 @@ public static TenantConfig[] getAllTenants(Start start) throws StorageQueryExcep
// Map (tenantIdentifier) -> thirdPartyId -> provider
HashMap<TenantIdentifier, HashMap<String, ThirdPartyConfig.Provider>> providerMap = ThirdPartyProviderSQLHelper.selectAll(start, providerClientsMap);

return TenantConfigSQLHelper.selectAll(start, providerMap);
// Map (tenantIdentifier) -> firstFactors
HashMap<TenantIdentifier, String[]> firstFactorsMap = MfaSqlHelper.selectAllFirstFactors(start);

// Map (tenantIdentifier) -> defaultRequiredFactorIds
HashMap<TenantIdentifier, String[]> defaultRequiredFactorIdsMap = MfaSqlHelper.selectAllDefaultRequiredFactorIds(start);

return TenantConfigSQLHelper.selectAll(start, providerMap, firstFactorsMap, defaultRequiredFactorIdsMap);
} catch (SQLException throwables) {
throw new StorageQueryException(throwables);
}
Expand Down
Loading

0 comments on commit 6c1283d

Please sign in to comment.