diff --git a/src/SqlClient.DesignTime/DesignTime.fs b/src/SqlClient.DesignTime/DesignTime.fs index ed6f506d..e5d7dd18 100644 --- a/src/SqlClient.DesignTime/DesignTime.fs +++ b/src/SqlClient.DesignTime/DesignTime.fs @@ -639,7 +639,7 @@ type DesignTime private() = ] let body1 (args: _ list) = - Expr.NewObject(ctorImpl, designTimeConfig :: <@@ Connection.Choice1Of3 %%args.Head @@> :: args.Tail) + Expr.NewObject(ctorImpl, designTimeConfig :: <@@ Connection.ConnectionString %%args.Head @@> :: args.Tail) yield ProvidedConstructor(parameters1, invokeCode = body1) :> MemberInfo @@ -663,10 +663,12 @@ type DesignTime private() = let connArg = <@@ if box (%%args.[1]: SqlTransaction) <> null - then Connection.Choice3Of3 %%args.[1] + then Connection.SystemDataSqlClientTransaction %%args.[1] elif box (%%args.[0]: SqlConnection) <> null - then Connection.Choice2Of3 %%args.Head - else Connection.Choice1Of3( %%connectionStringExpr) + then Connection.SystemDataSqlClientConnection %%args.Head + elif box (%%connectionStringExpr: string) <> null + then Connection.ConnectionString( %%connectionStringExpr) + else failwithf "design time doesn't support IDb* constructors" @@> Expr.NewObject(ctorImpl, [ designTimeConfig ; connArg; args.[2] ]) diff --git a/src/SqlClient/DataTable.fs b/src/SqlClient/DataTable.fs index dda3f49c..26e8f0a9 100644 --- a/src/SqlClient/DataTable.fs +++ b/src/SqlClient/DataTable.fs @@ -8,7 +8,7 @@ open FSharp.Data.SqlClient.Internals [] [] -type DataTable<'T when 'T :> DataRow>(selectCommand: SqlCommand, ?connectionString: Lazy) = +type DataTable<'T when 'T :> DataRow>(selectCommand: IDbCommand, ?connectionString: Lazy) = inherit DataTable() let rows = base.Rows @@ -42,10 +42,12 @@ type DataTable<'T when 'T :> DataRow>(selectCommand: SqlCommand, ?connectionStri member this.Update(?connection, ?transaction, ?batchSize, ?continueUpdateOnError, ?timeout: TimeSpan) = // not supported on all DataTable instances - match selectCommand with - | null -> failwith "This command wasn't constructed from SqlProgrammabilityProvider, call to Update is not supported." - | _ -> () - + let selectCommand = + match selectCommand with + | null -> failwith "This command wasn't constructed from SqlProgrammabilityProvider, call to Update is not supported." + | :? SqlCommand as selectCommand -> selectCommand + | _ -> failwithf "This command has type %s, this is only supported for commands instanciated with System.Data.SqlClient db types." (selectCommand.GetType().FullName) + connection |> Option.iter selectCommand.set_Connection transaction |> Option.iter selectCommand.set_Transaction @@ -94,9 +96,12 @@ type DataTable<'T when 'T :> DataRow>(selectCommand: SqlCommand, ?connectionStri | _, Some(t: SqlTransaction) -> t.Connection, t | Some c, None -> c, null | None, None -> - match selectCommand with - | null -> failwith "To issue BulkCopy on this table, you need to provide your own connection or transaction" - | _ -> () + let selectCommand = + match selectCommand with + | null -> failwith "To issue BulkCopy on this table, you need to provide your own connection or transaction" + | :? SqlCommand as selectCommand -> selectCommand + | _ -> failwithf "This command has type %s, this is only supported for commands instanciated with System.Data.SqlClient db types." (selectCommand.GetType().FullName) + if this.IsDirectTable then assert(connectionString.IsSome) diff --git a/src/SqlClient/Extensions.fs b/src/SqlClient/Extensions.fs index 02b95566..9d342835 100644 --- a/src/SqlClient/Extensions.fs +++ b/src/SqlClient/Extensions.fs @@ -6,8 +6,8 @@ open System.Data.SqlClient [] module Extensions = - - type SqlDataReader with + + type IDataReader with member internal this.MapRowValues<'TItem>( rowMapping) = seq { use _ = this @@ -49,7 +49,7 @@ module Extensions = yield mapper cursor } - type SqlConnection with + type IDbConnection with //address an issue when regular Dispose on SqlConnection needed for async computation //wipes out all properties like ConnectionString in addition to closing connection to db @@ -63,7 +63,7 @@ module Extensions = member this.IsSqlAzure = assert (this.State = ConnectionState.Open) - use cmd = new SqlCommand("SELECT SERVERPROPERTY('edition')", this) + use cmd = this.CreateCommand(CommandText = "SELECT SERVERPROPERTY('edition')") cmd.ExecuteScalar().Equals("SQL Azure") diff --git a/src/SqlClient/ISqlCommand.fs b/src/SqlClient/ISqlCommand.fs index 62e7a773..060001e2 100644 --- a/src/SqlClient/ISqlCommand.fs +++ b/src/SqlClient/ISqlCommand.fs @@ -12,7 +12,7 @@ type ISqlCommand = abstract ToTraceString: parameters: (string * obj)[] -> string - abstract Raw: SqlCommand with get + abstract Raw: System.Data.IDbCommand with get namespace FSharp.Data.SqlClient.Internals @@ -49,27 +49,45 @@ type DesignTimeConfig = { ExpectedDataReaderColumns: (string * string)[] } -type internal Connection = Choice - +[] +[] +type Connection = + + | [] ConnectionString of string + | SystemDataSqlClientConnection of SqlConnection + | SystemDataSqlClientTransaction of SqlTransaction + | [] SystemDataIDbConnection of IDbConnection + | [] SystemDataIDbTransaction of IDbTransaction + +#nowarn "44" +#nowarn "57" [] type ``ISqlCommand Implementation``(cfg: DesignTimeConfig, connection: Connection, commandTimeout) = - let cmd = new SqlCommand(cfg.SqlStatement, CommandTimeout = commandTimeout) - let manageConnection = - match connection with - | Choice1Of3 connectionString -> - cmd.Connection <- new SqlConnection(connectionString) - true - | Choice2Of3 instance -> - cmd.Connection <- instance - false - | Choice3Of3 tran -> - cmd.Transaction <- tran - cmd.Connection <- tran.Connection - false + let manageConnection= + match connection with + | Connection.ConnectionString _ -> true + | Connection.SystemDataSqlClientConnection _ + | Connection.SystemDataSqlClientTransaction _ + | Connection.SystemDataIDbConnection _ + | Connection.SystemDataIDbTransaction _ -> false + let cmd : IDbCommand = + match connection with + | Connection.ConnectionString connectionString -> + let cnx = new SqlConnection(connectionString) + cnx.CreateCommand(CommandText = cfg.SqlStatement, CommandTimeout = commandTimeout) + | Connection.SystemDataSqlClientConnection cnx -> + cnx.CreateCommand(CommandText = cfg.SqlStatement, CommandTimeout = commandTimeout) + | Connection.SystemDataSqlClientTransaction trx -> + trx.Connection.CreateCommand(CommandText = cfg.SqlStatement, CommandTimeout = commandTimeout, Transaction = trx) + | Connection.SystemDataIDbConnection cnx -> + cnx.CreateCommand(CommandText = cfg.SqlStatement, CommandTimeout = commandTimeout) + | Connection.SystemDataIDbTransaction trx -> + trx.Connection.CreateCommand(CommandText = cfg.SqlStatement, CommandTimeout = commandTimeout, Transaction = trx) do cmd.CommandType <- if cfg.IsStoredProcedure then CommandType.StoredProcedure else CommandType.Text - cmd.Parameters.AddRange( cfg.Parameters) + for parameter in cfg.Parameters do + cmd.Parameters.Add parameter |> ignore let getReaderBehavior() = seq { @@ -90,19 +108,19 @@ type ``ISqlCommand Implementation``(cfg: DesignTimeConfig, connection: Connectio match cfg.ResultType with | ResultType.DataReader -> ``ISqlCommand Implementation``.ExecuteReader >> box, - ``ISqlCommand Implementation``.AsyncExecuteReader >> box, + ``ISqlCommand Implementation``.AsyncExecuteReader connection >> box, notImplemented, notImplemented | ResultType.DataTable -> ``ISqlCommand Implementation``.ExecuteDataTable >> box, - ``ISqlCommand Implementation``.AsyncExecuteDataTable >> box, + ``ISqlCommand Implementation``.AsyncExecuteDataTable connection >> box, notImplemented, notImplemented | ResultType.Records | ResultType.Tuples -> match box cfg.RowMapping, cfg.ItemTypeName with | null, null -> ``ISqlCommand Implementation``.ExecuteNonQuery manageConnection >> box, - ``ISqlCommand Implementation``.AsyncExecuteNonQuery manageConnection >> box, + ``ISqlCommand Implementation``.AsyncExecuteNonQuery connection manageConnection >> box, notImplemented, notImplemented | rowMapping, itemTypeName -> @@ -215,40 +233,42 @@ type ``ISqlCommand Implementation``(cfg: DesignTimeConfig, connection: Connectio cmd.Connection.Dispose() cmd.Dispose() - static member internal SetParameters(cmd: SqlCommand, parameters: (string * obj)[]) = + static member internal SetParameters(cmd: IDbCommand, parameters: (string * obj)[]) = for name, value in parameters do - let p = cmd.Parameters.[name] - - if p.Direction.HasFlag(ParameterDirection.Input) - then + let p = cmd.Parameters.[name] :?> IDbDataParameter + let cmdIsSystemDataSqlClient = cmd :? SqlCommand + if p.Direction.HasFlag(ParameterDirection.Input) then match value with - | null -> - p.Value <- DBNull.Value + | null -> p.Value <- DBNull.Value | _ -> - match p.SqlDbType with - | SqlDbType.Structured -> - // TODO: Maybe make this lazy? + if cmdIsSystemDataSqlClient then + let p = p :?> SqlParameter + match p.SqlDbType with + | SqlDbType.Structured -> + // TODO: Maybe make this lazy? - //p.Value <- value |> unbox |> Seq.cast + //p.Value <- value |> unbox |> Seq.cast - //done via reflection because not implemented on Mono + //done via reflection because not implemented on Mono - let sqlDataRecordType = typeof.Assembly.GetType("Microsoft.SqlServer.Server.SqlDataRecord", throwOnError = true) - let records = typeof.GetMethod("Cast").MakeGenericMethod(sqlDataRecordType).Invoke(null, [| value |]) - let hasAny = - typeof - .GetMethods(BindingFlags.Static ||| BindingFlags.Public) - .First(fun m -> m.Name = "Any" && m.GetParameters().Count() = 1) - .MakeGenericMethod(sqlDataRecordType).Invoke(null, [| records |]) :?> bool - p.Value <- if not hasAny then null else records - | _ -> p.Value <- value - + let sqlDataRecordType = typeof.Assembly.GetType("Microsoft.SqlServer.Server.SqlDataRecord", throwOnError = true) + let records = typeof.GetMethod("Cast").MakeGenericMethod(sqlDataRecordType).Invoke(null, [| value |]) + let hasAny = + typeof + .GetMethods(BindingFlags.Static ||| BindingFlags.Public) + .First(fun m -> m.Name = "Any" && m.GetParameters().Count() = 1) + .MakeGenericMethod(sqlDataRecordType).Invoke(null, [| records |]) :?> bool + p.Value <- if not hasAny then null else records + | _ -> p.Value <- value + else + // happy go lucky... + p.Value <- value elif p.Direction.HasFlag(ParameterDirection.Output) && value :? Array then p.Size <- (value :?> Array).Length //Execute/AsyncExecute versions - static member internal VerifyResultsetColumns(cursor: SqlDataReader, expected) = + static member internal VerifyResultsetColumns(cursor: IDataReader, expected) = if FsharpDataSqlClientConfiguration.Current.ResultsetRuntimeVerification then if cursor.FieldCount < Array.length expected @@ -266,32 +286,48 @@ type ``ISqlCommand Implementation``(cfg: DesignTimeConfig, connection: Connectio cursor.Close() invalidOp message - static member internal ExecuteReader(cmd, getReaderBehavior, parameters, expectedDataReaderColumns) = + static member internal ExecuteReader(cmd : IDbCommand, getReaderBehavior, parameters, expectedDataReaderColumns) = ``ISqlCommand Implementation``.SetParameters(cmd, parameters) let cursor = cmd.ExecuteReader( getReaderBehavior()) ``ISqlCommand Implementation``.VerifyResultsetColumns(cursor, expectedDataReaderColumns) cursor - static member internal AsyncExecuteReader(cmd, getReaderBehavior, parameters, expectedDataReaderColumns) = + static member internal AsyncExecuteReader connection (cmd : IDbCommand, getReaderBehavior, parameters, expectedDataReaderColumns) = async { ``ISqlCommand Implementation``.SetParameters(cmd, parameters) - let! cursor = cmd.AsyncExecuteReader( getReaderBehavior()) - ``ISqlCommand Implementation``.VerifyResultsetColumns(cursor, expectedDataReaderColumns) - return cursor + match connection with + | Connection.ConnectionString _ + | Connection.SystemDataSqlClientConnection _ + | Connection.SystemDataSqlClientTransaction _ -> + let! cursor = (cmd :?> SqlCommand).AsyncExecuteReader( getReaderBehavior()) + ``ISqlCommand Implementation``.VerifyResultsetColumns(cursor, expectedDataReaderColumns) + return cursor + | Connection.SystemDataIDbConnection _ + | Connection.SystemDataIDbTransaction _ -> + raise (NotImplementedException "not supported") + return Unchecked.defaultof<_> } - static member internal ExecuteDataTable(cmd, getReaderBehavior, parameters, expectedDataReaderColumns) = + static member internal ExecuteDataTable(cmd : IDbCommand, getReaderBehavior, parameters, expectedDataReaderColumns) = use cursor = ``ISqlCommand Implementation``.ExecuteReader(cmd, getReaderBehavior, parameters, expectedDataReaderColumns) let result = new DataTable(cmd) result.Load(cursor) result - static member internal AsyncExecuteDataTable(cmd, getReaderBehavior, parameters, expectedDataReaderColumns) = + static member internal AsyncExecuteDataTable connection (cmd, getReaderBehavior, parameters, expectedDataReaderColumns) = async { - use! reader = ``ISqlCommand Implementation``.AsyncExecuteReader(cmd, getReaderBehavior, parameters, expectedDataReaderColumns) - let result = new DataTable(cmd) - result.Load(reader) - return result +// match connection with +// | Connection.ConnectionString _ +// | Connection.SystemDataSqlClientConnection _ +// | Connection.SystemDataSqlClientTransaction _ -> + use! reader = ``ISqlCommand Implementation``.AsyncExecuteReader connection (cmd, getReaderBehavior, parameters, expectedDataReaderColumns) + let result = new DataTable(cmd) + result.Load(reader) + return result +// | Connection.SystemDataIDbConnection _ +// | Connection.SystemDataIDbTransaction _ -> +// raise (NotImplementedException "not supported") +// return Unchecked.defaultof<_> } static member internal ExecuteSeq<'TItem> (rank, rowMapper) = fun(cmd: SqlCommand, getReaderBehavior, parameters, expectedDataReaderColumns) -> @@ -331,10 +367,10 @@ type ``ISqlCommand Implementation``(cfg: DesignTimeConfig, connection: Connectio box resultset - static member internal AsyncExecuteSeq<'TItem> (rank, rowMapper) = fun(cmd, getReaderBehavior, parameters, expectedDataReaderColumns) -> + static member internal AsyncExecuteSeq<'TItem> (rank, rowMapper) connection = fun(cmd, getReaderBehavior, parameters, expectedDataReaderColumns) -> let xs = async { - let! reader = ``ISqlCommand Implementation``.AsyncExecuteReader(cmd, getReaderBehavior, parameters, expectedDataReaderColumns) + let! reader = ``ISqlCommand Implementation``.AsyncExecuteReader connection (cmd, getReaderBehavior, parameters, expectedDataReaderColumns) return reader.MapRowValues<'TItem>( rowMapper) } @@ -362,17 +398,26 @@ type ``ISqlCommand Implementation``(cfg: DesignTimeConfig, connection: Connectio let recordsAffected = cmd.ExecuteNonQuery() for i = 0 to parameters.Length - 1 do let name, _ = parameters.[i] - let p = cmd.Parameters.[name] + let p = cmd.Parameters.[name] :?> IDbDataParameter if p.Direction.HasFlag( ParameterDirection.Output) then parameters.[i] <- name, p.Value recordsAffected - static member internal AsyncExecuteNonQuery manageConnection (cmd, _, parameters, _) = + static member internal AsyncExecuteNonQuery connection manageConnection (cmd, _, parameters, _) = ``ISqlCommand Implementation``.SetParameters(cmd, parameters) async { - use _ = cmd.Connection.UseLocally(manageConnection ) - return! cmd.AsyncExecuteNonQuery() + match connection with + | Connection.ConnectionString _ + | Connection.SystemDataSqlClientConnection _ + | Connection.SystemDataSqlClientTransaction _ -> + let cmd = unbox cmd + use _ = cmd.Connection.UseLocally(manageConnection ) + return! cmd.AsyncExecuteNonQuery() + | Connection.SystemDataIDbConnection _ + | Connection.SystemDataIDbTransaction _ -> + failwithf "unsupported with other client" + return! Unchecked.defaultof<_> } #if WITH_LEGACY_NAMESPACE