Skip to content

enable escape hatch to pass System.Data.IDbConnection/Transaction at runtime #432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/SqlClient.DesignTime/DesignTime.fs
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ type DesignTime private() =
]

let body1 (args: _ list) =
Expr.NewObject(ctorImpl, designTimeConfig :: <@@ Connection.Choice1Of3 %%args.Head @@> :: args.Tail)
Expr.NewObject(ctorImpl, designTimeConfig :: <@@ Connection.ConnectionString %%args.Head @@> :: args.Tail)

yield ProvidedConstructor(parameters1, invokeCode = body1) :> MemberInfo

Expand All @@ -663,10 +663,12 @@ type DesignTime private() =
let connArg =
<@@
if box (%%args.[1]: SqlTransaction) <> null
then Connection.Choice3Of3 %%args.[1]
then Connection.SystemDataSqlClientTransaction %%args.[1]
elif box (%%args.[0]: SqlConnection) <> null
then Connection.Choice2Of3 %%args.Head
else Connection.Choice1Of3( %%connectionStringExpr)
then Connection.SystemDataSqlClientConnection %%args.Head
elif box (%%connectionStringExpr: string) <> null
then Connection.ConnectionString( %%connectionStringExpr)
else failwithf "design time doesn't support IDb* constructors"
@@>
Expr.NewObject(ctorImpl, [ designTimeConfig ; connArg; args.[2] ])

Expand Down
21 changes: 13 additions & 8 deletions src/SqlClient/DataTable.fs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ open FSharp.Data.SqlClient.Internals

[<Sealed>]
[<CompilerMessageAttribute("This API supports the FSharp.Data.SqlClient infrastructure and is not intended to be used directly from your code.", 101, IsHidden = true)>]
type DataTable<'T when 'T :> DataRow>(selectCommand: SqlCommand, ?connectionString: Lazy<string>) =
type DataTable<'T when 'T :> DataRow>(selectCommand: IDbCommand, ?connectionString: Lazy<string>) =
inherit DataTable()

let rows = base.Rows
Expand Down Expand Up @@ -42,10 +42,12 @@ type DataTable<'T when 'T :> DataRow>(selectCommand: SqlCommand, ?connectionStri

member this.Update(?connection, ?transaction, ?batchSize, ?continueUpdateOnError, ?timeout: TimeSpan) =
// not supported on all DataTable instances
match selectCommand with
| null -> failwith "This command wasn't constructed from SqlProgrammabilityProvider, call to Update is not supported."
| _ -> ()

let selectCommand =
match selectCommand with
| null -> failwith "This command wasn't constructed from SqlProgrammabilityProvider, call to Update is not supported."
| :? SqlCommand as selectCommand -> selectCommand
| _ -> failwithf "This command has type %s, this is only supported for commands instanciated with System.Data.SqlClient db types." (selectCommand.GetType().FullName)

connection |> Option.iter selectCommand.set_Connection
transaction |> Option.iter selectCommand.set_Transaction

Expand Down Expand Up @@ -94,9 +96,12 @@ type DataTable<'T when 'T :> DataRow>(selectCommand: SqlCommand, ?connectionStri
| _, Some(t: SqlTransaction) -> t.Connection, t
| Some c, None -> c, null
| None, None ->
match selectCommand with
| null -> failwith "To issue BulkCopy on this table, you need to provide your own connection or transaction"
| _ -> ()
let selectCommand =
match selectCommand with
| null -> failwith "To issue BulkCopy on this table, you need to provide your own connection or transaction"
| :? SqlCommand as selectCommand -> selectCommand
| _ -> failwithf "This command has type %s, this is only supported for commands instanciated with System.Data.SqlClient db types." (selectCommand.GetType().FullName)

if this.IsDirectTable
then
assert(connectionString.IsSome)
Expand Down
8 changes: 4 additions & 4 deletions src/SqlClient/Extensions.fs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ open System.Data.SqlClient

[<AutoOpen>]
module Extensions =

type SqlDataReader with
type IDataReader with
member internal this.MapRowValues<'TItem>( rowMapping) =
seq {
use _ = this
Expand Down Expand Up @@ -49,7 +49,7 @@ module Extensions =
yield mapper cursor
}

type SqlConnection with
type IDbConnection with

//address an issue when regular Dispose on SqlConnection needed for async computation
//wipes out all properties like ConnectionString in addition to closing connection to db
Expand All @@ -63,7 +63,7 @@ module Extensions =

member this.IsSqlAzure =
assert (this.State = ConnectionState.Open)
use cmd = new SqlCommand("SELECT SERVERPROPERTY('edition')", this)
use cmd = this.CreateCommand(CommandText = "SELECT SERVERPROPERTY('edition')")
cmd.ExecuteScalar().Equals("SQL Azure")


Expand Down
165 changes: 105 additions & 60 deletions src/SqlClient/ISqlCommand.fs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ type ISqlCommand =

abstract ToTraceString: parameters: (string * obj)[] -> string

abstract Raw: SqlCommand with get
abstract Raw: System.Data.IDbCommand with get


namespace FSharp.Data.SqlClient.Internals
Expand Down Expand Up @@ -49,27 +49,45 @@ type DesignTimeConfig = {
ExpectedDataReaderColumns: (string * string)[]
}

type internal Connection = Choice<string, SqlConnection, SqlTransaction>

[<CompilerMessageAttribute("This API supports the FSharp.Data.SqlClient infrastructure and is not intended to be used directly from your code.", 101, IsHidden = true)>]
[<RequireQualifiedAccess>]
type Connection =

| [<Obsolete "Leaving the connection management to FSharp.Data.SqlClient is not recommended practice.">] ConnectionString of string
| SystemDataSqlClientConnection of SqlConnection
| SystemDataSqlClientTransaction of SqlTransaction
| [<Experimental "This is not supported feature.">] SystemDataIDbConnection of IDbConnection
| [<Experimental "This is not supported feature.">] SystemDataIDbTransaction of IDbTransaction

#nowarn "44"
#nowarn "57"
[<CompilerMessageAttribute("This API supports the FSharp.Data.SqlClient infrastructure and is not intended to be used directly from your code.", 101, IsHidden = true)>]
type ``ISqlCommand Implementation``(cfg: DesignTimeConfig, connection: Connection, commandTimeout) =
let cmd = new SqlCommand(cfg.SqlStatement, CommandTimeout = commandTimeout)
let manageConnection =
match connection with
| Choice1Of3 connectionString ->
cmd.Connection <- new SqlConnection(connectionString)
true
| Choice2Of3 instance ->
cmd.Connection <- instance
false
| Choice3Of3 tran ->
cmd.Transaction <- tran
cmd.Connection <- tran.Connection
false
let manageConnection=
match connection with
| Connection.ConnectionString _ -> true
| Connection.SystemDataSqlClientConnection _
| Connection.SystemDataSqlClientTransaction _
| Connection.SystemDataIDbConnection _
| Connection.SystemDataIDbTransaction _ -> false
let cmd : IDbCommand =
match connection with
| Connection.ConnectionString connectionString ->
let cnx = new SqlConnection(connectionString)
cnx.CreateCommand(CommandText = cfg.SqlStatement, CommandTimeout = commandTimeout)
| Connection.SystemDataSqlClientConnection cnx ->
cnx.CreateCommand(CommandText = cfg.SqlStatement, CommandTimeout = commandTimeout)
| Connection.SystemDataSqlClientTransaction trx ->
trx.Connection.CreateCommand(CommandText = cfg.SqlStatement, CommandTimeout = commandTimeout, Transaction = trx)
| Connection.SystemDataIDbConnection cnx ->
cnx.CreateCommand(CommandText = cfg.SqlStatement, CommandTimeout = commandTimeout)
| Connection.SystemDataIDbTransaction trx ->
trx.Connection.CreateCommand(CommandText = cfg.SqlStatement, CommandTimeout = commandTimeout, Transaction = trx)

do
cmd.CommandType <- if cfg.IsStoredProcedure then CommandType.StoredProcedure else CommandType.Text
cmd.Parameters.AddRange( cfg.Parameters)
for parameter in cfg.Parameters do
cmd.Parameters.Add parameter |> ignore<int>

let getReaderBehavior() =
seq {
Expand All @@ -90,19 +108,19 @@ type ``ISqlCommand Implementation``(cfg: DesignTimeConfig, connection: Connectio
match cfg.ResultType with
| ResultType.DataReader ->
``ISqlCommand Implementation``.ExecuteReader >> box,
``ISqlCommand Implementation``.AsyncExecuteReader >> box,
``ISqlCommand Implementation``.AsyncExecuteReader connection >> box,
notImplemented,
notImplemented
| ResultType.DataTable ->
``ISqlCommand Implementation``.ExecuteDataTable >> box,
``ISqlCommand Implementation``.AsyncExecuteDataTable >> box,
``ISqlCommand Implementation``.AsyncExecuteDataTable connection >> box,
notImplemented,
notImplemented
| ResultType.Records | ResultType.Tuples ->
match box cfg.RowMapping, cfg.ItemTypeName with
| null, null ->
``ISqlCommand Implementation``.ExecuteNonQuery manageConnection >> box,
``ISqlCommand Implementation``.AsyncExecuteNonQuery manageConnection >> box,
``ISqlCommand Implementation``.AsyncExecuteNonQuery connection manageConnection >> box,
notImplemented,
notImplemented
| rowMapping, itemTypeName ->
Expand Down Expand Up @@ -215,40 +233,42 @@ type ``ISqlCommand Implementation``(cfg: DesignTimeConfig, connection: Connectio
cmd.Connection.Dispose()
cmd.Dispose()

static member internal SetParameters(cmd: SqlCommand, parameters: (string * obj)[]) =
static member internal SetParameters(cmd: IDbCommand, parameters: (string * obj)[]) =
for name, value in parameters do
let p = cmd.Parameters.[name]

if p.Direction.HasFlag(ParameterDirection.Input)
then
let p = cmd.Parameters.[name] :?> IDbDataParameter
let cmdIsSystemDataSqlClient = cmd :? SqlCommand
if p.Direction.HasFlag(ParameterDirection.Input) then
match value with
| null ->
p.Value <- DBNull.Value
| null -> p.Value <- DBNull.Value
| _ ->
match p.SqlDbType with
| SqlDbType.Structured ->
// TODO: Maybe make this lazy?
if cmdIsSystemDataSqlClient then
let p = p :?> SqlParameter
match p.SqlDbType with
| SqlDbType.Structured ->
// TODO: Maybe make this lazy?

//p.Value <- value |> unbox |> Seq.cast<Microsoft.SqlServer.Server.SqlDataRecord>
//p.Value <- value |> unbox |> Seq.cast<Microsoft.SqlServer.Server.SqlDataRecord>

//done via reflection because not implemented on Mono
//done via reflection because not implemented on Mono

let sqlDataRecordType = typeof<SqlCommand>.Assembly.GetType("Microsoft.SqlServer.Server.SqlDataRecord", throwOnError = true)
let records = typeof<Linq.Enumerable>.GetMethod("Cast").MakeGenericMethod(sqlDataRecordType).Invoke(null, [| value |])
let hasAny =
typeof<Linq.Enumerable>
.GetMethods(BindingFlags.Static ||| BindingFlags.Public)
.First(fun m -> m.Name = "Any" && m.GetParameters().Count() = 1)
.MakeGenericMethod(sqlDataRecordType).Invoke(null, [| records |]) :?> bool
p.Value <- if not hasAny then null else records
| _ -> p.Value <- value

let sqlDataRecordType = typeof<SqlCommand>.Assembly.GetType("Microsoft.SqlServer.Server.SqlDataRecord", throwOnError = true)
let records = typeof<Linq.Enumerable>.GetMethod("Cast").MakeGenericMethod(sqlDataRecordType).Invoke(null, [| value |])
let hasAny =
typeof<Linq.Enumerable>
.GetMethods(BindingFlags.Static ||| BindingFlags.Public)
.First(fun m -> m.Name = "Any" && m.GetParameters().Count() = 1)
.MakeGenericMethod(sqlDataRecordType).Invoke(null, [| records |]) :?> bool
p.Value <- if not hasAny then null else records
| _ -> p.Value <- value
else
// happy go lucky...
p.Value <- value
elif p.Direction.HasFlag(ParameterDirection.Output) && value :? Array then
p.Size <- (value :?> Array).Length

//Execute/AsyncExecute versions

static member internal VerifyResultsetColumns(cursor: SqlDataReader, expected) =
static member internal VerifyResultsetColumns(cursor: IDataReader, expected) =
if FsharpDataSqlClientConfiguration.Current.ResultsetRuntimeVerification
then
if cursor.FieldCount < Array.length expected
Expand All @@ -266,32 +286,48 @@ type ``ISqlCommand Implementation``(cfg: DesignTimeConfig, connection: Connectio
cursor.Close()
invalidOp message

static member internal ExecuteReader(cmd, getReaderBehavior, parameters, expectedDataReaderColumns) =
static member internal ExecuteReader(cmd : IDbCommand, getReaderBehavior, parameters, expectedDataReaderColumns) =
``ISqlCommand Implementation``.SetParameters(cmd, parameters)
let cursor = cmd.ExecuteReader( getReaderBehavior())
``ISqlCommand Implementation``.VerifyResultsetColumns(cursor, expectedDataReaderColumns)
cursor

static member internal AsyncExecuteReader(cmd, getReaderBehavior, parameters, expectedDataReaderColumns) =
static member internal AsyncExecuteReader connection (cmd : IDbCommand, getReaderBehavior, parameters, expectedDataReaderColumns) =
async {
``ISqlCommand Implementation``.SetParameters(cmd, parameters)
let! cursor = cmd.AsyncExecuteReader( getReaderBehavior())
``ISqlCommand Implementation``.VerifyResultsetColumns(cursor, expectedDataReaderColumns)
return cursor
match connection with
| Connection.ConnectionString _
| Connection.SystemDataSqlClientConnection _
| Connection.SystemDataSqlClientTransaction _ ->
let! cursor = (cmd :?> SqlCommand).AsyncExecuteReader( getReaderBehavior())
``ISqlCommand Implementation``.VerifyResultsetColumns(cursor, expectedDataReaderColumns)
return cursor
| Connection.SystemDataIDbConnection _
| Connection.SystemDataIDbTransaction _ ->
raise (NotImplementedException "not supported")
return Unchecked.defaultof<_>
}

static member internal ExecuteDataTable(cmd, getReaderBehavior, parameters, expectedDataReaderColumns) =
static member internal ExecuteDataTable(cmd : IDbCommand, getReaderBehavior, parameters, expectedDataReaderColumns) =
use cursor = ``ISqlCommand Implementation``.ExecuteReader(cmd, getReaderBehavior, parameters, expectedDataReaderColumns)
let result = new DataTable<DataRow>(cmd)
result.Load(cursor)
result

static member internal AsyncExecuteDataTable(cmd, getReaderBehavior, parameters, expectedDataReaderColumns) =
static member internal AsyncExecuteDataTable connection (cmd, getReaderBehavior, parameters, expectedDataReaderColumns) =
async {
use! reader = ``ISqlCommand Implementation``.AsyncExecuteReader(cmd, getReaderBehavior, parameters, expectedDataReaderColumns)
let result = new DataTable<DataRow>(cmd)
result.Load(reader)
return result
// match connection with
// | Connection.ConnectionString _
// | Connection.SystemDataSqlClientConnection _
// | Connection.SystemDataSqlClientTransaction _ ->
use! reader = ``ISqlCommand Implementation``.AsyncExecuteReader connection (cmd, getReaderBehavior, parameters, expectedDataReaderColumns)
let result = new DataTable<DataRow>(cmd)
result.Load(reader)
return result
// | Connection.SystemDataIDbConnection _
// | Connection.SystemDataIDbTransaction _ ->
// raise (NotImplementedException "not supported")
// return Unchecked.defaultof<_>
}

static member internal ExecuteSeq<'TItem> (rank, rowMapper) = fun(cmd: SqlCommand, getReaderBehavior, parameters, expectedDataReaderColumns) ->
Expand Down Expand Up @@ -331,10 +367,10 @@ type ``ISqlCommand Implementation``(cfg: DesignTimeConfig, connection: Connectio

box resultset

static member internal AsyncExecuteSeq<'TItem> (rank, rowMapper) = fun(cmd, getReaderBehavior, parameters, expectedDataReaderColumns) ->
static member internal AsyncExecuteSeq<'TItem> (rank, rowMapper) connection = fun(cmd, getReaderBehavior, parameters, expectedDataReaderColumns) ->
let xs =
async {
let! reader = ``ISqlCommand Implementation``.AsyncExecuteReader(cmd, getReaderBehavior, parameters, expectedDataReaderColumns)
let! reader = ``ISqlCommand Implementation``.AsyncExecuteReader connection (cmd, getReaderBehavior, parameters, expectedDataReaderColumns)
return reader.MapRowValues<'TItem>( rowMapper)
}

Expand Down Expand Up @@ -362,17 +398,26 @@ type ``ISqlCommand Implementation``(cfg: DesignTimeConfig, connection: Connectio
let recordsAffected = cmd.ExecuteNonQuery()
for i = 0 to parameters.Length - 1 do
let name, _ = parameters.[i]
let p = cmd.Parameters.[name]
let p = cmd.Parameters.[name] :?> IDbDataParameter
if p.Direction.HasFlag( ParameterDirection.Output)
then
parameters.[i] <- name, p.Value
recordsAffected

static member internal AsyncExecuteNonQuery manageConnection (cmd, _, parameters, _) =
static member internal AsyncExecuteNonQuery connection manageConnection (cmd, _, parameters, _) =
``ISqlCommand Implementation``.SetParameters(cmd, parameters)
async {
use _ = cmd.Connection.UseLocally(manageConnection )
return! cmd.AsyncExecuteNonQuery()
match connection with
| Connection.ConnectionString _
| Connection.SystemDataSqlClientConnection _
| Connection.SystemDataSqlClientTransaction _ ->
let cmd = unbox<SqlCommand> cmd
use _ = cmd.Connection.UseLocally(manageConnection )
return! cmd.AsyncExecuteNonQuery()
| Connection.SystemDataIDbConnection _
| Connection.SystemDataIDbTransaction _ ->
failwithf "unsupported with other client"
return! Unchecked.defaultof<_>
}

#if WITH_LEGACY_NAMESPACE
Expand Down