Skip to content

Commit

Permalink
fix: pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sattvikc committed Oct 25, 2023
1 parent d632005 commit 9942400
Show file tree
Hide file tree
Showing 10 changed files with 258 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,14 @@ public String getTenantConfigsTable() {
return addSchemaAndPrefixToTableName("tenant_configs");
}

public String getFirstFactorsTable() {
return addSchemaAndPrefixToTableName("first_factors");
}

public String getDefaultRequiredFactorIdsTable() {
return addSchemaAndPrefixToTableName("default_required_factor_ids");
}

public String getTenantThirdPartyProvidersTable() {
return addSchemaAndPrefixToTableName("tenant_thirdparty_providers");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,25 @@ public static void createTablesIfNotExists(Start start) throws SQLException, Sto
NO_OP_SETTER);
}

if (!doesTableExists(start, Config.getConfig(start).getFirstFactorsTable())) {
getInstance(start).addState(CREATING_NEW_TABLE, null);
update(start, MultitenancyQueries.getQueryToCreateFirstFactorsTable(start), NO_OP_SETTER);

// index
update(start, MultitenancyQueries.getQueryToCreateTenantIdIndexForFirstFactorsTable(start),
NO_OP_SETTER);
}

if (!doesTableExists(start, Config.getConfig(start).getDefaultRequiredFactorIdsTable())) {
getInstance(start).addState(CREATING_NEW_TABLE, null);
update(start, MultitenancyQueries.getQueryToCreateDefaultRequiredFactorIdsTable(start), NO_OP_SETTER);

// index
update(start,
MultitenancyQueries.getQueryToCreateTenantIdIndexForDefaultRequiredFactorIdsTable(start),
NO_OP_SETTER);
}

if (!doesTableExists(start, Config.getConfig(start).getTenantThirdPartyProviderClientsTable())) {
getInstance(start).addState(CREATING_NEW_TABLE, null);
update(start, MultitenancyQueries.getQueryToCreateTenantThirdPartyProviderClientsTable(start),
Expand Down Expand Up @@ -563,6 +582,8 @@ public static void deleteAllTables(Start start) throws SQLException, StorageQuer
+ getConfig(start).getUserIdMappingTable() + ","
+ getConfig(start).getUsersTable() + ","
+ getConfig(start).getAccessTokenSigningKeysTable() + ","
+ getConfig(start).getFirstFactorsTable() + ","
+ getConfig(start).getDefaultRequiredFactorIdsTable() + ","
+ getConfig(start).getTenantConfigsTable() + ","
+ getConfig(start).getTenantThirdPartyProvidersTable() + ","
+ getConfig(start).getTenantThirdPartyProviderClientsTable() + ","
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

package io.supertokens.storage.postgresql.queries;

import io.supertokens.pluginInterface.emailpassword.exceptions.UnknownUserIdException;
import io.supertokens.pluginInterface.exceptions.StorageQueryException;
import io.supertokens.pluginInterface.exceptions.StorageTransactionLogicException;
import io.supertokens.pluginInterface.multitenancy.*;
import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException;
import io.supertokens.storage.postgresql.Start;
import io.supertokens.storage.postgresql.config.Config;
import io.supertokens.storage.postgresql.queries.multitenancy.MfaSqlHelper;
import io.supertokens.storage.postgresql.queries.multitenancy.TenantConfigSQLHelper;
import io.supertokens.storage.postgresql.queries.multitenancy.ThirdPartyProviderClientSQLHelper;
import io.supertokens.storage.postgresql.queries.multitenancy.ThirdPartyProviderSQLHelper;
Expand Down Expand Up @@ -50,8 +50,8 @@ static String getQueryToCreateTenantConfigsTable(Start start) {
+ "passwordless_enabled BOOLEAN,"
+ "third_party_enabled BOOLEAN,"
+ "totp_enabled BOOLEAN,"
+ "first_factors TEXT DEFAULT 'null',"
+ "default_required_factors TEXT DEFAULT 'null',"
+ "has_first_factors BOOLEAN DEFAULT FALSE,"
+ "has_default_required_factor_ids BOOLEAN DEFAULT FALSE,"
+ "CONSTRAINT " + Utils.getConstraintName(schema, tenantConfigsTable, null, "pkey") + " PRIMARY KEY (connection_uri_domain, app_id, tenant_id)"
+ ");";
// @formatter:on
Expand Down Expand Up @@ -122,6 +122,52 @@ public static String getQueryToCreateThirdPartyIdIndexForTenantThirdPartyProvide
+ getConfig(start).getTenantThirdPartyProviderClientsTable() + " (connection_uri_domain, app_id, tenant_id, third_party_id);";
}

public static String getQueryToCreateFirstFactorsTable(Start start) {
String schema = Config.getConfig(start).getTableSchema();
String tableName = Config.getConfig(start).getFirstFactorsTable();
// @formatter:off
return "CREATE TABLE IF NOT EXISTS " + tableName + " ("
+ "connection_uri_domain VARCHAR(256) DEFAULT '',"
+ "app_id VARCHAR(64) DEFAULT 'public',"
+ "tenant_id VARCHAR(64) DEFAULT 'public',"
+ "factor_id VARCHAR(64),"
+ "CONSTRAINT " + Utils.getConstraintName(schema, tableName, null, "pkey")
+ " PRIMARY KEY (connection_uri_domain, app_id, tenant_id, factor_id),"
+ "CONSTRAINT " + Utils.getConstraintName(schema, tableName, "tenant_id", "fkey")
+ " FOREIGN KEY (connection_uri_domain, app_id, tenant_id)"
+ " REFERENCES " + Config.getConfig(start).getTenantConfigsTable() + " (connection_uri_domain, app_id, tenant_id) ON DELETE CASCADE"
+ ");";
// @formatter:on
}

public static String getQueryToCreateTenantIdIndexForFirstFactorsTable(Start start) {
return "CREATE INDEX IF NOT EXISTS tenant_first_factors_tenant_id_index ON "
+ getConfig(start).getFirstFactorsTable() + " (connection_uri_domain, app_id, tenant_id);";
}

public static String getQueryToCreateDefaultRequiredFactorIdsTable(Start start) {
String schema = Config.getConfig(start).getTableSchema();
String tableName = Config.getConfig(start).getDefaultRequiredFactorIdsTable();
// @formatter:off
return "CREATE TABLE IF NOT EXISTS " + tableName + " ("
+ "connection_uri_domain VARCHAR(256) DEFAULT '',"
+ "app_id VARCHAR(64) DEFAULT 'public',"
+ "tenant_id VARCHAR(64) DEFAULT 'public',"
+ "factor_id VARCHAR(64),"
+ "CONSTRAINT " + Utils.getConstraintName(schema, tableName, null, "pkey")
+ " PRIMARY KEY (connection_uri_domain, app_id, tenant_id, factor_id),"
+ "CONSTRAINT " + Utils.getConstraintName(schema, tableName, "tenant_id", "fkey")
+ " FOREIGN KEY (connection_uri_domain, app_id, tenant_id)"
+ " REFERENCES " + Config.getConfig(start).getTenantConfigsTable() + " (connection_uri_domain, app_id, tenant_id) ON DELETE CASCADE"
+ ");";
// @formatter:on
}

public static String getQueryToCreateTenantIdIndexForDefaultRequiredFactorIdsTable(Start start) {
return "CREATE INDEX IF NOT EXISTS tenant_default_required_factor_ids_tenant_id_index ON "
+ getConfig(start).getDefaultRequiredFactorIdsTable() + " (connection_uri_domain, app_id, tenant_id);";
}

private static void executeCreateTenantQueries(Start start, Connection sqlCon, TenantConfig tenantConfig)
throws SQLException, StorageQueryException {

Expand All @@ -134,6 +180,9 @@ private static void executeCreateTenantQueries(Start start, Connection sqlCon, T
ThirdPartyProviderClientSQLHelper.create(start, sqlCon, tenantConfig, provider, providerClient);
}
}

MfaSqlHelper.createFirstFactors(start, sqlCon, tenantConfig.tenantIdentifier, tenantConfig.firstFactors);
MfaSqlHelper.createDefaultRequiredFactorIds(start, sqlCon, tenantConfig.tenantIdentifier, tenantConfig.defaultRequiredFactorIds);
}

public static void createTenantConfig(Start start, TenantConfig tenantConfig) throws StorageQueryException, StorageTransactionLogicException {
Expand Down Expand Up @@ -212,7 +261,13 @@ public static TenantConfig[] getAllTenants(Start start) throws StorageQueryExcep
// Map (tenantIdentifier) -> thirdPartyId -> provider
HashMap<TenantIdentifier, HashMap<String, ThirdPartyConfig.Provider>> providerMap = ThirdPartyProviderSQLHelper.selectAll(start, providerClientsMap);

return TenantConfigSQLHelper.selectAll(start, providerMap);
// Map (tenantIdentifier) -> firstFactors
HashMap<TenantIdentifier, String[]> firstFactorsMap = MfaSqlHelper.selectAllFirstFactors(start);

// Map (tenantIdentifier) -> defaultRequiredFactorIds
HashMap<TenantIdentifier, String[]> defaultRequiredFactorIdsMap = MfaSqlHelper.selectAllDefaultRequiredFactorIds(start);

return TenantConfigSQLHelper.selectAll(start, providerMap, firstFactorsMap, defaultRequiredFactorIdsMap);
} catch (SQLException throwables) {
throw new StorageQueryException(throwables);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved.
*
* This software is licensed under the Apache License, Version 2.0 (the
* "License") as published by the Apache Software Foundation.
*
* You may not use this file except in compliance with the License. You may
* obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package io.supertokens.storage.postgresql.queries.multitenancy;

import io.supertokens.pluginInterface.exceptions.StorageQueryException;
import io.supertokens.pluginInterface.multitenancy.TenantConfig;
import io.supertokens.pluginInterface.multitenancy.TenantIdentifier;
import io.supertokens.pluginInterface.multitenancy.ThirdPartyConfig;
import io.supertokens.storage.postgresql.Start;

import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

import static io.supertokens.storage.postgresql.QueryExecutorTemplate.execute;
import static io.supertokens.storage.postgresql.QueryExecutorTemplate.update;
import static io.supertokens.storage.postgresql.config.Config.getConfig;

public class MfaSqlHelper {
public static HashMap<TenantIdentifier, String[]> selectAllFirstFactors(Start start)
throws SQLException, StorageQueryException {
String QUERY = "SELECT connection_uri_domain, app_id, tenant_id, factor_id FROM "
+ getConfig(start).getFirstFactorsTable() + ";";
return execute(start, QUERY, pst -> {}, result -> {
HashMap<TenantIdentifier, List<String>> firstFactors = new HashMap<>();

while (result.next()) {
TenantIdentifier tenantIdentifier = new TenantIdentifier(result.getString("connection_uri_domain"), result.getString("app_id"), result.getString("tenant_id"));
if (!firstFactors.containsKey(tenantIdentifier)) {
firstFactors.put(tenantIdentifier, new ArrayList<>());
}

firstFactors.get(tenantIdentifier).add(result.getString("factor_id"));
}

HashMap<TenantIdentifier, String[]> finalResult = new HashMap<>();
for (TenantIdentifier tenantIdentifier : firstFactors.keySet()) {
finalResult.put(tenantIdentifier, firstFactors.get(tenantIdentifier).toArray(new String[0]));
}

return finalResult;
});
}

public static HashMap<TenantIdentifier, String[]> selectAllDefaultRequiredFactorIds(Start start)
throws SQLException, StorageQueryException {
String QUERY = "SELECT connection_uri_domain, app_id, tenant_id, factor_id FROM "
+ getConfig(start).getDefaultRequiredFactorIdsTable() + ";";
return execute(start, QUERY, pst -> {}, result -> {
HashMap<TenantIdentifier, List<String>> defaultRequiredFactors = new HashMap<>();

while (result.next()) {
TenantIdentifier tenantIdentifier = new TenantIdentifier(result.getString("connection_uri_domain"),
result.getString("app_id"), result.getString("tenant_id"));
if (!defaultRequiredFactors.containsKey(tenantIdentifier)) {
defaultRequiredFactors.put(tenantIdentifier, new ArrayList<>());
}

defaultRequiredFactors.get(tenantIdentifier).add(result.getString("factor_id"));
}

HashMap<TenantIdentifier, String[]> finalResult = new HashMap<>();
for (TenantIdentifier tenantIdentifier : defaultRequiredFactors.keySet()) {
finalResult.put(tenantIdentifier, defaultRequiredFactors.get(tenantIdentifier).toArray(new String[0]));
}

return finalResult;
});
}

public static void createFirstFactors(Start start, Connection sqlCon, TenantIdentifier tenantIdentifier, String[] firstFactors)
throws SQLException, StorageQueryException {
if (firstFactors == null || firstFactors.length == 0) {
return;
}

String QUERY = "INSERT INTO " + getConfig(start).getFirstFactorsTable() + "(connection_uri_domain, app_id, tenant_id, factor_id) VALUES (?, ?, ?, ?);";
for (String factorId : firstFactors) {
update(sqlCon, QUERY, pst -> {
pst.setString(1, tenantIdentifier.getConnectionUriDomain());
pst.setString(2, tenantIdentifier.getAppId());
pst.setString(3, tenantIdentifier.getTenantId());
pst.setString(4, factorId);
});
}
}

public static void createDefaultRequiredFactorIds(Start start, Connection sqlCon, TenantIdentifier tenantIdentifier, String[] defaultRequiredFactorIds)
throws SQLException, StorageQueryException {
if (defaultRequiredFactorIds == null || defaultRequiredFactorIds.length == 0) {
return;
}

String QUERY = "INSERT INTO " + getConfig(start).getDefaultRequiredFactorIdsTable() + "(connection_uri_domain, app_id, tenant_id, factor_id) VALUES (?, ?, ?, ?);";
for (String factorId : defaultRequiredFactorIds) {
update(sqlCon, QUERY, pst -> {
pst.setString(1, tenantIdentifier.getConnectionUriDomain());
pst.setString(2, tenantIdentifier.getAppId());
pst.setString(3, tenantIdentifier.getTenantId());
pst.setString(4, factorId);
});
}
}
}
Loading

0 comments on commit 9942400

Please sign in to comment.