diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2917bd58..08549481 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -11,6 +11,8 @@ jobs: steps: - uses: actions/checkout@v2 + with: + submodules: true - name: Setup .NET uses: actions/setup-dotnet@v1 diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..88f37299 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "tests/EdgeDB.Tests.Unit/shared-client-testcases"] + path = tests/EdgeDB.Tests.Unit/shared-client-testcases + url = https://github.com/edgedb/shared-client-testcases.git diff --git a/src/EdgeDB.Net.Driver/EdgeDBConnection.cs b/src/EdgeDB.Net.Driver/EdgeDBConnection.cs index b576cf2b..fd70e2df 100644 --- a/src/EdgeDB.Net.Driver/EdgeDBConnection.cs +++ b/src/EdgeDB.Net.Driver/EdgeDBConnection.cs @@ -10,6 +10,40 @@ namespace EdgeDB; +/// +/// A json readable representation of an EdgeDBConnection. +/// When using Credentials to create an EdgeDBConnection, the data must conform to this type. +/// +internal class ConnectionCredentials +{ + [JsonProperty("host")] + public string? Host { get; init; } + + [JsonProperty("port")] + [JsonConverter(typeof(AsStringConverter))] + public string? Port { get; init; } + + [JsonProperty("database")] + public string? Database { get; init; } + + [JsonProperty("branch")] + public string? Branch { get; init; } + + [JsonProperty("user")] + public string? User { get; init; } + + [JsonProperty("password")] + public string? Password { get; init; } + + [JsonProperty("tls_ca")] + public string? TlsCA { get; init; } + + [JsonProperty("tls_security")] + [JsonConverter(typeof(TLSSecurityModeParser))] + public TLSSecurityMode? TlsSecurity { get; init; } +} + + /// /// Represents a class containing information on how to connect to a edgedb instance. /// @@ -18,16 +52,19 @@ public sealed class EdgeDBConnection private const string INSTANCE_ENV_NAME = "INSTANCE"; private const string DSN_ENV_NAME = "DSN"; private const string CREDENTIALS_FILE_ENV_NAME = "CREDENTIALS_FILE"; - private const string USER_ENV_NAME = "USER"; - private const string PASSWORD_ENV_NAME = "PASSWORD"; - private const string DATABASE_ENV_NAME = "DATABASE"; - private const string BRANCH_ENV_NAME = "BRANCH"; private const string HOST_ENV_NAME = "HOST"; private const string PORT_ENV_NAME = "PORT"; + private const string DATABASE_ENV_NAME = "DATABASE"; + private const string BRANCH_ENV_NAME = "BRANCH"; + private const string USER_ENV_NAME = "USER"; + private const string PASSWORD_ENV_NAME = "PASSWORD"; + private const string SECRET_KEY_ENV_NAME = "SECRET_KEY"; + private const string TLS_CA_ENV_NAME = "TLS_CA"; private const string CLIENT_SECURITY_ENV_NAME = "CLIENT_SECURITY"; private const string CLIENT_TLS_SECURITY_ENV_NAME = "CLIENT_TLS_SECURITY"; + private const string TLS_SERVER_NAME_ENV_NAME = "TLS_SERVER_NAME"; + private const string WAIT_UNTIL_AVAILABLE_ENV_NAME = "WAIT_UNTIL_AVAILABLE"; private const string CLOUD_PROFILE_ENV_NAME = "CLOUD_PROFILE"; - private const string SECRET_KEY_ENV_NAME = "SECRET_KEY"; private const int DOMAIN_NAME_MAX_LEN = 62; private EdgeDBConnection MergeInto(EdgeDBConnection other) @@ -295,31 +332,544 @@ public string CloudProfile #endregion - #region Construct methods + #region Create Function + + /// + /// Optional args which can be passed into . + /// + public class Options + { + // Primary args + // These can set host/port of the connection. + public string? Instance { get; set; } + public string? Dsn { get; set; } + public string? Host { get; set; } + public int? Port { get; set; } + + // Secondary args + public string? Database { get; set; } + public string? Branch { get; set; } + public string? User { get; set; } + public string? Password { get; set; } + public string? SecretKey { get; set; } + public string? Credentials { get; set; } + public string? CredentialsFile { get; set; } + public string? TLSCertificateAuthority { get; set; } + public string? TLSCertificateAuthorityFile { get; set; } + public TLSSecurityMode? TLSSecurity { get; set; } + public string? TLSServerName { get; set; } + public string? WaitUntilAvailable { get; set; } + public Dictionary? ServerSettings { get; set; } + + public bool IsEmpty => + Instance is null + && Dsn is null + && Host is null + && Port is null + && Database is null + && Branch is null + && User is null + && Password is null + && SecretKey is null + && Credentials is null + && CredentialsFile is null + && TLSCertificateAuthority is null + && TLSCertificateAuthorityFile is null + && TLSSecurity is null + && TLSServerName is null + && WaitUntilAvailable is null + && ServerSettings is null; + } + + /// + /// Parses the `gel.toml`, optional , and environment variables to build an + /// . + /// + /// This function will first search for the first valid primary args (which can set host/port) + /// in the following order: + /// - + /// - Environment variables + /// - `gel.toml` file + /// + /// It will then apply any secondary args from the environment variables and options. + /// + /// If any primary are present, then all environment variables are ignored. + /// + /// See the documentation + /// for more information. + /// + /// Options used to build the . + /// + /// A class that can be used to connect to a EdgeDB instance. + /// + /// + /// An error occured while parsing or configuring the . + /// + public static EdgeDBConnection Create(Options? options = null) + { + return _Create(options ?? new(), null); + } + + internal static EdgeDBConnection _Create(Options options, ISystemProvider? platform) + { + platform ??= ConfigUtils.DefaultPlatformProvider; + + ConfigUtils.ResolvedFields resolvedFields = new(); + + #region Primary Options + + // First, check primary options + // If any primary options are present, environment variables are ignored + + bool hasPrimaryOptions = false; + { + // These options can set host/port and should be resolved first + // More than one primary options should raise an error + + Exception primaryError = new ConfigurationException( + "Connection options cannot have more than one of the following " + + "values: \"Instance\", \"Dsn\", \"Credentials\", " + + "\"CredentialsFile\" or \"Host\"/\"Port\""); + // The primaryError has priority, so hold on to any other exception + // until all primary options are processed. + Exception? deferredPrimaryError = null; + + if (options.Instance is not null) + { + if (hasPrimaryOptions) { throw primaryError; } + if (options.Instance != "") + { + try + { + var fromDSN = _FromInstanceName(options.Instance, null, platform); + resolvedFields.MergeFrom(fromDSN); + } + catch (Exception e) + { + deferredPrimaryError = e; + } + } + else + { + deferredPrimaryError = new ConfigurationException( + $"Invalid instance name: \"{options.Instance}\""); + } + hasPrimaryOptions = true; + } + + if (options.Dsn is not null) + { + if (hasPrimaryOptions) { throw primaryError; } + var fromDSN = _FromDSN(options.Dsn, platform); + resolvedFields.MergeFrom(fromDSN); + hasPrimaryOptions = true; + } + + { + string? credentialsText = null; + if (options.Credentials is not null) + { + if (hasPrimaryOptions) { throw primaryError; } + credentialsText = options.Credentials; + hasPrimaryOptions = true; + } + if (options.CredentialsFile is not null) + { + if (hasPrimaryOptions) { throw primaryError; } + if (platform.FileExists(options.CredentialsFile)) + { + credentialsText = platform.FileReadAllText(options.CredentialsFile) ?? "{}"; + } + else + { + deferredPrimaryError = new ConfigurationException( + $"Invalid CredentialsFile: \"{options.CredentialsFile}\", could not find file"); + } + hasPrimaryOptions = true; + } + if (credentialsText is not null) + { + try + { + ConnectionCredentials? credentials = + new JsonSerializer().DeserializeObject(credentialsText); + if (credentials is not null) + { + resolvedFields.MergeFrom(ConfigUtils.ResolvedFields.FromCredentials(credentials)); + } + } + catch (JsonException) + { + deferredPrimaryError = new ConfigurationException("Invalid Credentials: could not parse json"); + } + } + } + + { + bool hasHostOrPort = false; + if (options.Host is not null) + { + if (hasPrimaryOptions) { throw primaryError; } + resolvedFields.Host = options.Host; + hasHostOrPort = true; + } + if (options.Port is not null) + { + if (hasPrimaryOptions) { throw primaryError; } + resolvedFields.Port = options.Port; + hasHostOrPort = true; + } + if (hasHostOrPort) + { + hasPrimaryOptions = true; + } + } + + if (deferredPrimaryError is not null) + { + throw deferredPrimaryError; + } + } + + #endregion + + #region Primary Env + + var envName = string.Empty; + var envVar = string.Empty; + + bool hasPrimaryEnv = false; + + if (!hasPrimaryOptions) + { + // These env vars can set host/port and should be resolved first + // More than one primary env var should raise an error + + Exception primaryError = new ConfigurationException( + "Cannot have more than one of the following connection " + + "environment variables: \"GEL_DSN\", \"GEL_INSTANCE\", " + + "\"GEL_CREDENTIALS_FILE\" or \"GEL_HOST\"/\"GEL_PORT\""); + // The primaryError has priority, so hold on to any other exception + // until all primary env vars are processed. + Exception? deferredPrimaryError = null; + + if (platform.GetGelEnvVariable(INSTANCE_ENV_NAME, out envName, out envVar)) + { + if (hasPrimaryEnv) { throw primaryError; } + try + { + var fromInst = _FromInstanceName(envVar, null, platform); + resolvedFields.MergeFrom(fromInst); + } + catch (Exception e) + { + deferredPrimaryError = e; + } + hasPrimaryEnv = true; + } + + if (platform.GetGelEnvVariable(DSN_ENV_NAME, out envName, out envVar)) + { + if (hasPrimaryEnv) { throw primaryError; } + var fromDSN = _FromDSN(envVar, platform); + resolvedFields.MergeFrom(fromDSN); + hasPrimaryEnv = true; + } + + if (platform.GetGelEnvVariable(CREDENTIALS_FILE_ENV_NAME, out envName, out envVar)) + { + if (hasPrimaryEnv) { throw primaryError; } + if (platform.FileExists(envVar)) + { + var credentials = + JsonConvert.DeserializeObject(platform.FileReadAllText(envVar))!; + resolvedFields.MergeFrom(ConfigUtils.ResolvedFields.FromCredentials(credentials)); + } + else + { + deferredPrimaryError = new FileNotFoundException( + $"Invalid credential file from {envName}: \"{envVar}\", could not find file"); + } + hasPrimaryEnv = true; + } + + { + bool hasHostOrPort = false; + if (platform.GetGelEnvVariable(HOST_ENV_NAME, out envName, out envVar)) + { + if (hasPrimaryEnv) { throw primaryError; } + resolvedFields.Host = envVar; + hasHostOrPort = true; + } + if (platform.GetGelEnvVariable(PORT_ENV_NAME, out envName, out envVar)) + { + ConfigUtils.ResolvedField? port = ConfigUtils.ParsePort(envVar); + if (port is not null) + { + if (hasPrimaryEnv) { throw primaryError; } + resolvedFields.Port = ConfigUtils.MergeField(resolvedFields.Port, port); + hasHostOrPort = true; + } + } + if (hasHostOrPort) + { + hasPrimaryEnv = true; + } + } + + if (deferredPrimaryError is not null) + { + throw deferredPrimaryError; + } + } + + #endregion + + #region Toml File + + if (!hasPrimaryOptions && !hasPrimaryEnv) + { + ConfigUtils.ResolvedFields? fromToml = _ResolveEdgeDBTOML(platform); + if (fromToml is not null) + { + resolvedFields.MergeFrom(fromToml); + } + } + + #endregion + + #region Secondary Env + + if (!hasPrimaryOptions) + { + if (platform.GetGelEnvVariable(DATABASE_ENV_NAME, out envName, out envVar)) + { + var altName = string.Empty; + var altVal = string.Empty; + if (platform.GetGelEnvVariable(BRANCH_ENV_NAME, out altName, out altVal)) + { + throw new ConfigurationException( + $"Environment variables {envName} and {altName} are mutually exclusive"); + } + + resolvedFields.DatabaseOrBranch = new ConfigUtils.DatabaseOrBranch.DatabaseName(envVar); + } + + if (platform.GetGelEnvVariable(BRANCH_ENV_NAME, out envName, out envVar)) + { + resolvedFields.DatabaseOrBranch = new ConfigUtils.DatabaseOrBranch.BranchName(envVar); + } + + if (platform.GetGelEnvVariable(USER_ENV_NAME, out envName, out envVar)) + { + resolvedFields.User = envVar; + } + + if (platform.GetGelEnvVariable(PASSWORD_ENV_NAME, out envName, out envVar)) + { + resolvedFields.Password = envVar; + } + + if (platform.GetGelEnvVariable(TLS_CA_ENV_NAME, out envName, out envVar)) + { + resolvedFields.TLSCertificateAuthority = envVar; + } + + { + string clientSecurityEnvName; + string clientTlsSecurityEnvName; + TLSSecurityMode? clientSecurity = null; + TLSSecurityMode? clientTlsSecurity = null; + bool hasDefault = false; + if (platform.GetGelEnvVariable(CLIENT_SECURITY_ENV_NAME, out clientSecurityEnvName, out envVar)) + { + if (TLSSecurityModeParser.TryParse(envVar, true, out clientSecurity)) + { + if (clientSecurity is not null) + { + resolvedFields.TLSSecurity = clientSecurity; + } + else + { + hasDefault = true; + } + } + else + { + resolvedFields.TLSSecurity = new ConfigurationException( + $"Invalid TLS Security from {clientSecurityEnvName}: \"{envVar}\""); + } + } + if (platform.GetGelEnvVariable(CLIENT_TLS_SECURITY_ENV_NAME, out clientTlsSecurityEnvName, out envVar)) + { + if (TLSSecurityModeParser.TryParse(envVar, true, out clientTlsSecurity)) + { + if (clientTlsSecurity is null) + { + hasDefault = true; + } + else if (clientSecurity is null) + { + // overwrite default value + resolvedFields.TLSSecurity = clientTlsSecurity.Value; + } + else if (clientSecurity == TLSSecurityMode.Strict + && clientTlsSecurity != TLSSecurityMode.Strict) + { + throw new ConfigurationException( + $"{clientSecurityEnvName}=strict but {clientTlsSecurityEnvName}={envVar}. " + + $"{clientTlsSecurityEnvName} must be strict when {clientSecurityEnvName} " + + $"is strict" + ); + } + else + { + // overwrite existing value + resolvedFields.TLSSecurity = clientTlsSecurity.Value; + } + } + else + { + resolvedFields.TLSSecurity = new ConfigurationException( + $"Invalid TLS Security from {clientTlsSecurityEnvName}: \"{envVar}\""); + } + } + if (hasDefault) + { + // finally, apply default value if no non-default value or error present + resolvedFields.TLSSecurity ??= TLSSecurityMode.Default; + } + } + + if (platform.GetGelEnvVariable(TLS_SERVER_NAME_ENV_NAME, out envName, out envVar)) + { + resolvedFields.TLSServerName = envVar; + } + + if (platform.GetGelEnvVariable(WAIT_UNTIL_AVAILABLE_ENV_NAME, out envName, out envVar)) + { + resolvedFields.WaitUntilAvailable = ConfigUtils.ParseWaitUntilAvailable(envVar); + } + } + + #endregion + + #region Secondary Options + + // Finally, check secondary options + // Secondary options should override environment variables + + if (options.Database is not null && options.Branch is not null) + { + throw new ConfigurationException("Invalid options: Database and Branch are mutually exclusive."); + } + else if (options.Database is not null) + { + resolvedFields.DatabaseOrBranch = new ConfigUtils.DatabaseOrBranch.DatabaseName(options.Database); + } + else if (options.Branch is not null) + { + resolvedFields.DatabaseOrBranch = new ConfigUtils.DatabaseOrBranch.BranchName(options.Branch); + } + + if (options.User is not null) { resolvedFields.User = options.User; } + if (options.Password is not null) { resolvedFields.Password = options.Password; } + if (options.SecretKey is not null) { resolvedFields.SecretKey = options.SecretKey; } + if (options.TLSCertificateAuthority is not null) + { + resolvedFields.TLSCertificateAuthority = options.TLSCertificateAuthority; + } + if (options.TLSCertificateAuthorityFile is not null) + { + if (platform.FileExists(options.TLSCertificateAuthorityFile)) + { + resolvedFields.TLSCertificateAuthority = + platform.FileReadAllText(options.TLSCertificateAuthorityFile); + } + else + { + throw new ConfigurationException( + $"Invalid TLSCertificateAuthorityFile: \"{options.TLSCertificateAuthorityFile}\", could not find file"); + } + } + if (options.TLSSecurity is not null) { resolvedFields.TLSSecurity = options.TLSSecurity; } + if (options.TLSServerName is not null) { resolvedFields.TLSServerName = options.TLSServerName; } + if (options.WaitUntilAvailable is not null) + { + resolvedFields.WaitUntilAvailable = ConfigUtils.MergeField( + resolvedFields.WaitUntilAvailable, + ConfigUtils.ParseWaitUntilAvailable(options.WaitUntilAvailable)); + } + if (options.ServerSettings is not null) + { + foreach (KeyValuePair entry in options.ServerSettings) + { + resolvedFields.ServerSettings = ConfigUtils.AddServerSettingField( + resolvedFields.ServerSettings, entry.Key, entry.Value); + } + } + + #endregion + + if (options.IsEmpty && resolvedFields.IsEmpty) + { + throw new ConfigurationException("No `gel.toml` found and no connection options specified."); + } + + return _FromResolvedFields(resolvedFields, platform); + } + + #endregion + + #region Create Helpers internal static EdgeDBConnection _FromResolvedFields(ConfigUtils.ResolvedFields resolvedFields, ISystemProvider? platform) { platform ??= ConfigUtils.DefaultPlatformProvider; - if (resolvedFields.Host?.Value is not null && resolvedFields.Host.Value.Contains(',')) + if (ConfigUtils.TryGetFieldValue(resolvedFields.Host, out string host)) { - throw new ConfigurationException( - $"Invalid host: \"{resolvedFields.Host.Value}\", DSN cannot contain more than one host"); + if (host.Contains(',')) + { + throw new ConfigurationException( + $"Invalid host: \"{host}\", DSN cannot contain more than one host"); + } + if (host == "") + { + throw new ConfigurationException($"Invalid host: \"{host}\""); + } + if (host.StartsWith("/")) + { + throw new ConfigurationException($"Invalid host: \"{host}\", unix socket paths not supported"); + } } - if (resolvedFields.DatabaseOrBranch?.Value?.Value == "") + if (ConfigUtils.TryGetFieldValue(resolvedFields.Port, out int port)) { - throw resolvedFields.DatabaseOrBranch.Value switch + if (port < 1 || 65535 < port) { - ConfigUtils.DatabaseOrBranch.DatabaseName name => new ConfigurationException( - $"Invalid database name: \"{name.Value}\""), - ConfigUtils.DatabaseOrBranch.BranchName name => new ConfigurationException( - $"Invalid branch name: \"{name.Value}\""), - _ => new ConfigurationException("Invalid database or branch name"), - }; + throw new ConfigurationException($"Invalid port: \"{port}\", must be between 1 and 65535"); + } } - if (resolvedFields.User == "") + if (ConfigUtils.TryGetFieldValue(resolvedFields.DatabaseOrBranch, out ConfigUtils.DatabaseOrBranch databaseOrBranch)) { - throw new ConfigurationException($"Invalid user: \"{resolvedFields.User.Value}\""); + if (databaseOrBranch.Value == "") + { + throw databaseOrBranch switch + { + ConfigUtils.DatabaseOrBranch.DatabaseName name => new ConfigurationException( + $"Invalid database name: \"{name.Value}\""), + ConfigUtils.DatabaseOrBranch.BranchName name => new ConfigurationException( + $"Invalid branch name: \"{name.Value}\""), + _ => new ConfigurationException("Invalid database or branch name"), + }; + } + } + if (ConfigUtils.TryGetFieldValue(resolvedFields.User, out string user)) + { + if (user == "") + { + throw new ConfigurationException($"Invalid user: \"{user}\""); + } } return new() @@ -525,10 +1075,9 @@ internal static ConfigUtils.ResolvedFields _FromDSN(string dsn, ISystemProvider? string key = arg.Key; ConfigUtils.ResolvedField value = arg.Value; - if (key.EndsWith("_env")) + if (key.EndsWith("_env") && ConfigUtils.TryGetFieldValue(value, out string envName)) { string oldKey = key; - string envName = value.Value!; key = key.Substring(0, key.Length - "_env".Length); string? envVar = platform.GetEnvVariable(envName); if (envVar is not null) @@ -542,10 +1091,9 @@ internal static ConfigUtils.ResolvedFields _FromDSN(string dsn, ISystemProvider? } } - if (key.EndsWith("_file") && value.Value is not null) + if (key.EndsWith("_file") && ConfigUtils.TryGetFieldValue(value, out string fileName)) { string oldKey = key; - string fileName = value.Value!; key = key.Substring(0, key.Length - "_file".Length); if (platform.FileExists(fileName)) { @@ -635,17 +1183,7 @@ internal static ConfigUtils.ResolvedFields _FromDSN(string dsn, ISystemProvider? }); break; case "wait_until_available": - resolvedFields.WaitUntilAvailable = value.Convert(v => - { - try - { - return ConfigUtils.ParseWaitUntilAvailable(v); - } - catch (Exception e) - { - return e; - } - }); + resolvedFields.WaitUntilAvailable = value.Convert(ConfigUtils.ParseWaitUntilAvailable); break; default: @@ -667,10 +1205,10 @@ internal static ConfigUtils.ResolvedFields _FromDSN(string dsn, ISystemProvider? /// The project directory doesn't exist for the supplied toml file. public static EdgeDBConnection FromProjectFile(string path) { - return _FromProjectFile(path, null); + return _FromResolvedFields(_FromProjectFile(path, null), null); } - internal static EdgeDBConnection _FromProjectFile(string path, ISystemProvider? platform) + internal static ConfigUtils.ResolvedFields _FromProjectFile(string path, ISystemProvider? platform) { platform ??= ConfigUtils.DefaultPlatformProvider; @@ -690,12 +1228,14 @@ internal static EdgeDBConnection _FromProjectFile(string path, ISystemProvider? if (!ConfigUtils.TryResolveInstanceCloudProfile(projectDir, out var profile, out var inst, platform) || inst is null) throw new FileNotFoundException($"Could not find instance name under project directory {projectDir}"); - var connection = _FromInstanceName(inst, profile, platform); + var resolvedFields = _FromInstanceName(inst, profile, platform); - if (ConfigUtils.TryResolveProjectDatabase(projectDir, out var database, platform)) - connection.Database = database; + if (ConfigUtils.TryResolveProjectDatabase(projectDir, out var database, platform) && database is not null) + { + resolvedFields.DatabaseOrBranch = new ConfigUtils.DatabaseOrBranch.DatabaseName(database); + } - return connection; + return resolvedFields; } /// @@ -714,10 +1254,10 @@ internal static EdgeDBConnection _FromProjectFile(string path, ISystemProvider? /// The configuration is invalid. public static EdgeDBConnection FromInstanceName(string name, string? cloudProfile = null) { - return _FromInstanceName(name, cloudProfile, null); + return _FromResolvedFields(_FromInstanceName(name, cloudProfile, null), null); } - internal static EdgeDBConnection _FromInstanceName(string name, string? cloudProfile, ISystemProvider? platform) + internal static ConfigUtils.ResolvedFields _FromInstanceName(string name, string? cloudProfile, ISystemProvider? platform) { platform ??= ConfigUtils.DefaultPlatformProvider; @@ -725,16 +1265,20 @@ internal static EdgeDBConnection _FromInstanceName(string name, string? cloudPro { var configPath = platform.CombinePaths(ConfigUtils.GetCredentialsDir(platform), $"{name}.json"); - return !platform.FileExists(configPath) - ? throw new FileNotFoundException($"Config file couldn't be found at {configPath}") - : JsonConvert.DeserializeObject(platform.FileReadAllText(configPath))!; + if (!platform.FileExists(configPath)) + { + throw new FileNotFoundException($"Config file couldn't be found at {configPath}"); + } + + ConnectionCredentials credentials = JsonConvert.DeserializeObject( + platform.FileReadAllText(configPath))!; + + return ConfigUtils.ResolvedFields.FromCredentials(credentials); } if (Regex.IsMatch(name, @"^([A-Za-z0-9](-?[A-Za-z0-9])*)\/([A-Za-z0-9](-?[A-Za-z0-9])*)$")) { - var conn = new EdgeDBConnection(); - conn.ParseCloudInstanceName(name, cloudProfile, platform); - return conn; + return ParseCloudInstanceName(name, null, cloudProfile, platform); } throw new ConfigurationException($"Invalid instance name '{name}'"); @@ -748,10 +1292,15 @@ internal static EdgeDBConnection _FromInstanceName(string name, string? cloudPro /// No 'edgedb.toml' file could be found. public static EdgeDBConnection ResolveEdgeDBTOML() { - return _ResolveEdgeDBTOML(null); + ConfigUtils.ResolvedFields? resolvedFields = _ResolveEdgeDBTOML(null); + if (resolvedFields is null) + { + throw new ConfigurationException("Couldn't resolve gel.toml file"); + } + return _FromResolvedFields(resolvedFields, null); } - internal static EdgeDBConnection _ResolveEdgeDBTOML(ISystemProvider? platform) + internal static ConfigUtils.ResolvedFields? _ResolveEdgeDBTOML(ISystemProvider? platform) { platform ??= ConfigUtils.DefaultPlatformProvider; @@ -759,30 +1308,32 @@ internal static EdgeDBConnection _ResolveEdgeDBTOML(ISystemProvider? platform) while (true) { + if (platform.FileExists(platform.CombinePaths(dir!, "gel.toml"))) + return _FromProjectFile(platform.CombinePaths(dir!, "gel.toml"), platform); + if (platform.FileExists(platform.CombinePaths(dir!, "edgedb.toml"))) return _FromProjectFile(platform.CombinePaths(dir!, "edgedb.toml"), platform); var parent = platform.DirectoryGetParent(dir!); if (parent is null || !parent.Exists) - throw new FileNotFoundException("Couldn't resolve edgedb.toml file"); + return null; dir = parent.FullName; } } - private void ParseCloudInstanceName(string name, string? cloudProfile, ISystemProvider? platform) + private static ConfigUtils.ResolvedFields ParseCloudInstanceName( + string name, string? secretKey, string? cloudProfile, ISystemProvider? platform) { if (name.Length > DOMAIN_NAME_MAX_LEN) { throw new ConfigurationException($"Cloud instance name must be {DOMAIN_NAME_MAX_LEN} characters or less"); } - var secretKey = SecretKey; - if (secretKey is null) { - var profile = ConfigUtils.ReadCloudProfile(cloudProfile ?? CloudProfile, platform); + var profile = ConfigUtils.ReadCloudProfile(cloudProfile ?? _defaultCloudProfile, platform); if (profile.SecretKey is null) { @@ -821,8 +1372,11 @@ private void ParseCloudInstanceName(string name, string? cloudProfile, ISystemPr spl = name.Split("/"); - Hostname = $"{spl[1]}--{spl[0]}.c-{dnsBucket}.i.{dnsZone}"; - SecretKey ??= secretKey; + return new() + { + Host = $"{spl[1]}--{spl[0]}.c-{dnsBucket}.i.{dnsZone}", + SecretKey = secretKey, + }; } /// @@ -859,17 +1413,14 @@ internal static EdgeDBConnection _Parse(string? instance, string? dsn, if (autoResolve && !((instance is not null && instance.Contains('/')) || (dsn is not null && !dsn.StartsWith("edgedb://") && !dsn.StartsWith("gel://")))) { - try - { - connection = _ResolveEdgeDBTOML(platform); - } - catch (FileNotFoundException) + ConfigUtils.ResolvedFields? resolvedFields = _ResolveEdgeDBTOML(platform); + if (resolvedFields is not null) { - // ignore + connection = _FromResolvedFields(resolvedFields, platform); } } - #region Env + #region Old Env var envName = string.Empty; var envVar = string.Empty; @@ -888,7 +1439,7 @@ internal static EdgeDBConnection _Parse(string? instance, string? dsn, if (platform.GetGelEnvVariable(INSTANCE_ENV_NAME, out envName, out envVar)) { - var fromInst = _FromInstanceName(envVar, null, platform); + var fromInst = _FromResolvedFields(_FromInstanceName(envVar, null, platform), platform); connection = connection?.MergeInto(fromInst) ?? fromInst; } @@ -1034,7 +1585,7 @@ internal static EdgeDBConnection _Parse(string? instance, string? dsn, if (instance is not null) { - var fromInst = _FromInstanceName(instance, null, platform); + var fromInst = _FromResolvedFields(_FromInstanceName(instance, null, platform), platform); connection = connection?.MergeInto(fromInst) ?? fromInst; } @@ -1043,8 +1594,8 @@ internal static EdgeDBConnection _Parse(string? instance, string? dsn, if (Regex.IsMatch(dsn, @"^([A-Za-z0-9](-?[A-Za-z0-9])*)\/([A-Za-z0-9](-?[A-Za-z0-9])*)$")) { // cloud - connection ??= new EdgeDBConnection(); - connection.ParseCloudInstanceName(dsn, null, platform); + var fromCloud = _FromResolvedFields(ParseCloudInstanceName(dsn, connection?.SecretKey, null, platform), platform); + connection = connection?.MergeInto(fromCloud) ?? fromCloud; } else { diff --git a/src/EdgeDB.Net.Driver/Utils/ConfigUtils.cs b/src/EdgeDB.Net.Driver/Utils/ConfigUtils.cs index 81045593..f65995d1 100644 --- a/src/EdgeDB.Net.Driver/Utils/ConfigUtils.cs +++ b/src/EdgeDB.Net.Driver/Utils/ConfigUtils.cs @@ -224,7 +224,21 @@ internal T CheckAndGetValue() } } - static ResolvedField? MergeField(ResolvedField? to, ResolvedField? from) + internal static bool TryGetFieldValue(ResolvedField? field, out T value) + { + if (field is ResolvedField.Valid) + { + value = field.Value!; + return true; + } + else + { + value = default(T)!; + return false; + } + } + + internal static ResolvedField? MergeField(ResolvedField? to, ResolvedField? from) { if (to is null) { @@ -312,6 +326,31 @@ Host is null && TLSServerName is null && WaitUntilAvailable is null && ServerSettings.Count == 0; + + internal static ResolvedFields FromCredentials(ConnectionCredentials credentials) + { + ResolvedFields result = new(); + + if (credentials.Host is not null) { result.Host = credentials.Host; } + if (credentials.Port is not null) + { + result.Port = MergeField(result.Port, ParsePort(credentials.Port)); + } + if (credentials.Database is not null) + { + result.DatabaseOrBranch = new DatabaseOrBranch.DatabaseName(credentials.Database); + } + if (credentials.Branch is not null) + { + result.DatabaseOrBranch = new DatabaseOrBranch.BranchName(credentials.Branch); + } + if (credentials.User is not null) { result.User = credentials.User; } + if (credentials.Password is not null) { result.Password = credentials.Password; } + if (credentials.TlsCA is not null) { result.TLSCertificateAuthority = credentials.TlsCA; } + if (credentials.TlsSecurity is not null) { result.TLSSecurity = credentials.TlsSecurity; } + + return result; + } } #endregion diff --git a/src/EdgeDB.Net.Driver/Utils/JsonUtils.cs b/src/EdgeDB.Net.Driver/Utils/JsonUtils.cs new file mode 100644 index 00000000..f06ffd9a --- /dev/null +++ b/src/EdgeDB.Net.Driver/Utils/JsonUtils.cs @@ -0,0 +1,37 @@ +using EdgeDB.DataTypes; +using Newtonsoft.Json; + +namespace EdgeDB.Utils; + +internal class AsStringConverter : JsonConverter +{ + // Always reads numbers as strings + + public override string? ReadJson( + JsonReader reader, + Type objectType, + string? existingValue, + bool hasExistingValue, + JsonSerializer serializer) + { + if (reader.TokenType == JsonToken.Integer) + { + return reader.Value!.ToString(); + } + else if (reader.TokenType == JsonToken.Float) + { + return reader.Value!.ToString(); + } + else if (reader.TokenType == JsonToken.String) + { + return (string)reader.Value!; + } + throw new JsonException("Expected Number or String."); + } + + public override void WriteJson( + JsonWriter writer, string? value, JsonSerializer serializer) + { + throw new NotImplementedException(); + } +} diff --git a/tests/EdgeDB.Tests.Unit/EdgeDB.Tests.Unit.csproj b/tests/EdgeDB.Tests.Unit/EdgeDB.Tests.Unit.csproj index 9e889258..1eedf3a7 100644 --- a/tests/EdgeDB.Tests.Unit/EdgeDB.Tests.Unit.csproj +++ b/tests/EdgeDB.Tests.Unit/EdgeDB.Tests.Unit.csproj @@ -20,4 +20,10 @@ + + + PreserveNewest + + + diff --git a/tests/EdgeDB.Tests.Unit/SharedClientTests.cs b/tests/EdgeDB.Tests.Unit/SharedClientTests.cs new file mode 100644 index 00000000..b643aa3c --- /dev/null +++ b/tests/EdgeDB.Tests.Unit/SharedClientTests.cs @@ -0,0 +1,729 @@ +using EdgeDB.Abstractions; +using EdgeDB.Utils; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using System.Security.Cryptography; +using System.Text; +using System.Text.RegularExpressions; + +namespace EdgeDB.Tests.Unit; + +[TestClass] +public class SharedClientTests +{ + [TestMethod] + public void TestConnectParams() + { + StreamReader reader = new("shared-client-testcases/connection_testcases.json"); + List? testcases = JsonConvert.DeserializeObject>(reader.ReadToEnd()); + if (testcases is null) + { + throw new JsonException("Failed to read 'connection_testcases.json.\n" + + "Is the 'shared-client-testcases' submodule initialised? " + + "Try running 'git submodule update --init'."); + } + + foreach ((int textIndex, TestCase testCase) in testcases.Select((x, i) => (i, x))) + { + if (testCase.FileSystem is not null + && ( + !(testCase.Platform is null && RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) || + !(testCase.Platform == "windows" && !RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) || + !(testCase.Platform == "macos" && RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + )) + { + // skipping unsupported platform test + continue; + } + + if ((testCase.Result is null) == (testCase.Error is null)) + { + throw new Exception("invalid test case: either \"result\" or \"error\" key has to be specified"); + } + + TestResult result = ParseConnection(testCase); + + if (testCase.Result is not null) + { + AssertSameConnection(result, testCase.Result); + } + else if (testCase.Error is not null) + { + AssertSameException(result, testCase.Error); + } + } + } + + private class TestResult + { + public EdgeDBConnection? Connection { get; init; } + public Exception? Exception { get; init; } + + public static implicit operator TestResult(EdgeDBConnection c) => new() {Connection = c}; + public static implicit operator TestResult(Exception x) => new() {Exception = x}; + } + + private static TestResult ParseConnection(TestCase testCase) + { + try + { + ISystemProvider mockSystem = new MockSystemProvider(testCase); + + EdgeDBConnection.Options config = new() + { + Instance = testCase?.Options?.Instance, + Dsn = testCase?.Options?.Dsn, + Host = testCase?.Options?.Host, + Port = ( + testCase?.Options?.Port is null + ? null + : int.TryParse(testCase?.Options?.Port, out var parsedPort) + ? parsedPort + : throw new ConfigurationException( + $"Invalid port: {testCase?.Options?.Port}, not an integer") + ), + Database = testCase?.Options?.Database, + Branch = testCase?.Options?.Branch, + User = testCase?.Options?.User, + Password = testCase?.Options?.Password, + SecretKey = testCase?.Options?.SecretKey, + Credentials = testCase?.Options?.Credentials, + CredentialsFile = testCase?.Options?.CredentialsFile, + TLSCertificateAuthority = testCase?.Options?.TlsCA, + TLSCertificateAuthorityFile = testCase?.Options?.TlsCAFile, + TLSSecurity = ( + testCase?.Options?.TlsSecurity is null + ? null + : TLSSecurityModeParser.Parse(testCase.Options.TlsSecurity) + ), + TLSServerName = testCase?.Options?.TlsServerName, + WaitUntilAvailable = testCase?.Options?.WaitUntilAvailable, + ServerSettings = testCase?.Options?.ServerSettings, + }; + + EdgeDBConnection connection = EdgeDBConnection._Create(config, mockSystem); + + return connection; + } + catch (Exception x) + { + return x; + } + } + + #region Test Assertions + + private static void AssertSameConnection(TestResult result, TestCase.ExpectedResult expectedResult) + { + string expectedHostname = expectedResult.Address is not null + ? expectedResult.Address[0] + : "localhost"; + int expectedPort = expectedResult.Address is not null + ? int.Parse(expectedResult.Address[1]) + : 5656; + string expectedDatabase = expectedResult.Database; + string expectedBranch = expectedResult.Branch; + string expectedUsername = expectedResult.User; + string expectedPassword = expectedResult.Password ?? ""; + string? expectedSecretKey = expectedResult.SecretKey; + string? expectedTLSCertificateAuthority = expectedResult.TlsCAData; + TLSSecurityMode expectedTLSSecurity = expectedResult.TlsSecurity ?? TLSSecurityMode.Strict; + string? expectedTLSServerName = expectedResult.TlsServerName; + int expectedWaitUntilAvailable = + expectedResult.WaitUntilAvailable is not null + && ConfigUtils.TryGetFieldValue( + ConfigUtils.ParseWaitUntilAvailable(expectedResult.WaitUntilAvailable), + out int timeout) + ? timeout + : 30000; + + Assert.IsNull(result.Exception, $"\"{result.Exception?.Message}\"\n{result.Exception?.StackTrace}"); + Assert.IsNotNull(result.Connection); + EdgeDBConnection actual = result.Connection; + + Assert.AreEqual(expectedHostname, actual.Hostname); + Assert.AreEqual(expectedPort, actual.Port); + Assert.AreEqual(expectedDatabase, actual.Database); + Assert.AreEqual(expectedBranch, actual.Branch); + Assert.AreEqual(expectedUsername, actual.Username); + Assert.AreEqual(expectedPassword, actual.Password); + Assert.AreEqual(expectedSecretKey, actual.SecretKey); + Assert.AreEqual(expectedTLSCertificateAuthority, actual.TLSCertificateAuthority); + Assert.AreEqual(expectedTLSSecurity, actual.TLSSecurity); + Assert.AreEqual(expectedTLSServerName, actual.TLSServerName); + Assert.AreEqual(expectedWaitUntilAvailable, actual.WaitUntilAvailable); + CollectionAssert.AreEqual(expectedResult.ServerSettings, actual.ServerSettings); + } + + private static void AssertSameException(TestResult result, TestCase.ExpectedError expectedError) + { + (Type expectedType, Regex expectedRegex) = _errorMapping[expectedError.Type]; + + Assert.IsNull(result.Connection); + Assert.IsNotNull(result.Exception); + Exception actual = result.Exception; + + Assert.IsInstanceOfType( + actual, + expectedType, + $"Exception type {expectedType} expected but got {result.Exception.GetType()}\n" + + $"{result.Exception.Message}\n" + + $"{result.Exception?.StackTrace}"); + Assert.IsTrue( + expectedRegex.Match(actual.Message).Success, + $"Exception message \"{actual.Message}\" does not match pattern \"{expectedRegex}\"\n" + + $"{result.Exception?.StackTrace}"); + } + + private static readonly Dictionary _errorMapping = new() + { + { + "credentials_file_not_found", + ( + typeof(ConfigurationException), + new("cannot read credentials", RegexOptions.Compiled) + ) + }, + { + "project_not_initialised", + ( + typeof(ConfigurationException), + new("Found `\\w+.toml` but the project is not initialized", RegexOptions.Compiled) + ) + }, + { + "no_options_or_toml", + ( + typeof(ConfigurationException), + new("No `gel.toml` found and no connection options specified", RegexOptions.Compiled) + ) + }, + { + "invalid_credentials_file", + ( + typeof(ConfigurationException), + new("Invalid CredentialsFile", RegexOptions.Compiled) + ) + }, + { + "invalid_dsn_or_instance_name", + ( + typeof(ConfigurationException), + new("Invalid (?:DSN|instance name)", RegexOptions.Compiled) + ) + }, + { + "invalid_instance_name", + ( + typeof(ConfigurationException), + new("invalid instance name", RegexOptions.Compiled) + ) + }, + { + "invalid_dsn", + ( + typeof(ConfigurationException), + new("Invalid DSN", RegexOptions.Compiled) + ) + }, + { + "unix_socket_unsupported", + ( + typeof(ConfigurationException), + new("unix socket paths not supported", RegexOptions.Compiled) + ) + }, + { + "invalid_host", + ( + typeof(ConfigurationException), + new("Invalid host", RegexOptions.Compiled) + ) + }, + { + "invalid_port", + ( + typeof(ConfigurationException), + new("Invalid port", RegexOptions.Compiled) + ) + }, + { + "invalid_user", + ( + typeof(ConfigurationException), + new("Invalid user", RegexOptions.Compiled) + ) + }, + { + "invalid_database", + ( + typeof(ConfigurationException), + new("Invalid database", RegexOptions.Compiled) + ) + }, + { + "multiple_compound_env", + ( + typeof(ConfigurationException), + new("Cannot have more than one of the following connection environment variables", RegexOptions.Compiled) + ) + }, + { + "multiple_compound_opts", + ( + typeof(ConfigurationException), + new("Connection options cannot have more than one of the following values", RegexOptions.Compiled) + ) + }, + { + "exclusive_options", + ( + typeof(ConfigurationException), + new("are mutually exclusive", RegexOptions.Compiled) + ) + }, + { + "env_not_found", + ( + typeof(ConfigurationException), + new("environment variable \".*\" doesn\'t exist", RegexOptions.Compiled) + ) + }, + { + "file_not_found", + ( + typeof(ConfigurationException), + new("could not find file", RegexOptions.Compiled) + ) + }, + { + "invalid_tls_security", + ( + typeof(ConfigurationException), + new( + "Invalid TLS Security|\\w+ must be strict when \\w+ is strict", + RegexOptions.Compiled) + ) + }, + { + "invalid_secret_key", + ( + typeof(ConfigurationException), + new("Invalid secret key", RegexOptions.Compiled) + ) + }, + { + "secret_key_not_found", + ( + typeof(ConfigurationException), + new("Cannot connect to cloud instances without secret key", RegexOptions.Compiled) + ) + }, + { + "docker_tcp_port", + ( + typeof(ConfigurationException), + new("\\w+_PORT in \"tcp://host:port\" format, so will be ignored", RegexOptions.Compiled) + ) + }, + { + "gel_and_edgedb", + ( + typeof(ConfigurationException), + new("Both GEL_\\w+ and EDGEDB_\\w+ are set; EDGEDB_\\w+ will be ignored", RegexOptions.Compiled) + ) + }, + }; + + #endregion + + #region MockSystemProvider + + class MockSystemProvider : BaseDefaultSystemProvider + { + private readonly string? _homeDir; + private readonly string? _currentDir; + private readonly Dictionary _envVars; + private Dictionary _files; + + public List Warnings { get; } = new(); + + public MockSystemProvider(TestCase testCase) + { + _homeDir = testCase.FileSystem?.HomeDir; + _currentDir = testCase.FileSystem?.CurrentDir; + _envVars = testCase.EnvVars ?? new(); + _files = CacheFiles(testCase?.FileSystem?.Files); + } + + private Dictionary CacheFiles(Dictionary? files) + { + return files?.SelectMany( + x => { + string path = x.Key; + TestCase.File file = x.Value; + + if (file.Contents is not null) + { + return new List<(string, string)>(){(path, file.Contents)}; + } + else + { + if (file.Fields is null) + { + throw new Exception("File must be either string or json object of fields"); + } + if (!file.Fields.ContainsKey("project-path")) + { + throw new Exception("File as object must have \"project-path\" field"); + } + + List<(string,string)> subfiles = new(); + + string dir = path.Replace("${HASH}", ProjectPathHash(file.Fields["project-path"])); + + foreach (KeyValuePair field in file.Fields) + { + subfiles.Add((CombinePaths(new string[]{ dir, field.Key }), field.Value)); + } + + return subfiles; + } + }) + .ToDictionary(x => x.Item1, x => x.Item2) + ?? new(); + } + + private string ProjectPathHash(string path) + { + if (IsOSPlatform(OSPlatform.Windows) && !path.StartsWith("\\\\")) + { + path = "\\\\?\\" + path; + } + + return Convert.ToHexString(SHA1.HashData(Encoding.UTF8.GetBytes(path))); + } + + public override string GetHomeDir() => _homeDir ?? base.GetHomeDir(); + + public override string GetCurrentDirectory() => _currentDir ?? base.GetCurrentDirectory(); + + public override string? GetEnvVariable(string name) + => _envVars.TryGetValue(name, out var val) + ? val + : null; + + public override bool FileExists(string path) => _files.ContainsKey(path); + + public override string FileReadAllText(string path) => _files[path]; + + public override void WriteWarning(string message) + => Warnings.Add(message); + } + + #endregion + + #region TestCase + + class TestCase + { + [JsonProperty("name")] + public string Name { get; init; } = string.Empty; + + [JsonProperty("opts")] + public OptionsData? Options { get; init; } + + [JsonProperty("env")] + public Dictionary? EnvVars { get; init; } + + [JsonProperty("platform")] + public string? Platform { get; init; } + + [JsonProperty("fs")] + public FileSystemData? FileSystem { get; init; } + + [JsonProperty("warnings")] + public List? Warnings { get; init; } + + [JsonProperty("result")] + public ExpectedResult? Result { get; init; } + + [JsonProperty("error")] + public ExpectedError? Error { get; init; } + + public class OptionsData + { + [JsonProperty("instance")] + public string? Instance { get; init; } + + [JsonProperty("dsn")] + public string? Dsn { get; init; } + + [JsonProperty("host")] + public string? Host { get; init; } + + [JsonProperty("port")] + [JsonConverter(typeof(AsStringConverter))] + public string? Port { get; init; } + + [JsonProperty("database")] + public string? Database { get; init; } + + [JsonProperty("branch")] + public string? Branch { get; init; } + + [JsonProperty("user")] + public string? User { get; init; } + + [JsonProperty("password")] + public string? Password { get; init; } + + [JsonProperty("secretKey")] + public string? SecretKey { get; init; } + + [JsonProperty("credentials")] + public string? Credentials { get; init; } + + [JsonProperty("credentialsFile")] + public string? CredentialsFile { get; init; } + + [JsonProperty("tlsCA")] + public string? TlsCA { get; init; } + + [JsonProperty("tlsCAFile")] + public string? TlsCAFile { get; init; } + + [JsonProperty("tlsSecurity")] + public string? TlsSecurity { get; init; } + + [JsonProperty("tlsServerName")] + public string? TlsServerName { get; init; } + + [JsonProperty("waitUntilAvailable")] + public string? WaitUntilAvailable { get; init; } + + [JsonProperty("serverSettings")] + public Dictionary? ServerSettings { get; init; } + } + + public class Credentials + { + [JsonProperty("host")] + public string? Host { get; init; } + + [JsonProperty("port")] + [JsonConverter(typeof(AsStringConverter))] + public string? Port { get; init; } + + [JsonProperty("database")] + public string? Database { get; init; } + + [JsonProperty("branch")] + public string? Branch { get; init; } + + [JsonProperty("user")] + public string? User { get; init; } + + [JsonProperty("password")] + public string? Password { get; init; } + + [JsonProperty("tls_ca")] + public string? TlsCA { get; init; } + + [JsonProperty("tls_security")] + [JsonConverter(typeof(TLSSecurityModeParser))] + public TLSSecurityMode? TlsSecurity { get; init; } + } + + public class FileSystemData + { + [JsonProperty("cwd")] + public string? CurrentDir { get; init; } + + [JsonProperty("homedir")] + public string? HomeDir { get; init; } + + [JsonProperty("files")] + public Dictionary? Files { get; init; } + } + + [JsonConverter(typeof(FileJsonConverter))] + public class File + { + // Has either string contents or has explicitly defined instance information + + // string contents + public string? Contents { get; init; } + + // instance information + public Dictionary? Fields { get; init; } + } + + public class ExpectedResult + { + [JsonProperty("address")] + [JsonConverter(typeof(AsListStringConverter))] + public List Address { get; init; } = new(); + + [JsonProperty("database")] + public string Database { get; init; } = string.Empty; + + [JsonProperty("branch")] + public string Branch { get; init; } = string.Empty; + + [JsonProperty("user")] + public string User { get; init; } = string.Empty; + + [JsonProperty("password")] + public string? Password { get; init; } + + [JsonProperty("secretKey")] + public string? SecretKey { get; init; } + + [JsonProperty("tlsCAData")] + public string? TlsCAData { get; init; } + + [JsonProperty("tlsSecurity")] + [JsonConverter(typeof(TLSSecurityModeParser))] + public TLSSecurityMode? TlsSecurity { get; init; } + + [JsonProperty("tlsServerName")] + public string? TlsServerName { get; init; } + + [JsonProperty("waitUntilAvailable")] + public string? WaitUntilAvailable { get; init; } + + [JsonProperty("serverSettings")] + public Dictionary? ServerSettings { get; init; } + } + + public class ExpectedError + { + [JsonProperty("type")] + public string Type { get; init; } = string.Empty; + } + + private class AsListStringConverter : JsonConverter> + { + // Always reads numbers as strings + + public override List? ReadJson( + JsonReader reader, + Type objectType, + List? existingValue, + bool hasExistingValue, + JsonSerializer serializer) + { + List result = new(); + + // skip JsonToken.StartArray + while (reader.Read()) + { + if (reader.TokenType == JsonToken.EndArray) + { + break; + } + + if (reader.TokenType == JsonToken.Integer) + { + result.Add(reader.Value!.ToString()!); + } + else if (reader.TokenType == JsonToken.Float) + { + result.Add(reader.Value!.ToString()!); + } + else if (reader.TokenType == JsonToken.String) + { + result.Add((string)reader.Value!); + } + else + { + throw new JsonException( + $"Invalid {reader.TokenType} token: \"{reader.Value}\", expected Number or String."); + } + } + + return result; + } + + public override void WriteJson( + JsonWriter writer, List? value, JsonSerializer serializer) + { + throw new NotImplementedException(); + } + } + + private class FileJsonConverter : JsonConverter + { + public override File? ReadJson( + JsonReader reader, + Type objectType, + File? existingValue, + bool hasExistingValue, + JsonSerializer serializer) + { + if (reader.TokenType == JsonToken.String) + { + // string contents + return new() + { + Contents = (string)reader.Value!, + }; + } + else if (reader.TokenType == JsonToken.StartObject) + { + // instance information + Dictionary fields = new(); + while (reader.Read()) + { + if (reader.TokenType == JsonToken.EndObject) + { + break; + } + + if (reader.TokenType != JsonToken.PropertyName) + { + throw new JsonException( + $"Invalid {reader.TokenType} token: \"{reader.Value}\", expected PropertyName."); + } + + string propertyName = (string)reader.Value!; + + reader.Read(); + if (reader.TokenType != JsonToken.String) + { + throw new JsonException( + $"Invalid {reader.TokenType} token: \"{reader.Value}\", expected String."); + } + + string value = (string)reader.Value!; + + fields[propertyName] = value; + } + + return new() + { + Fields = fields, + }; + } + else + { + throw new JsonException("Could not read File object."); + } + } + + public override void WriteJson( + JsonWriter writer, File? value, JsonSerializer serializer) + { + throw new NotImplementedException(); + } + } + } + + #endregion +} diff --git a/tests/EdgeDB.Tests.Unit/shared-client-testcases b/tests/EdgeDB.Tests.Unit/shared-client-testcases new file mode 160000 index 00000000..d720f509 --- /dev/null +++ b/tests/EdgeDB.Tests.Unit/shared-client-testcases @@ -0,0 +1 @@ +Subproject commit d720f509b352f2fe587b5ea6d0c624553b95e04b