diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 2917bd58..08549481 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -11,6 +11,8 @@ jobs:
steps:
- uses: actions/checkout@v2
+ with:
+ submodules: true
- name: Setup .NET
uses: actions/setup-dotnet@v1
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 00000000..88f37299
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "tests/EdgeDB.Tests.Unit/shared-client-testcases"]
+ path = tests/EdgeDB.Tests.Unit/shared-client-testcases
+ url = https://github.com/edgedb/shared-client-testcases.git
diff --git a/src/EdgeDB.Net.Driver/EdgeDBConnection.cs b/src/EdgeDB.Net.Driver/EdgeDBConnection.cs
index b576cf2b..fd70e2df 100644
--- a/src/EdgeDB.Net.Driver/EdgeDBConnection.cs
+++ b/src/EdgeDB.Net.Driver/EdgeDBConnection.cs
@@ -10,6 +10,40 @@
namespace EdgeDB;
+///
+/// A json readable representation of an EdgeDBConnection.
+/// When using Credentials to create an EdgeDBConnection, the data must conform to this type.
+///
+internal class ConnectionCredentials
+{
+ [JsonProperty("host")]
+ public string? Host { get; init; }
+
+ [JsonProperty("port")]
+ [JsonConverter(typeof(AsStringConverter))]
+ public string? Port { get; init; }
+
+ [JsonProperty("database")]
+ public string? Database { get; init; }
+
+ [JsonProperty("branch")]
+ public string? Branch { get; init; }
+
+ [JsonProperty("user")]
+ public string? User { get; init; }
+
+ [JsonProperty("password")]
+ public string? Password { get; init; }
+
+ [JsonProperty("tls_ca")]
+ public string? TlsCA { get; init; }
+
+ [JsonProperty("tls_security")]
+ [JsonConverter(typeof(TLSSecurityModeParser))]
+ public TLSSecurityMode? TlsSecurity { get; init; }
+}
+
+
///
/// Represents a class containing information on how to connect to a edgedb instance.
///
@@ -18,16 +52,19 @@ public sealed class EdgeDBConnection
private const string INSTANCE_ENV_NAME = "INSTANCE";
private const string DSN_ENV_NAME = "DSN";
private const string CREDENTIALS_FILE_ENV_NAME = "CREDENTIALS_FILE";
- private const string USER_ENV_NAME = "USER";
- private const string PASSWORD_ENV_NAME = "PASSWORD";
- private const string DATABASE_ENV_NAME = "DATABASE";
- private const string BRANCH_ENV_NAME = "BRANCH";
private const string HOST_ENV_NAME = "HOST";
private const string PORT_ENV_NAME = "PORT";
+ private const string DATABASE_ENV_NAME = "DATABASE";
+ private const string BRANCH_ENV_NAME = "BRANCH";
+ private const string USER_ENV_NAME = "USER";
+ private const string PASSWORD_ENV_NAME = "PASSWORD";
+ private const string SECRET_KEY_ENV_NAME = "SECRET_KEY";
+ private const string TLS_CA_ENV_NAME = "TLS_CA";
private const string CLIENT_SECURITY_ENV_NAME = "CLIENT_SECURITY";
private const string CLIENT_TLS_SECURITY_ENV_NAME = "CLIENT_TLS_SECURITY";
+ private const string TLS_SERVER_NAME_ENV_NAME = "TLS_SERVER_NAME";
+ private const string WAIT_UNTIL_AVAILABLE_ENV_NAME = "WAIT_UNTIL_AVAILABLE";
private const string CLOUD_PROFILE_ENV_NAME = "CLOUD_PROFILE";
- private const string SECRET_KEY_ENV_NAME = "SECRET_KEY";
private const int DOMAIN_NAME_MAX_LEN = 62;
private EdgeDBConnection MergeInto(EdgeDBConnection other)
@@ -295,31 +332,544 @@ public string CloudProfile
#endregion
- #region Construct methods
+ #region Create Function
+
+ ///
+ /// Optional args which can be passed into .
+ ///
+ public class Options
+ {
+ // Primary args
+ // These can set host/port of the connection.
+ public string? Instance { get; set; }
+ public string? Dsn { get; set; }
+ public string? Host { get; set; }
+ public int? Port { get; set; }
+
+ // Secondary args
+ public string? Database { get; set; }
+ public string? Branch { get; set; }
+ public string? User { get; set; }
+ public string? Password { get; set; }
+ public string? SecretKey { get; set; }
+ public string? Credentials { get; set; }
+ public string? CredentialsFile { get; set; }
+ public string? TLSCertificateAuthority { get; set; }
+ public string? TLSCertificateAuthorityFile { get; set; }
+ public TLSSecurityMode? TLSSecurity { get; set; }
+ public string? TLSServerName { get; set; }
+ public string? WaitUntilAvailable { get; set; }
+ public Dictionary? ServerSettings { get; set; }
+
+ public bool IsEmpty =>
+ Instance is null
+ && Dsn is null
+ && Host is null
+ && Port is null
+ && Database is null
+ && Branch is null
+ && User is null
+ && Password is null
+ && SecretKey is null
+ && Credentials is null
+ && CredentialsFile is null
+ && TLSCertificateAuthority is null
+ && TLSCertificateAuthorityFile is null
+ && TLSSecurity is null
+ && TLSServerName is null
+ && WaitUntilAvailable is null
+ && ServerSettings is null;
+ }
+
+ ///
+ /// Parses the `gel.toml`, optional , and environment variables to build an
+ /// .
+ ///
+ /// This function will first search for the first valid primary args (which can set host/port)
+ /// in the following order:
+ /// -
+ /// - Environment variables
+ /// - `gel.toml` file
+ ///
+ /// It will then apply any secondary args from the environment variables and options.
+ ///
+ /// If any primary are present, then all environment variables are ignored.
+ ///
+ /// See the documentation
+ /// for more information.
+ ///
+ /// Options used to build the .
+ ///
+ /// A class that can be used to connect to a EdgeDB instance.
+ ///
+ ///
+ /// An error occured while parsing or configuring the .
+ ///
+ public static EdgeDBConnection Create(Options? options = null)
+ {
+ return _Create(options ?? new(), null);
+ }
+
+ internal static EdgeDBConnection _Create(Options options, ISystemProvider? platform)
+ {
+ platform ??= ConfigUtils.DefaultPlatformProvider;
+
+ ConfigUtils.ResolvedFields resolvedFields = new();
+
+ #region Primary Options
+
+ // First, check primary options
+ // If any primary options are present, environment variables are ignored
+
+ bool hasPrimaryOptions = false;
+ {
+ // These options can set host/port and should be resolved first
+ // More than one primary options should raise an error
+
+ Exception primaryError = new ConfigurationException(
+ "Connection options cannot have more than one of the following "
+ + "values: \"Instance\", \"Dsn\", \"Credentials\", "
+ + "\"CredentialsFile\" or \"Host\"/\"Port\"");
+ // The primaryError has priority, so hold on to any other exception
+ // until all primary options are processed.
+ Exception? deferredPrimaryError = null;
+
+ if (options.Instance is not null)
+ {
+ if (hasPrimaryOptions) { throw primaryError; }
+ if (options.Instance != "")
+ {
+ try
+ {
+ var fromDSN = _FromInstanceName(options.Instance, null, platform);
+ resolvedFields.MergeFrom(fromDSN);
+ }
+ catch (Exception e)
+ {
+ deferredPrimaryError = e;
+ }
+ }
+ else
+ {
+ deferredPrimaryError = new ConfigurationException(
+ $"Invalid instance name: \"{options.Instance}\"");
+ }
+ hasPrimaryOptions = true;
+ }
+
+ if (options.Dsn is not null)
+ {
+ if (hasPrimaryOptions) { throw primaryError; }
+ var fromDSN = _FromDSN(options.Dsn, platform);
+ resolvedFields.MergeFrom(fromDSN);
+ hasPrimaryOptions = true;
+ }
+
+ {
+ string? credentialsText = null;
+ if (options.Credentials is not null)
+ {
+ if (hasPrimaryOptions) { throw primaryError; }
+ credentialsText = options.Credentials;
+ hasPrimaryOptions = true;
+ }
+ if (options.CredentialsFile is not null)
+ {
+ if (hasPrimaryOptions) { throw primaryError; }
+ if (platform.FileExists(options.CredentialsFile))
+ {
+ credentialsText = platform.FileReadAllText(options.CredentialsFile) ?? "{}";
+ }
+ else
+ {
+ deferredPrimaryError = new ConfigurationException(
+ $"Invalid CredentialsFile: \"{options.CredentialsFile}\", could not find file");
+ }
+ hasPrimaryOptions = true;
+ }
+ if (credentialsText is not null)
+ {
+ try
+ {
+ ConnectionCredentials? credentials =
+ new JsonSerializer().DeserializeObject(credentialsText);
+ if (credentials is not null)
+ {
+ resolvedFields.MergeFrom(ConfigUtils.ResolvedFields.FromCredentials(credentials));
+ }
+ }
+ catch (JsonException)
+ {
+ deferredPrimaryError = new ConfigurationException("Invalid Credentials: could not parse json");
+ }
+ }
+ }
+
+ {
+ bool hasHostOrPort = false;
+ if (options.Host is not null)
+ {
+ if (hasPrimaryOptions) { throw primaryError; }
+ resolvedFields.Host = options.Host;
+ hasHostOrPort = true;
+ }
+ if (options.Port is not null)
+ {
+ if (hasPrimaryOptions) { throw primaryError; }
+ resolvedFields.Port = options.Port;
+ hasHostOrPort = true;
+ }
+ if (hasHostOrPort)
+ {
+ hasPrimaryOptions = true;
+ }
+ }
+
+ if (deferredPrimaryError is not null)
+ {
+ throw deferredPrimaryError;
+ }
+ }
+
+ #endregion
+
+ #region Primary Env
+
+ var envName = string.Empty;
+ var envVar = string.Empty;
+
+ bool hasPrimaryEnv = false;
+
+ if (!hasPrimaryOptions)
+ {
+ // These env vars can set host/port and should be resolved first
+ // More than one primary env var should raise an error
+
+ Exception primaryError = new ConfigurationException(
+ "Cannot have more than one of the following connection "
+ + "environment variables: \"GEL_DSN\", \"GEL_INSTANCE\", "
+ + "\"GEL_CREDENTIALS_FILE\" or \"GEL_HOST\"/\"GEL_PORT\"");
+ // The primaryError has priority, so hold on to any other exception
+ // until all primary env vars are processed.
+ Exception? deferredPrimaryError = null;
+
+ if (platform.GetGelEnvVariable(INSTANCE_ENV_NAME, out envName, out envVar))
+ {
+ if (hasPrimaryEnv) { throw primaryError; }
+ try
+ {
+ var fromInst = _FromInstanceName(envVar, null, platform);
+ resolvedFields.MergeFrom(fromInst);
+ }
+ catch (Exception e)
+ {
+ deferredPrimaryError = e;
+ }
+ hasPrimaryEnv = true;
+ }
+
+ if (platform.GetGelEnvVariable(DSN_ENV_NAME, out envName, out envVar))
+ {
+ if (hasPrimaryEnv) { throw primaryError; }
+ var fromDSN = _FromDSN(envVar, platform);
+ resolvedFields.MergeFrom(fromDSN);
+ hasPrimaryEnv = true;
+ }
+
+ if (platform.GetGelEnvVariable(CREDENTIALS_FILE_ENV_NAME, out envName, out envVar))
+ {
+ if (hasPrimaryEnv) { throw primaryError; }
+ if (platform.FileExists(envVar))
+ {
+ var credentials =
+ JsonConvert.DeserializeObject(platform.FileReadAllText(envVar))!;
+ resolvedFields.MergeFrom(ConfigUtils.ResolvedFields.FromCredentials(credentials));
+ }
+ else
+ {
+ deferredPrimaryError = new FileNotFoundException(
+ $"Invalid credential file from {envName}: \"{envVar}\", could not find file");
+ }
+ hasPrimaryEnv = true;
+ }
+
+ {
+ bool hasHostOrPort = false;
+ if (platform.GetGelEnvVariable(HOST_ENV_NAME, out envName, out envVar))
+ {
+ if (hasPrimaryEnv) { throw primaryError; }
+ resolvedFields.Host = envVar;
+ hasHostOrPort = true;
+ }
+ if (platform.GetGelEnvVariable(PORT_ENV_NAME, out envName, out envVar))
+ {
+ ConfigUtils.ResolvedField? port = ConfigUtils.ParsePort(envVar);
+ if (port is not null)
+ {
+ if (hasPrimaryEnv) { throw primaryError; }
+ resolvedFields.Port = ConfigUtils.MergeField(resolvedFields.Port, port);
+ hasHostOrPort = true;
+ }
+ }
+ if (hasHostOrPort)
+ {
+ hasPrimaryEnv = true;
+ }
+ }
+
+ if (deferredPrimaryError is not null)
+ {
+ throw deferredPrimaryError;
+ }
+ }
+
+ #endregion
+
+ #region Toml File
+
+ if (!hasPrimaryOptions && !hasPrimaryEnv)
+ {
+ ConfigUtils.ResolvedFields? fromToml = _ResolveEdgeDBTOML(platform);
+ if (fromToml is not null)
+ {
+ resolvedFields.MergeFrom(fromToml);
+ }
+ }
+
+ #endregion
+
+ #region Secondary Env
+
+ if (!hasPrimaryOptions)
+ {
+ if (platform.GetGelEnvVariable(DATABASE_ENV_NAME, out envName, out envVar))
+ {
+ var altName = string.Empty;
+ var altVal = string.Empty;
+ if (platform.GetGelEnvVariable(BRANCH_ENV_NAME, out altName, out altVal))
+ {
+ throw new ConfigurationException(
+ $"Environment variables {envName} and {altName} are mutually exclusive");
+ }
+
+ resolvedFields.DatabaseOrBranch = new ConfigUtils.DatabaseOrBranch.DatabaseName(envVar);
+ }
+
+ if (platform.GetGelEnvVariable(BRANCH_ENV_NAME, out envName, out envVar))
+ {
+ resolvedFields.DatabaseOrBranch = new ConfigUtils.DatabaseOrBranch.BranchName(envVar);
+ }
+
+ if (platform.GetGelEnvVariable(USER_ENV_NAME, out envName, out envVar))
+ {
+ resolvedFields.User = envVar;
+ }
+
+ if (platform.GetGelEnvVariable(PASSWORD_ENV_NAME, out envName, out envVar))
+ {
+ resolvedFields.Password = envVar;
+ }
+
+ if (platform.GetGelEnvVariable(TLS_CA_ENV_NAME, out envName, out envVar))
+ {
+ resolvedFields.TLSCertificateAuthority = envVar;
+ }
+
+ {
+ string clientSecurityEnvName;
+ string clientTlsSecurityEnvName;
+ TLSSecurityMode? clientSecurity = null;
+ TLSSecurityMode? clientTlsSecurity = null;
+ bool hasDefault = false;
+ if (platform.GetGelEnvVariable(CLIENT_SECURITY_ENV_NAME, out clientSecurityEnvName, out envVar))
+ {
+ if (TLSSecurityModeParser.TryParse(envVar, true, out clientSecurity))
+ {
+ if (clientSecurity is not null)
+ {
+ resolvedFields.TLSSecurity = clientSecurity;
+ }
+ else
+ {
+ hasDefault = true;
+ }
+ }
+ else
+ {
+ resolvedFields.TLSSecurity = new ConfigurationException(
+ $"Invalid TLS Security from {clientSecurityEnvName}: \"{envVar}\"");
+ }
+ }
+ if (platform.GetGelEnvVariable(CLIENT_TLS_SECURITY_ENV_NAME, out clientTlsSecurityEnvName, out envVar))
+ {
+ if (TLSSecurityModeParser.TryParse(envVar, true, out clientTlsSecurity))
+ {
+ if (clientTlsSecurity is null)
+ {
+ hasDefault = true;
+ }
+ else if (clientSecurity is null)
+ {
+ // overwrite default value
+ resolvedFields.TLSSecurity = clientTlsSecurity.Value;
+ }
+ else if (clientSecurity == TLSSecurityMode.Strict
+ && clientTlsSecurity != TLSSecurityMode.Strict)
+ {
+ throw new ConfigurationException(
+ $"{clientSecurityEnvName}=strict but {clientTlsSecurityEnvName}={envVar}. "
+ + $"{clientTlsSecurityEnvName} must be strict when {clientSecurityEnvName} "
+ + $"is strict"
+ );
+ }
+ else
+ {
+ // overwrite existing value
+ resolvedFields.TLSSecurity = clientTlsSecurity.Value;
+ }
+ }
+ else
+ {
+ resolvedFields.TLSSecurity = new ConfigurationException(
+ $"Invalid TLS Security from {clientTlsSecurityEnvName}: \"{envVar}\"");
+ }
+ }
+ if (hasDefault)
+ {
+ // finally, apply default value if no non-default value or error present
+ resolvedFields.TLSSecurity ??= TLSSecurityMode.Default;
+ }
+ }
+
+ if (platform.GetGelEnvVariable(TLS_SERVER_NAME_ENV_NAME, out envName, out envVar))
+ {
+ resolvedFields.TLSServerName = envVar;
+ }
+
+ if (platform.GetGelEnvVariable(WAIT_UNTIL_AVAILABLE_ENV_NAME, out envName, out envVar))
+ {
+ resolvedFields.WaitUntilAvailable = ConfigUtils.ParseWaitUntilAvailable(envVar);
+ }
+ }
+
+ #endregion
+
+ #region Secondary Options
+
+ // Finally, check secondary options
+ // Secondary options should override environment variables
+
+ if (options.Database is not null && options.Branch is not null)
+ {
+ throw new ConfigurationException("Invalid options: Database and Branch are mutually exclusive.");
+ }
+ else if (options.Database is not null)
+ {
+ resolvedFields.DatabaseOrBranch = new ConfigUtils.DatabaseOrBranch.DatabaseName(options.Database);
+ }
+ else if (options.Branch is not null)
+ {
+ resolvedFields.DatabaseOrBranch = new ConfigUtils.DatabaseOrBranch.BranchName(options.Branch);
+ }
+
+ if (options.User is not null) { resolvedFields.User = options.User; }
+ if (options.Password is not null) { resolvedFields.Password = options.Password; }
+ if (options.SecretKey is not null) { resolvedFields.SecretKey = options.SecretKey; }
+ if (options.TLSCertificateAuthority is not null)
+ {
+ resolvedFields.TLSCertificateAuthority = options.TLSCertificateAuthority;
+ }
+ if (options.TLSCertificateAuthorityFile is not null)
+ {
+ if (platform.FileExists(options.TLSCertificateAuthorityFile))
+ {
+ resolvedFields.TLSCertificateAuthority =
+ platform.FileReadAllText(options.TLSCertificateAuthorityFile);
+ }
+ else
+ {
+ throw new ConfigurationException(
+ $"Invalid TLSCertificateAuthorityFile: \"{options.TLSCertificateAuthorityFile}\", could not find file");
+ }
+ }
+ if (options.TLSSecurity is not null) { resolvedFields.TLSSecurity = options.TLSSecurity; }
+ if (options.TLSServerName is not null) { resolvedFields.TLSServerName = options.TLSServerName; }
+ if (options.WaitUntilAvailable is not null)
+ {
+ resolvedFields.WaitUntilAvailable = ConfigUtils.MergeField(
+ resolvedFields.WaitUntilAvailable,
+ ConfigUtils.ParseWaitUntilAvailable(options.WaitUntilAvailable));
+ }
+ if (options.ServerSettings is not null)
+ {
+ foreach (KeyValuePair entry in options.ServerSettings)
+ {
+ resolvedFields.ServerSettings = ConfigUtils.AddServerSettingField(
+ resolvedFields.ServerSettings, entry.Key, entry.Value);
+ }
+ }
+
+ #endregion
+
+ if (options.IsEmpty && resolvedFields.IsEmpty)
+ {
+ throw new ConfigurationException("No `gel.toml` found and no connection options specified.");
+ }
+
+ return _FromResolvedFields(resolvedFields, platform);
+ }
+
+ #endregion
+
+ #region Create Helpers
internal static EdgeDBConnection _FromResolvedFields(ConfigUtils.ResolvedFields resolvedFields, ISystemProvider? platform)
{
platform ??= ConfigUtils.DefaultPlatformProvider;
- if (resolvedFields.Host?.Value is not null && resolvedFields.Host.Value.Contains(','))
+ if (ConfigUtils.TryGetFieldValue(resolvedFields.Host, out string host))
{
- throw new ConfigurationException(
- $"Invalid host: \"{resolvedFields.Host.Value}\", DSN cannot contain more than one host");
+ if (host.Contains(','))
+ {
+ throw new ConfigurationException(
+ $"Invalid host: \"{host}\", DSN cannot contain more than one host");
+ }
+ if (host == "")
+ {
+ throw new ConfigurationException($"Invalid host: \"{host}\"");
+ }
+ if (host.StartsWith("/"))
+ {
+ throw new ConfigurationException($"Invalid host: \"{host}\", unix socket paths not supported");
+ }
}
- if (resolvedFields.DatabaseOrBranch?.Value?.Value == "")
+ if (ConfigUtils.TryGetFieldValue(resolvedFields.Port, out int port))
{
- throw resolvedFields.DatabaseOrBranch.Value switch
+ if (port < 1 || 65535 < port)
{
- ConfigUtils.DatabaseOrBranch.DatabaseName name => new ConfigurationException(
- $"Invalid database name: \"{name.Value}\""),
- ConfigUtils.DatabaseOrBranch.BranchName name => new ConfigurationException(
- $"Invalid branch name: \"{name.Value}\""),
- _ => new ConfigurationException("Invalid database or branch name"),
- };
+ throw new ConfigurationException($"Invalid port: \"{port}\", must be between 1 and 65535");
+ }
}
- if (resolvedFields.User == "")
+ if (ConfigUtils.TryGetFieldValue(resolvedFields.DatabaseOrBranch, out ConfigUtils.DatabaseOrBranch databaseOrBranch))
{
- throw new ConfigurationException($"Invalid user: \"{resolvedFields.User.Value}\"");
+ if (databaseOrBranch.Value == "")
+ {
+ throw databaseOrBranch switch
+ {
+ ConfigUtils.DatabaseOrBranch.DatabaseName name => new ConfigurationException(
+ $"Invalid database name: \"{name.Value}\""),
+ ConfigUtils.DatabaseOrBranch.BranchName name => new ConfigurationException(
+ $"Invalid branch name: \"{name.Value}\""),
+ _ => new ConfigurationException("Invalid database or branch name"),
+ };
+ }
+ }
+ if (ConfigUtils.TryGetFieldValue(resolvedFields.User, out string user))
+ {
+ if (user == "")
+ {
+ throw new ConfigurationException($"Invalid user: \"{user}\"");
+ }
}
return new()
@@ -525,10 +1075,9 @@ internal static ConfigUtils.ResolvedFields _FromDSN(string dsn, ISystemProvider?
string key = arg.Key;
ConfigUtils.ResolvedField value = arg.Value;
- if (key.EndsWith("_env"))
+ if (key.EndsWith("_env") && ConfigUtils.TryGetFieldValue(value, out string envName))
{
string oldKey = key;
- string envName = value.Value!;
key = key.Substring(0, key.Length - "_env".Length);
string? envVar = platform.GetEnvVariable(envName);
if (envVar is not null)
@@ -542,10 +1091,9 @@ internal static ConfigUtils.ResolvedFields _FromDSN(string dsn, ISystemProvider?
}
}
- if (key.EndsWith("_file") && value.Value is not null)
+ if (key.EndsWith("_file") && ConfigUtils.TryGetFieldValue(value, out string fileName))
{
string oldKey = key;
- string fileName = value.Value!;
key = key.Substring(0, key.Length - "_file".Length);
if (platform.FileExists(fileName))
{
@@ -635,17 +1183,7 @@ internal static ConfigUtils.ResolvedFields _FromDSN(string dsn, ISystemProvider?
});
break;
case "wait_until_available":
- resolvedFields.WaitUntilAvailable = value.Convert(v =>
- {
- try
- {
- return ConfigUtils.ParseWaitUntilAvailable(v);
- }
- catch (Exception e)
- {
- return e;
- }
- });
+ resolvedFields.WaitUntilAvailable = value.Convert(ConfigUtils.ParseWaitUntilAvailable);
break;
default:
@@ -667,10 +1205,10 @@ internal static ConfigUtils.ResolvedFields _FromDSN(string dsn, ISystemProvider?
/// The project directory doesn't exist for the supplied toml file.
public static EdgeDBConnection FromProjectFile(string path)
{
- return _FromProjectFile(path, null);
+ return _FromResolvedFields(_FromProjectFile(path, null), null);
}
- internal static EdgeDBConnection _FromProjectFile(string path, ISystemProvider? platform)
+ internal static ConfigUtils.ResolvedFields _FromProjectFile(string path, ISystemProvider? platform)
{
platform ??= ConfigUtils.DefaultPlatformProvider;
@@ -690,12 +1228,14 @@ internal static EdgeDBConnection _FromProjectFile(string path, ISystemProvider?
if (!ConfigUtils.TryResolveInstanceCloudProfile(projectDir, out var profile, out var inst, platform) || inst is null)
throw new FileNotFoundException($"Could not find instance name under project directory {projectDir}");
- var connection = _FromInstanceName(inst, profile, platform);
+ var resolvedFields = _FromInstanceName(inst, profile, platform);
- if (ConfigUtils.TryResolveProjectDatabase(projectDir, out var database, platform))
- connection.Database = database;
+ if (ConfigUtils.TryResolveProjectDatabase(projectDir, out var database, platform) && database is not null)
+ {
+ resolvedFields.DatabaseOrBranch = new ConfigUtils.DatabaseOrBranch.DatabaseName(database);
+ }
- return connection;
+ return resolvedFields;
}
///
@@ -714,10 +1254,10 @@ internal static EdgeDBConnection _FromProjectFile(string path, ISystemProvider?
/// The configuration is invalid.
public static EdgeDBConnection FromInstanceName(string name, string? cloudProfile = null)
{
- return _FromInstanceName(name, cloudProfile, null);
+ return _FromResolvedFields(_FromInstanceName(name, cloudProfile, null), null);
}
- internal static EdgeDBConnection _FromInstanceName(string name, string? cloudProfile, ISystemProvider? platform)
+ internal static ConfigUtils.ResolvedFields _FromInstanceName(string name, string? cloudProfile, ISystemProvider? platform)
{
platform ??= ConfigUtils.DefaultPlatformProvider;
@@ -725,16 +1265,20 @@ internal static EdgeDBConnection _FromInstanceName(string name, string? cloudPro
{
var configPath = platform.CombinePaths(ConfigUtils.GetCredentialsDir(platform), $"{name}.json");
- return !platform.FileExists(configPath)
- ? throw new FileNotFoundException($"Config file couldn't be found at {configPath}")
- : JsonConvert.DeserializeObject(platform.FileReadAllText(configPath))!;
+ if (!platform.FileExists(configPath))
+ {
+ throw new FileNotFoundException($"Config file couldn't be found at {configPath}");
+ }
+
+ ConnectionCredentials credentials = JsonConvert.DeserializeObject(
+ platform.FileReadAllText(configPath))!;
+
+ return ConfigUtils.ResolvedFields.FromCredentials(credentials);
}
if (Regex.IsMatch(name, @"^([A-Za-z0-9](-?[A-Za-z0-9])*)\/([A-Za-z0-9](-?[A-Za-z0-9])*)$"))
{
- var conn = new EdgeDBConnection();
- conn.ParseCloudInstanceName(name, cloudProfile, platform);
- return conn;
+ return ParseCloudInstanceName(name, null, cloudProfile, platform);
}
throw new ConfigurationException($"Invalid instance name '{name}'");
@@ -748,10 +1292,15 @@ internal static EdgeDBConnection _FromInstanceName(string name, string? cloudPro
/// No 'edgedb.toml' file could be found.
public static EdgeDBConnection ResolveEdgeDBTOML()
{
- return _ResolveEdgeDBTOML(null);
+ ConfigUtils.ResolvedFields? resolvedFields = _ResolveEdgeDBTOML(null);
+ if (resolvedFields is null)
+ {
+ throw new ConfigurationException("Couldn't resolve gel.toml file");
+ }
+ return _FromResolvedFields(resolvedFields, null);
}
- internal static EdgeDBConnection _ResolveEdgeDBTOML(ISystemProvider? platform)
+ internal static ConfigUtils.ResolvedFields? _ResolveEdgeDBTOML(ISystemProvider? platform)
{
platform ??= ConfigUtils.DefaultPlatformProvider;
@@ -759,30 +1308,32 @@ internal static EdgeDBConnection _ResolveEdgeDBTOML(ISystemProvider? platform)
while (true)
{
+ if (platform.FileExists(platform.CombinePaths(dir!, "gel.toml")))
+ return _FromProjectFile(platform.CombinePaths(dir!, "gel.toml"), platform);
+
if (platform.FileExists(platform.CombinePaths(dir!, "edgedb.toml")))
return _FromProjectFile(platform.CombinePaths(dir!, "edgedb.toml"), platform);
var parent = platform.DirectoryGetParent(dir!);
if (parent is null || !parent.Exists)
- throw new FileNotFoundException("Couldn't resolve edgedb.toml file");
+ return null;
dir = parent.FullName;
}
}
- private void ParseCloudInstanceName(string name, string? cloudProfile, ISystemProvider? platform)
+ private static ConfigUtils.ResolvedFields ParseCloudInstanceName(
+ string name, string? secretKey, string? cloudProfile, ISystemProvider? platform)
{
if (name.Length > DOMAIN_NAME_MAX_LEN)
{
throw new ConfigurationException($"Cloud instance name must be {DOMAIN_NAME_MAX_LEN} characters or less");
}
- var secretKey = SecretKey;
-
if (secretKey is null)
{
- var profile = ConfigUtils.ReadCloudProfile(cloudProfile ?? CloudProfile, platform);
+ var profile = ConfigUtils.ReadCloudProfile(cloudProfile ?? _defaultCloudProfile, platform);
if (profile.SecretKey is null)
{
@@ -821,8 +1372,11 @@ private void ParseCloudInstanceName(string name, string? cloudProfile, ISystemPr
spl = name.Split("/");
- Hostname = $"{spl[1]}--{spl[0]}.c-{dnsBucket}.i.{dnsZone}";
- SecretKey ??= secretKey;
+ return new()
+ {
+ Host = $"{spl[1]}--{spl[0]}.c-{dnsBucket}.i.{dnsZone}",
+ SecretKey = secretKey,
+ };
}
///
@@ -859,17 +1413,14 @@ internal static EdgeDBConnection _Parse(string? instance, string? dsn,
if (autoResolve && !((instance is not null && instance.Contains('/')) ||
(dsn is not null && !dsn.StartsWith("edgedb://") && !dsn.StartsWith("gel://"))))
{
- try
- {
- connection = _ResolveEdgeDBTOML(platform);
- }
- catch (FileNotFoundException)
+ ConfigUtils.ResolvedFields? resolvedFields = _ResolveEdgeDBTOML(platform);
+ if (resolvedFields is not null)
{
- // ignore
+ connection = _FromResolvedFields(resolvedFields, platform);
}
}
- #region Env
+ #region Old Env
var envName = string.Empty;
var envVar = string.Empty;
@@ -888,7 +1439,7 @@ internal static EdgeDBConnection _Parse(string? instance, string? dsn,
if (platform.GetGelEnvVariable(INSTANCE_ENV_NAME, out envName, out envVar))
{
- var fromInst = _FromInstanceName(envVar, null, platform);
+ var fromInst = _FromResolvedFields(_FromInstanceName(envVar, null, platform), platform);
connection = connection?.MergeInto(fromInst) ?? fromInst;
}
@@ -1034,7 +1585,7 @@ internal static EdgeDBConnection _Parse(string? instance, string? dsn,
if (instance is not null)
{
- var fromInst = _FromInstanceName(instance, null, platform);
+ var fromInst = _FromResolvedFields(_FromInstanceName(instance, null, platform), platform);
connection = connection?.MergeInto(fromInst) ?? fromInst;
}
@@ -1043,8 +1594,8 @@ internal static EdgeDBConnection _Parse(string? instance, string? dsn,
if (Regex.IsMatch(dsn, @"^([A-Za-z0-9](-?[A-Za-z0-9])*)\/([A-Za-z0-9](-?[A-Za-z0-9])*)$"))
{
// cloud
- connection ??= new EdgeDBConnection();
- connection.ParseCloudInstanceName(dsn, null, platform);
+ var fromCloud = _FromResolvedFields(ParseCloudInstanceName(dsn, connection?.SecretKey, null, platform), platform);
+ connection = connection?.MergeInto(fromCloud) ?? fromCloud;
}
else
{
diff --git a/src/EdgeDB.Net.Driver/Utils/ConfigUtils.cs b/src/EdgeDB.Net.Driver/Utils/ConfigUtils.cs
index 81045593..f65995d1 100644
--- a/src/EdgeDB.Net.Driver/Utils/ConfigUtils.cs
+++ b/src/EdgeDB.Net.Driver/Utils/ConfigUtils.cs
@@ -224,7 +224,21 @@ internal T CheckAndGetValue()
}
}
- static ResolvedField? MergeField(ResolvedField? to, ResolvedField? from)
+ internal static bool TryGetFieldValue(ResolvedField? field, out T value)
+ {
+ if (field is ResolvedField.Valid)
+ {
+ value = field.Value!;
+ return true;
+ }
+ else
+ {
+ value = default(T)!;
+ return false;
+ }
+ }
+
+ internal static ResolvedField? MergeField(ResolvedField? to, ResolvedField? from)
{
if (to is null)
{
@@ -312,6 +326,31 @@ Host is null
&& TLSServerName is null
&& WaitUntilAvailable is null
&& ServerSettings.Count == 0;
+
+ internal static ResolvedFields FromCredentials(ConnectionCredentials credentials)
+ {
+ ResolvedFields result = new();
+
+ if (credentials.Host is not null) { result.Host = credentials.Host; }
+ if (credentials.Port is not null)
+ {
+ result.Port = MergeField(result.Port, ParsePort(credentials.Port));
+ }
+ if (credentials.Database is not null)
+ {
+ result.DatabaseOrBranch = new DatabaseOrBranch.DatabaseName(credentials.Database);
+ }
+ if (credentials.Branch is not null)
+ {
+ result.DatabaseOrBranch = new DatabaseOrBranch.BranchName(credentials.Branch);
+ }
+ if (credentials.User is not null) { result.User = credentials.User; }
+ if (credentials.Password is not null) { result.Password = credentials.Password; }
+ if (credentials.TlsCA is not null) { result.TLSCertificateAuthority = credentials.TlsCA; }
+ if (credentials.TlsSecurity is not null) { result.TLSSecurity = credentials.TlsSecurity; }
+
+ return result;
+ }
}
#endregion
diff --git a/src/EdgeDB.Net.Driver/Utils/JsonUtils.cs b/src/EdgeDB.Net.Driver/Utils/JsonUtils.cs
new file mode 100644
index 00000000..f06ffd9a
--- /dev/null
+++ b/src/EdgeDB.Net.Driver/Utils/JsonUtils.cs
@@ -0,0 +1,37 @@
+using EdgeDB.DataTypes;
+using Newtonsoft.Json;
+
+namespace EdgeDB.Utils;
+
+internal class AsStringConverter : JsonConverter
+{
+ // Always reads numbers as strings
+
+ public override string? ReadJson(
+ JsonReader reader,
+ Type objectType,
+ string? existingValue,
+ bool hasExistingValue,
+ JsonSerializer serializer)
+ {
+ if (reader.TokenType == JsonToken.Integer)
+ {
+ return reader.Value!.ToString();
+ }
+ else if (reader.TokenType == JsonToken.Float)
+ {
+ return reader.Value!.ToString();
+ }
+ else if (reader.TokenType == JsonToken.String)
+ {
+ return (string)reader.Value!;
+ }
+ throw new JsonException("Expected Number or String.");
+ }
+
+ public override void WriteJson(
+ JsonWriter writer, string? value, JsonSerializer serializer)
+ {
+ throw new NotImplementedException();
+ }
+}
diff --git a/tests/EdgeDB.Tests.Unit/EdgeDB.Tests.Unit.csproj b/tests/EdgeDB.Tests.Unit/EdgeDB.Tests.Unit.csproj
index 9e889258..1eedf3a7 100644
--- a/tests/EdgeDB.Tests.Unit/EdgeDB.Tests.Unit.csproj
+++ b/tests/EdgeDB.Tests.Unit/EdgeDB.Tests.Unit.csproj
@@ -20,4 +20,10 @@
+
+
+ PreserveNewest
+
+
+
diff --git a/tests/EdgeDB.Tests.Unit/SharedClientTests.cs b/tests/EdgeDB.Tests.Unit/SharedClientTests.cs
new file mode 100644
index 00000000..b643aa3c
--- /dev/null
+++ b/tests/EdgeDB.Tests.Unit/SharedClientTests.cs
@@ -0,0 +1,729 @@
+using EdgeDB.Abstractions;
+using EdgeDB.Utils;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Newtonsoft.Json;
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Runtime.InteropServices;
+using System.Security.Cryptography;
+using System.Text;
+using System.Text.RegularExpressions;
+
+namespace EdgeDB.Tests.Unit;
+
+[TestClass]
+public class SharedClientTests
+{
+ [TestMethod]
+ public void TestConnectParams()
+ {
+ StreamReader reader = new("shared-client-testcases/connection_testcases.json");
+ List? testcases = JsonConvert.DeserializeObject>(reader.ReadToEnd());
+ if (testcases is null)
+ {
+ throw new JsonException("Failed to read 'connection_testcases.json.\n"
+ + "Is the 'shared-client-testcases' submodule initialised? "
+ + "Try running 'git submodule update --init'.");
+ }
+
+ foreach ((int textIndex, TestCase testCase) in testcases.Select((x, i) => (i, x)))
+ {
+ if (testCase.FileSystem is not null
+ && (
+ !(testCase.Platform is null && RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) ||
+ !(testCase.Platform == "windows" && !RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) ||
+ !(testCase.Platform == "macos" && RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
+ ))
+ {
+ // skipping unsupported platform test
+ continue;
+ }
+
+ if ((testCase.Result is null) == (testCase.Error is null))
+ {
+ throw new Exception("invalid test case: either \"result\" or \"error\" key has to be specified");
+ }
+
+ TestResult result = ParseConnection(testCase);
+
+ if (testCase.Result is not null)
+ {
+ AssertSameConnection(result, testCase.Result);
+ }
+ else if (testCase.Error is not null)
+ {
+ AssertSameException(result, testCase.Error);
+ }
+ }
+ }
+
+ private class TestResult
+ {
+ public EdgeDBConnection? Connection { get; init; }
+ public Exception? Exception { get; init; }
+
+ public static implicit operator TestResult(EdgeDBConnection c) => new() {Connection = c};
+ public static implicit operator TestResult(Exception x) => new() {Exception = x};
+ }
+
+ private static TestResult ParseConnection(TestCase testCase)
+ {
+ try
+ {
+ ISystemProvider mockSystem = new MockSystemProvider(testCase);
+
+ EdgeDBConnection.Options config = new()
+ {
+ Instance = testCase?.Options?.Instance,
+ Dsn = testCase?.Options?.Dsn,
+ Host = testCase?.Options?.Host,
+ Port = (
+ testCase?.Options?.Port is null
+ ? null
+ : int.TryParse(testCase?.Options?.Port, out var parsedPort)
+ ? parsedPort
+ : throw new ConfigurationException(
+ $"Invalid port: {testCase?.Options?.Port}, not an integer")
+ ),
+ Database = testCase?.Options?.Database,
+ Branch = testCase?.Options?.Branch,
+ User = testCase?.Options?.User,
+ Password = testCase?.Options?.Password,
+ SecretKey = testCase?.Options?.SecretKey,
+ Credentials = testCase?.Options?.Credentials,
+ CredentialsFile = testCase?.Options?.CredentialsFile,
+ TLSCertificateAuthority = testCase?.Options?.TlsCA,
+ TLSCertificateAuthorityFile = testCase?.Options?.TlsCAFile,
+ TLSSecurity = (
+ testCase?.Options?.TlsSecurity is null
+ ? null
+ : TLSSecurityModeParser.Parse(testCase.Options.TlsSecurity)
+ ),
+ TLSServerName = testCase?.Options?.TlsServerName,
+ WaitUntilAvailable = testCase?.Options?.WaitUntilAvailable,
+ ServerSettings = testCase?.Options?.ServerSettings,
+ };
+
+ EdgeDBConnection connection = EdgeDBConnection._Create(config, mockSystem);
+
+ return connection;
+ }
+ catch (Exception x)
+ {
+ return x;
+ }
+ }
+
+ #region Test Assertions
+
+ private static void AssertSameConnection(TestResult result, TestCase.ExpectedResult expectedResult)
+ {
+ string expectedHostname = expectedResult.Address is not null
+ ? expectedResult.Address[0]
+ : "localhost";
+ int expectedPort = expectedResult.Address is not null
+ ? int.Parse(expectedResult.Address[1])
+ : 5656;
+ string expectedDatabase = expectedResult.Database;
+ string expectedBranch = expectedResult.Branch;
+ string expectedUsername = expectedResult.User;
+ string expectedPassword = expectedResult.Password ?? "";
+ string? expectedSecretKey = expectedResult.SecretKey;
+ string? expectedTLSCertificateAuthority = expectedResult.TlsCAData;
+ TLSSecurityMode expectedTLSSecurity = expectedResult.TlsSecurity ?? TLSSecurityMode.Strict;
+ string? expectedTLSServerName = expectedResult.TlsServerName;
+ int expectedWaitUntilAvailable =
+ expectedResult.WaitUntilAvailable is not null
+ && ConfigUtils.TryGetFieldValue(
+ ConfigUtils.ParseWaitUntilAvailable(expectedResult.WaitUntilAvailable),
+ out int timeout)
+ ? timeout
+ : 30000;
+
+ Assert.IsNull(result.Exception, $"\"{result.Exception?.Message}\"\n{result.Exception?.StackTrace}");
+ Assert.IsNotNull(result.Connection);
+ EdgeDBConnection actual = result.Connection;
+
+ Assert.AreEqual(expectedHostname, actual.Hostname);
+ Assert.AreEqual(expectedPort, actual.Port);
+ Assert.AreEqual(expectedDatabase, actual.Database);
+ Assert.AreEqual(expectedBranch, actual.Branch);
+ Assert.AreEqual(expectedUsername, actual.Username);
+ Assert.AreEqual(expectedPassword, actual.Password);
+ Assert.AreEqual(expectedSecretKey, actual.SecretKey);
+ Assert.AreEqual(expectedTLSCertificateAuthority, actual.TLSCertificateAuthority);
+ Assert.AreEqual(expectedTLSSecurity, actual.TLSSecurity);
+ Assert.AreEqual(expectedTLSServerName, actual.TLSServerName);
+ Assert.AreEqual(expectedWaitUntilAvailable, actual.WaitUntilAvailable);
+ CollectionAssert.AreEqual(expectedResult.ServerSettings, actual.ServerSettings);
+ }
+
+ private static void AssertSameException(TestResult result, TestCase.ExpectedError expectedError)
+ {
+ (Type expectedType, Regex expectedRegex) = _errorMapping[expectedError.Type];
+
+ Assert.IsNull(result.Connection);
+ Assert.IsNotNull(result.Exception);
+ Exception actual = result.Exception;
+
+ Assert.IsInstanceOfType(
+ actual,
+ expectedType,
+ $"Exception type {expectedType} expected but got {result.Exception.GetType()}\n"
+ + $"{result.Exception.Message}\n"
+ + $"{result.Exception?.StackTrace}");
+ Assert.IsTrue(
+ expectedRegex.Match(actual.Message).Success,
+ $"Exception message \"{actual.Message}\" does not match pattern \"{expectedRegex}\"\n"
+ + $"{result.Exception?.StackTrace}");
+ }
+
+ private static readonly Dictionary _errorMapping = new()
+ {
+ {
+ "credentials_file_not_found",
+ (
+ typeof(ConfigurationException),
+ new("cannot read credentials", RegexOptions.Compiled)
+ )
+ },
+ {
+ "project_not_initialised",
+ (
+ typeof(ConfigurationException),
+ new("Found `\\w+.toml` but the project is not initialized", RegexOptions.Compiled)
+ )
+ },
+ {
+ "no_options_or_toml",
+ (
+ typeof(ConfigurationException),
+ new("No `gel.toml` found and no connection options specified", RegexOptions.Compiled)
+ )
+ },
+ {
+ "invalid_credentials_file",
+ (
+ typeof(ConfigurationException),
+ new("Invalid CredentialsFile", RegexOptions.Compiled)
+ )
+ },
+ {
+ "invalid_dsn_or_instance_name",
+ (
+ typeof(ConfigurationException),
+ new("Invalid (?:DSN|instance name)", RegexOptions.Compiled)
+ )
+ },
+ {
+ "invalid_instance_name",
+ (
+ typeof(ConfigurationException),
+ new("invalid instance name", RegexOptions.Compiled)
+ )
+ },
+ {
+ "invalid_dsn",
+ (
+ typeof(ConfigurationException),
+ new("Invalid DSN", RegexOptions.Compiled)
+ )
+ },
+ {
+ "unix_socket_unsupported",
+ (
+ typeof(ConfigurationException),
+ new("unix socket paths not supported", RegexOptions.Compiled)
+ )
+ },
+ {
+ "invalid_host",
+ (
+ typeof(ConfigurationException),
+ new("Invalid host", RegexOptions.Compiled)
+ )
+ },
+ {
+ "invalid_port",
+ (
+ typeof(ConfigurationException),
+ new("Invalid port", RegexOptions.Compiled)
+ )
+ },
+ {
+ "invalid_user",
+ (
+ typeof(ConfigurationException),
+ new("Invalid user", RegexOptions.Compiled)
+ )
+ },
+ {
+ "invalid_database",
+ (
+ typeof(ConfigurationException),
+ new("Invalid database", RegexOptions.Compiled)
+ )
+ },
+ {
+ "multiple_compound_env",
+ (
+ typeof(ConfigurationException),
+ new("Cannot have more than one of the following connection environment variables", RegexOptions.Compiled)
+ )
+ },
+ {
+ "multiple_compound_opts",
+ (
+ typeof(ConfigurationException),
+ new("Connection options cannot have more than one of the following values", RegexOptions.Compiled)
+ )
+ },
+ {
+ "exclusive_options",
+ (
+ typeof(ConfigurationException),
+ new("are mutually exclusive", RegexOptions.Compiled)
+ )
+ },
+ {
+ "env_not_found",
+ (
+ typeof(ConfigurationException),
+ new("environment variable \".*\" doesn\'t exist", RegexOptions.Compiled)
+ )
+ },
+ {
+ "file_not_found",
+ (
+ typeof(ConfigurationException),
+ new("could not find file", RegexOptions.Compiled)
+ )
+ },
+ {
+ "invalid_tls_security",
+ (
+ typeof(ConfigurationException),
+ new(
+ "Invalid TLS Security|\\w+ must be strict when \\w+ is strict",
+ RegexOptions.Compiled)
+ )
+ },
+ {
+ "invalid_secret_key",
+ (
+ typeof(ConfigurationException),
+ new("Invalid secret key", RegexOptions.Compiled)
+ )
+ },
+ {
+ "secret_key_not_found",
+ (
+ typeof(ConfigurationException),
+ new("Cannot connect to cloud instances without secret key", RegexOptions.Compiled)
+ )
+ },
+ {
+ "docker_tcp_port",
+ (
+ typeof(ConfigurationException),
+ new("\\w+_PORT in \"tcp://host:port\" format, so will be ignored", RegexOptions.Compiled)
+ )
+ },
+ {
+ "gel_and_edgedb",
+ (
+ typeof(ConfigurationException),
+ new("Both GEL_\\w+ and EDGEDB_\\w+ are set; EDGEDB_\\w+ will be ignored", RegexOptions.Compiled)
+ )
+ },
+ };
+
+ #endregion
+
+ #region MockSystemProvider
+
+ class MockSystemProvider : BaseDefaultSystemProvider
+ {
+ private readonly string? _homeDir;
+ private readonly string? _currentDir;
+ private readonly Dictionary _envVars;
+ private Dictionary _files;
+
+ public List Warnings { get; } = new();
+
+ public MockSystemProvider(TestCase testCase)
+ {
+ _homeDir = testCase.FileSystem?.HomeDir;
+ _currentDir = testCase.FileSystem?.CurrentDir;
+ _envVars = testCase.EnvVars ?? new();
+ _files = CacheFiles(testCase?.FileSystem?.Files);
+ }
+
+ private Dictionary CacheFiles(Dictionary? files)
+ {
+ return files?.SelectMany(
+ x => {
+ string path = x.Key;
+ TestCase.File file = x.Value;
+
+ if (file.Contents is not null)
+ {
+ return new List<(string, string)>(){(path, file.Contents)};
+ }
+ else
+ {
+ if (file.Fields is null)
+ {
+ throw new Exception("File must be either string or json object of fields");
+ }
+ if (!file.Fields.ContainsKey("project-path"))
+ {
+ throw new Exception("File as object must have \"project-path\" field");
+ }
+
+ List<(string,string)> subfiles = new();
+
+ string dir = path.Replace("${HASH}", ProjectPathHash(file.Fields["project-path"]));
+
+ foreach (KeyValuePair field in file.Fields)
+ {
+ subfiles.Add((CombinePaths(new string[]{ dir, field.Key }), field.Value));
+ }
+
+ return subfiles;
+ }
+ })
+ .ToDictionary(x => x.Item1, x => x.Item2)
+ ?? new();
+ }
+
+ private string ProjectPathHash(string path)
+ {
+ if (IsOSPlatform(OSPlatform.Windows) && !path.StartsWith("\\\\"))
+ {
+ path = "\\\\?\\" + path;
+ }
+
+ return Convert.ToHexString(SHA1.HashData(Encoding.UTF8.GetBytes(path)));
+ }
+
+ public override string GetHomeDir() => _homeDir ?? base.GetHomeDir();
+
+ public override string GetCurrentDirectory() => _currentDir ?? base.GetCurrentDirectory();
+
+ public override string? GetEnvVariable(string name)
+ => _envVars.TryGetValue(name, out var val)
+ ? val
+ : null;
+
+ public override bool FileExists(string path) => _files.ContainsKey(path);
+
+ public override string FileReadAllText(string path) => _files[path];
+
+ public override void WriteWarning(string message)
+ => Warnings.Add(message);
+ }
+
+ #endregion
+
+ #region TestCase
+
+ class TestCase
+ {
+ [JsonProperty("name")]
+ public string Name { get; init; } = string.Empty;
+
+ [JsonProperty("opts")]
+ public OptionsData? Options { get; init; }
+
+ [JsonProperty("env")]
+ public Dictionary? EnvVars { get; init; }
+
+ [JsonProperty("platform")]
+ public string? Platform { get; init; }
+
+ [JsonProperty("fs")]
+ public FileSystemData? FileSystem { get; init; }
+
+ [JsonProperty("warnings")]
+ public List? Warnings { get; init; }
+
+ [JsonProperty("result")]
+ public ExpectedResult? Result { get; init; }
+
+ [JsonProperty("error")]
+ public ExpectedError? Error { get; init; }
+
+ public class OptionsData
+ {
+ [JsonProperty("instance")]
+ public string? Instance { get; init; }
+
+ [JsonProperty("dsn")]
+ public string? Dsn { get; init; }
+
+ [JsonProperty("host")]
+ public string? Host { get; init; }
+
+ [JsonProperty("port")]
+ [JsonConverter(typeof(AsStringConverter))]
+ public string? Port { get; init; }
+
+ [JsonProperty("database")]
+ public string? Database { get; init; }
+
+ [JsonProperty("branch")]
+ public string? Branch { get; init; }
+
+ [JsonProperty("user")]
+ public string? User { get; init; }
+
+ [JsonProperty("password")]
+ public string? Password { get; init; }
+
+ [JsonProperty("secretKey")]
+ public string? SecretKey { get; init; }
+
+ [JsonProperty("credentials")]
+ public string? Credentials { get; init; }
+
+ [JsonProperty("credentialsFile")]
+ public string? CredentialsFile { get; init; }
+
+ [JsonProperty("tlsCA")]
+ public string? TlsCA { get; init; }
+
+ [JsonProperty("tlsCAFile")]
+ public string? TlsCAFile { get; init; }
+
+ [JsonProperty("tlsSecurity")]
+ public string? TlsSecurity { get; init; }
+
+ [JsonProperty("tlsServerName")]
+ public string? TlsServerName { get; init; }
+
+ [JsonProperty("waitUntilAvailable")]
+ public string? WaitUntilAvailable { get; init; }
+
+ [JsonProperty("serverSettings")]
+ public Dictionary? ServerSettings { get; init; }
+ }
+
+ public class Credentials
+ {
+ [JsonProperty("host")]
+ public string? Host { get; init; }
+
+ [JsonProperty("port")]
+ [JsonConverter(typeof(AsStringConverter))]
+ public string? Port { get; init; }
+
+ [JsonProperty("database")]
+ public string? Database { get; init; }
+
+ [JsonProperty("branch")]
+ public string? Branch { get; init; }
+
+ [JsonProperty("user")]
+ public string? User { get; init; }
+
+ [JsonProperty("password")]
+ public string? Password { get; init; }
+
+ [JsonProperty("tls_ca")]
+ public string? TlsCA { get; init; }
+
+ [JsonProperty("tls_security")]
+ [JsonConverter(typeof(TLSSecurityModeParser))]
+ public TLSSecurityMode? TlsSecurity { get; init; }
+ }
+
+ public class FileSystemData
+ {
+ [JsonProperty("cwd")]
+ public string? CurrentDir { get; init; }
+
+ [JsonProperty("homedir")]
+ public string? HomeDir { get; init; }
+
+ [JsonProperty("files")]
+ public Dictionary? Files { get; init; }
+ }
+
+ [JsonConverter(typeof(FileJsonConverter))]
+ public class File
+ {
+ // Has either string contents or has explicitly defined instance information
+
+ // string contents
+ public string? Contents { get; init; }
+
+ // instance information
+ public Dictionary? Fields { get; init; }
+ }
+
+ public class ExpectedResult
+ {
+ [JsonProperty("address")]
+ [JsonConverter(typeof(AsListStringConverter))]
+ public List Address { get; init; } = new();
+
+ [JsonProperty("database")]
+ public string Database { get; init; } = string.Empty;
+
+ [JsonProperty("branch")]
+ public string Branch { get; init; } = string.Empty;
+
+ [JsonProperty("user")]
+ public string User { get; init; } = string.Empty;
+
+ [JsonProperty("password")]
+ public string? Password { get; init; }
+
+ [JsonProperty("secretKey")]
+ public string? SecretKey { get; init; }
+
+ [JsonProperty("tlsCAData")]
+ public string? TlsCAData { get; init; }
+
+ [JsonProperty("tlsSecurity")]
+ [JsonConverter(typeof(TLSSecurityModeParser))]
+ public TLSSecurityMode? TlsSecurity { get; init; }
+
+ [JsonProperty("tlsServerName")]
+ public string? TlsServerName { get; init; }
+
+ [JsonProperty("waitUntilAvailable")]
+ public string? WaitUntilAvailable { get; init; }
+
+ [JsonProperty("serverSettings")]
+ public Dictionary? ServerSettings { get; init; }
+ }
+
+ public class ExpectedError
+ {
+ [JsonProperty("type")]
+ public string Type { get; init; } = string.Empty;
+ }
+
+ private class AsListStringConverter : JsonConverter>
+ {
+ // Always reads numbers as strings
+
+ public override List? ReadJson(
+ JsonReader reader,
+ Type objectType,
+ List? existingValue,
+ bool hasExistingValue,
+ JsonSerializer serializer)
+ {
+ List result = new();
+
+ // skip JsonToken.StartArray
+ while (reader.Read())
+ {
+ if (reader.TokenType == JsonToken.EndArray)
+ {
+ break;
+ }
+
+ if (reader.TokenType == JsonToken.Integer)
+ {
+ result.Add(reader.Value!.ToString()!);
+ }
+ else if (reader.TokenType == JsonToken.Float)
+ {
+ result.Add(reader.Value!.ToString()!);
+ }
+ else if (reader.TokenType == JsonToken.String)
+ {
+ result.Add((string)reader.Value!);
+ }
+ else
+ {
+ throw new JsonException(
+ $"Invalid {reader.TokenType} token: \"{reader.Value}\", expected Number or String.");
+ }
+ }
+
+ return result;
+ }
+
+ public override void WriteJson(
+ JsonWriter writer, List? value, JsonSerializer serializer)
+ {
+ throw new NotImplementedException();
+ }
+ }
+
+ private class FileJsonConverter : JsonConverter
+ {
+ public override File? ReadJson(
+ JsonReader reader,
+ Type objectType,
+ File? existingValue,
+ bool hasExistingValue,
+ JsonSerializer serializer)
+ {
+ if (reader.TokenType == JsonToken.String)
+ {
+ // string contents
+ return new()
+ {
+ Contents = (string)reader.Value!,
+ };
+ }
+ else if (reader.TokenType == JsonToken.StartObject)
+ {
+ // instance information
+ Dictionary fields = new();
+ while (reader.Read())
+ {
+ if (reader.TokenType == JsonToken.EndObject)
+ {
+ break;
+ }
+
+ if (reader.TokenType != JsonToken.PropertyName)
+ {
+ throw new JsonException(
+ $"Invalid {reader.TokenType} token: \"{reader.Value}\", expected PropertyName.");
+ }
+
+ string propertyName = (string)reader.Value!;
+
+ reader.Read();
+ if (reader.TokenType != JsonToken.String)
+ {
+ throw new JsonException(
+ $"Invalid {reader.TokenType} token: \"{reader.Value}\", expected String.");
+ }
+
+ string value = (string)reader.Value!;
+
+ fields[propertyName] = value;
+ }
+
+ return new()
+ {
+ Fields = fields,
+ };
+ }
+ else
+ {
+ throw new JsonException("Could not read File object.");
+ }
+ }
+
+ public override void WriteJson(
+ JsonWriter writer, File? value, JsonSerializer serializer)
+ {
+ throw new NotImplementedException();
+ }
+ }
+ }
+
+ #endregion
+}
diff --git a/tests/EdgeDB.Tests.Unit/shared-client-testcases b/tests/EdgeDB.Tests.Unit/shared-client-testcases
new file mode 160000
index 00000000..d720f509
--- /dev/null
+++ b/tests/EdgeDB.Tests.Unit/shared-client-testcases
@@ -0,0 +1 @@
+Subproject commit d720f509b352f2fe587b5ea6d0c624553b95e04b