diff --git a/src/FunSQL.jl b/src/FunSQL.jl index ca799261..05ffa507 100644 --- a/src/FunSQL.jl +++ b/src/FunSQL.jl @@ -53,8 +53,10 @@ export funsql_from, funsql_fun, funsql_group, + funsql_hide, funsql_highlight, funsql_in, + funsql_into, funsql_iterate, funsql_is_not_null, funsql_is_null, @@ -80,6 +82,7 @@ export funsql_rank, funsql_row_number, funsql_select, + funsql_show, funsql_sort, funsql_sum, funsql_with diff --git a/src/link.jl b/src/link.jl index d4d12487..881c7fc0 100644 --- a/src/link.jl +++ b/src/link.jl @@ -22,16 +22,28 @@ struct LinkContext knot_refs) end -function link(n::SQLNode) - @dissect(n, WithContext(over = over, catalog = catalog)) || throw(ILLFormedError()) - ctx = LinkContext(catalog) - t = row_type(over) +function _select(t::RowType) refs = SQLNode[] + t.visible || return refs for (f, ft) in t.fields if ft isa ScalarType + ft.visible || continue push!(refs, Get(f)) + else + nested_refs = _select(ft) + for nested_ref in nested_refs + push!(refs, Nested(over = nested_ref, name = f)) + end end end + refs +end + +function link(n::SQLNode) + @dissect(n, WithContext(over = over, catalog = catalog)) || throw(ILLFormedError()) + ctx = LinkContext(catalog) + t = row_type(over) + refs = _select(t) over′ = Linked(refs, over = link(dismantle(over, ctx), ctx, refs)) WithContext(over = over′, catalog = catalog, defs = ctx.defs) end @@ -114,19 +126,15 @@ function dismantle(n::GroupNode, ctx) Group(over = over′, by = by′, sets = n.sets, name = n.name, label_map = n.label_map) end -function dismantle(n::IterateNode, ctx) +function dismantle(n::IntoNode, ctx) over′ = dismantle(n.over, ctx) - iterator′ = dismantle(n.iterator, ctx) - Iterate(over = over′, iterator = iterator′) + Into(over = over′, name = n.name) end -function dismantle(n::JoinNode, ctx) - rt = row_type(n.joinee) - router = JoinRouter(Set(keys(rt.fields)), !isa(rt.group, EmptyType)) +function dismantle(n::IterateNode, ctx) over′ = dismantle(n.over, ctx) - joinee′ = dismantle(n.joinee, ctx) - on′ = dismantle_scalar(n.on, ctx) - RoutedJoin(over = over′, joinee = joinee′, on = on′, router = router, left = n.left, right = n.right, optional = n.optional) + iterator′ = dismantle(n.iterator, ctx) + Iterate(over = over′, iterator = iterator′) end function dismantle(n::LimitNode, ctx) @@ -172,6 +180,13 @@ function dismantle_scalar(n::ResolvedNode, ctx) end end +function dismantle(n::RoutedJoinNode, ctx) + over′ = dismantle(n.over, ctx) + joinee′ = dismantle(n.joinee, ctx) + on′ = dismantle_scalar(n.on, ctx) + RoutedJoin(over = over′, joinee = joinee′, on = on′, name = n.name, left = n.left, right = n.right, optional = n.optional) +end + function dismantle(n::SelectNode, ctx) over′ = dismantle(n.over, ctx) args′ = dismantle_scalar(n.args, ctx) @@ -219,16 +234,7 @@ function link(n::AppendNode, ctx) end function link(n::AsNode, ctx) - refs = SQLNode[] - for ref in ctx.refs - if @dissect(ref, over |> Nested(name = name)) - @assert name == n.name - push!(refs, over) - else - error() - end - end - over′ = link(n.over, ctx, refs) + over′ = link(n.over, ctx) As(over = over′, name = n.name) end @@ -276,10 +282,8 @@ function link(n::FromIterateNode, ctx) end function link(n::FromTableExpressionNode, ctx) - refs = ctx.cte_refs[(n.name, n.depth)] - for ref in ctx.refs - push!(refs, Nested(over = ref, name = n.name)) - end + cte_refs = ctx.cte_refs[(n.name, n.depth)] + append!(cte_refs, ctx.refs) n end @@ -320,6 +324,20 @@ function link(n::GroupNode, ctx) Group(over = over′, by = n.by, sets = n.sets, name = n.name, label_map = n.label_map) end +function link(n::IntoNode, ctx) + refs = SQLNode[] + for ref in ctx.refs + if @dissect(ref, over |> Nested(name = name)) + @assert name == n.name + push!(refs, over) + else + error() + end + end + over′ = link(n.over, ctx, refs) + Into(over = over′, name = n.name) +end + function link(n::IterateNode, ctx) iterator′ = n.iterator defs = copy(ctx.defs) @@ -351,53 +369,6 @@ function link(n::IterateNode, ctx) Padding(over = n′) end -function route(r::JoinRouter, ref::SQLNode) - if @dissect(ref, over |> Nested(name = name)) && name in r.label_set - return 1 - end - if @dissect(ref, Get(name = name)) && name in r.label_set - return 1 - end - if @dissect(ref, over |> Agg()) && r.group - return 1 - end - return -1 -end - -function link(n::RoutedJoinNode, ctx) - lrefs = SQLNode[] - rrefs = SQLNode[] - for ref in ctx.refs - turn = route(n.router, ref) - push!(turn < 0 ? lrefs : rrefs, ref) - end - if n.optional && isempty(rrefs) - return link(n.over, ctx) - end - ln_ext_refs = length(lrefs) - rn_ext_refs = length(rrefs) - refs′ = SQLNode[] - lateral_refs = SQLNode[] - gather!(n.joinee, ctx, lateral_refs) - append!(lrefs, lateral_refs) - lateral = !isempty(lateral_refs) - gather!(n.on, ctx, refs′) - for ref in refs′ - turn = route(n.router, ref) - push!(turn < 0 ? lrefs : rrefs, ref) - end - over′ = Linked(lrefs, ln_ext_refs, over = link(n.over, ctx, lrefs)) - joinee′ = Linked(rrefs, rn_ext_refs, over = link(n.joinee, ctx, rrefs)) - RoutedJoinNode( - over = over′, - joinee = joinee′, - on = n.on, - router = n.router, - left = n.left, - right = n.right, - lateral = lateral) -end - function link(n::LimitNode, ctx) over′ = Linked(ctx.refs, over = link(n.over, ctx)) Limit(over = over′, offset = n.offset, limit = n.limit) @@ -446,6 +417,46 @@ function link(n::PartitionNode, ctx) Partition(over = over′, by = n.by, order_by = n.order_by, frame = n.frame, name = n.name) end +function link(n::RoutedJoinNode, ctx) + lrefs = SQLNode[] + rrefs = SQLNode[] + for ref in ctx.refs + if @dissect(ref, over |> Nested(name = name)) && name === n.name + push!(rrefs, ref) + else + push!(lrefs, ref) + end + end + if n.optional && isempty(rrefs) + return link(n.over, ctx) + end + ln_ext_refs = length(lrefs) + rn_ext_refs = length(rrefs) + refs′ = SQLNode[] + lateral_refs = SQLNode[] + gather!(n.joinee, ctx, lateral_refs) + append!(lrefs, lateral_refs) + lateral = !isempty(lateral_refs) + gather!(n.on, ctx, refs′) + for ref in refs′ + if @dissect(ref, over |> Nested(name = name)) && name === n.name + push!(rrefs, ref) + else + push!(lrefs, ref) + end + end + over′ = Linked(lrefs, ln_ext_refs, over = link(n.over, ctx, lrefs)) + joinee′ = Linked(rrefs, rn_ext_refs, over = link(Into(over = n.joinee, name = n.name), ctx, rrefs)) + RoutedJoinNode( + over = over′, + joinee = joinee′, + on = n.on, + name = n.name, + left = n.left, + right = n.right, + lateral = lateral) +end + function link(n::SelectNode, ctx) refs = SQLNode[] gather!(n.args, ctx, refs) @@ -540,12 +551,9 @@ end function gather!(n::IsolatedNode, ctx) def = ctx.defs[n.idx] !@dissect(def, Linked()) || return - refs = SQLNode[] - for (f, ft) in n.type.fields - if ft isa ScalarType - push!(refs, Get(f)) - break - end + refs = _select(n.type) + if !isempty(refs) + refs = refs[1:1] end def′ = Linked(refs, over = link(def, ctx, refs)) ctx.defs[n.idx] = def′ diff --git a/src/nodes.jl b/src/nodes.jl index 58d955a8..c862d8c8 100644 --- a/src/nodes.jl +++ b/src/nodes.jl @@ -696,6 +696,7 @@ include("nodes/get.jl") include("nodes/group.jl") include("nodes/highlight.jl") include("nodes/internal.jl") +include("nodes/into.jl") include("nodes/iterate.jl") include("nodes/join.jl") include("nodes/limit.jl") @@ -704,6 +705,7 @@ include("nodes/order.jl") include("nodes/over.jl") include("nodes/partition.jl") include("nodes/select.jl") +include("nodes/show.jl") include("nodes/sort.jl") include("nodes/variable.jl") include("nodes/where.jl") diff --git a/src/nodes/as.jl b/src/nodes/as.jl index ca52308e..9a0d82f1 100644 --- a/src/nodes/as.jl +++ b/src/nodes/as.jl @@ -18,8 +18,7 @@ AsNode(name; over = nothing) = As(name; over = nothing) name => over -In a scalar context, `As` specifies the name of the output column. When -applied to tabular data, `As` wraps the data in a nested record. +`As` specifies the name of the output column. The arrow operator (`=>`) is a shorthand notation for `As`. @@ -37,19 +36,19 @@ SELECT "person_1"."person_id" AS "id" FROM "person" AS "person_1" ``` -*Show all patients together with their state of residence.* +*Show all patients together with their primary care provider.* ```jldoctest -julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth, :location_id]); +julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth, :provider_id]); -julia> location = SQLTable(:location, columns = [:location_id, :state]); +julia> provider = SQLTable(:provider, columns = [:provider_id, :provider_name]); julia> q = From(:person) |> - Join(From(:location) |> As(:location), - on = Get.location_id .== Get.location.location_id) |> - Select(Get.person_id, Get.location.state); + Join(From(:provider) |> As(:pcp), + on = Get.provider_id .== Get.pcp.provider_id) |> + Select(Get.person_id, Get.pcp.provider_name); -julia> print(render(q, tables = [person, location])) +julia> print(render(q, tables = [person, provider])) SELECT "person_1"."person_id", "location_1"."state" diff --git a/src/nodes/internal.jl b/src/nodes/internal.jl index 8a5f1025..c96dc000 100644 --- a/src/nodes/internal.jl +++ b/src/nodes/internal.jl @@ -202,30 +202,22 @@ PrettyPrinting.quoteof(n::FromFunctionNode, ctx::QuoteContext) = Expr(:kw, :columns, Expr(:vect, [QuoteNode(col) for col in n.columns]...))) # Annotated Join node. -struct JoinRouter - label_set::Set{Symbol} - group::Bool -end - -PrettyPrinting.quoteof(r::JoinRouter) = - Expr(:call, nameof(JoinRouter), quoteof(r.label_set), quoteof(r.group)) - mutable struct RoutedJoinNode <: TabularNode over::Union{SQLNode, Nothing} joinee::SQLNode on::SQLNode - router::JoinRouter + name::Symbol left::Bool right::Bool lateral::Bool optional::Bool - RoutedJoinNode(; over, joinee, on, router, left, right, lateral = false, optional = false) = - new(over, joinee, on, router, left, right, lateral, optional) + RoutedJoinNode(; over, joinee, on, name = label(joinee), left, right, lateral = false, optional = false) = + new(over, joinee, on, name, left, right, lateral, optional) end -RoutedJoinNode(joinee, on; over = nothing, router, left = false, right = false, lateral = false, optional = false) = - RoutedJoinNode(over = over, joinee = joinee, on = on, router, left = left, right = right, lateral = lateral, optional = optional) +RoutedJoinNode(joinee, on; over = nothing, name = label(joinee), left = false, right = false, lateral = false, optional = false) = + RoutedJoinNode(over = over, name = name, on = on, router, left = left, right = right, lateral = lateral, optional = optional) RoutedJoin(args...; kws...) = RoutedJoinNode(args...; kws...) |> SQLNode @@ -235,7 +227,7 @@ function PrettyPrinting.quoteof(n::RoutedJoinNode, ctx::QuoteContext) if !ctx.limit push!(ex.args, quoteof(n.joinee, ctx)) push!(ex.args, quoteof(n.on, ctx)) - push!(ex.args, Expr(:kw, :router, quoteof(n.router))) + push!(ex.args, Expr(:kw, :name, QuoteNode(n.name))) if n.left push!(ex.args, Expr(:kw, :left, n.left)) end diff --git a/src/nodes/into.jl b/src/nodes/into.jl new file mode 100644 index 00000000..8f7e3c8e --- /dev/null +++ b/src/nodes/into.jl @@ -0,0 +1,39 @@ +# Wrap the output into a nested record. + +mutable struct IntoNode <: TabularNode + over::Union{SQLNode, Nothing} + name::Symbol + + IntoNode(; + over = nothing, + name::Union{Symbol, AbstractString}) = + new(over, Symbol(name)) +end + +IntoNode(name; over = nothing) = + IntoNode(over = over, name = name) + +""" + Into(; over = nothing, name) + Into(name; over = nothing) + +`Into` wraps output columns in a nested record. +""" +Into(args...; kws...) = + IntoNode(args...; kws...) |> SQLNode + +const funsql_into = Into + +dissect(scr::Symbol, ::typeof(Into), pats::Vector{Any}) = + dissect(scr, IntoNode, pats) + +function PrettyPrinting.quoteof(n::IntoNode, ctx::QuoteContext) + ex = Expr(:call, nameof(Into), quoteof(n.name)) + if n.over !== nothing + ex = Expr(:call, :|>, quoteof(n.over, ctx), ex) + end + ex +end + +label(n::IntoNode) = + n.name diff --git a/src/nodes/join.jl b/src/nodes/join.jl index b5e56536..0335cfb3 100644 --- a/src/nodes/join.jl +++ b/src/nodes/join.jl @@ -7,21 +7,22 @@ mutable struct JoinNode <: TabularNode left::Bool right::Bool optional::Bool + swap::Bool - JoinNode(; over = nothing, joinee, on, left = false, right = false, optional = false) = - new(over, joinee, on, left, right, optional) + JoinNode(; over = nothing, joinee, on, left = false, right = false, optional = false, swap = false) = + new(over, joinee, on, left, right, optional, swap) end -JoinNode(joinee; over = nothing, on, left = false, right = false, optional = false) = - JoinNode(over = over, joinee = joinee, on = on, left = left, right = right, optional = optional) +JoinNode(joinee; over = nothing, on, left = false, right = false, optional = false, swap = false) = + JoinNode(over = over, joinee = joinee, on = on, left = left, right = right, optional = optional, swap = swap) -JoinNode(joinee, on; over = nothing, left = false, right = false, optional = false) = - JoinNode(over = over, joinee = joinee, on = on, left = left, right = right, optional = optional) +JoinNode(joinee, on; over = nothing, left = false, right = false, optional = false, swap = false) = + JoinNode(over = over, joinee = joinee, on = on, left = left, right = right, optional = optional, swap = swap) """ - Join(; over = nothing, joinee, on, left = false, right = false, optional = false) - Join(joinee; over = nothing, on, left = false, right = false, optional = false) - Join(joinee, on; over = nothing, left = false, right = false, optional = false) + Join(; over = nothing, joinee, on, left = false, right = false, optional = false, swap = false) + Join(joinee; over = nothing, on, left = false, right = false, optional = false, swap = false) + Join(joinee, on; over = nothing, left = false, right = false, optional = false, swap = false) `Join` correlates two input datasets. @@ -107,6 +108,9 @@ function PrettyPrinting.quoteof(n::JoinNode, ctx::QuoteContext) if n.optional push!(ex.args, Expr(:kw, :optional, n.optional)) end + if n.swap + push!(ex.args, Expr(:kw, :swap, n.swap)) + end else push!(ex.args, :…) end @@ -115,3 +119,6 @@ function PrettyPrinting.quoteof(n::JoinNode, ctx::QuoteContext) end ex end + +label(n::JoinNode) = + n.swap ? label(n.joinee) : label(n.over) diff --git a/src/nodes/show.jl b/src/nodes/show.jl new file mode 100644 index 00000000..312207b5 --- /dev/null +++ b/src/nodes/show.jl @@ -0,0 +1,47 @@ +# Show/Hide nodes + +mutable struct ShowNode <: TabularNode + over::Union{SQLNode, Nothing} + names::Vector{Symbol} + visible::Bool + label_map::FunSQL.OrderedDict{Symbol, Int} + + function ShowNode(; over = nothing, names = [], visible = true, label_map = nothing) + if label_map !== nothing + new(over, names, visible, label_map) + else + n = new(over, names, visible, FunSQL.OrderedDict{Symbol, Int}()) + for (i, name) in enumerate(n.names) + if name in keys(n.label_map) + err = FunSQL.DuplicateLabelError(name, path = [n]) + throw(err) + end + n.label_map[name] = i + end + n + end + end +end + +ShowNode(names...; over = nothing, visible = true) = + ShowNode(over = over, names = Symbol[names...], visible = visible) + +Show(args...; kws...) = + ShowNode(args...; kws...) |> SQLNode + +Hide(args...; kws...) = + ShowNode(args...; kws..., visible = false) |> SQLNode + +const funsql_show = Show +const funsql_hide = Hide + +dissect(scr::Symbol, ::typeof(Show), pats::Vector{Any}) = + dissect(scr, ShowNode, pats) + +function FunSQL.PrettyPrinting.quoteof(n::ShowNode, ctx::FunSQL.QuoteContext) + ex = Expr(:call, nameof(n.visible ? Show : Hide), quoteof(n.names, ctx)...) + if n.over !== nothing + ex = Expr(:call, :|>, FunSQL.quoteof(n.over, ctx), ex) + end + ex +end diff --git a/src/resolve.jl b/src/resolve.jl index b2097429..0112e64f 100644 --- a/src/resolve.jl +++ b/src/resolve.jl @@ -160,9 +160,8 @@ end function resolve(n::AsNode, ctx) over′ = resolve(n.over, ctx) - t = row_type(over′) n′ = As(name = n.name, over = over′) - Resolved(RowType(FieldTypeMap(n.name => t)), over = n′) + Resolved(type(over′), over = n′) end function resolve_scalar(n::AsNode, ctx) @@ -357,6 +356,13 @@ resolve(n::HighlightNode, ctx) = resolve_scalar(n::HighlightNode, ctx) = resolve_scalar(n.over, ctx) +function resolve(n::IntoNode, ctx) + over′ = resolve(n.over, ctx) + t = row_type(over′) + n′ = Into(name = n.name, over = over′) + Resolved(RowType(FieldTypeMap(n.name => t)), over = n′) +end + function resolve(n::IterateNode, ctx) over′ = resolve(n.over, ResolveContext(ctx, knot_type = nothing, implicit_knot = false)) t = row_type(over′) @@ -372,23 +378,23 @@ function resolve(n::IterateNode, ctx) end function resolve(n::JoinNode, ctx) + if n.swap + return resolve(JoinNode(over = n.joinee, joinee = n.over, on = n.on, left = n.right, right = n.left, optional = n.optional), ctx) + end over′ = resolve(n.over, ctx) lt = row_type(over′) + name = label(n.joinee) joinee′ = resolve(n.joinee, ResolveContext(ctx, row_type = lt, implicit_knot = false)) rt = row_type(joinee′) fields = FieldTypeMap() for (f, ft) in lt.fields - fields[f] = get(rt.fields, f, ft) + fields[f] = ft end - for (f, ft) in rt.fields - if !haskey(fields, f) - fields[f] = ft - end - end - group = rt.group isa EmptyType ? lt.group : rt.group + fields[name] = rt + group = lt.group t = RowType(fields, group) on′ = resolve_scalar(n.on, ctx, t) - n′ = Join(over = over′, joinee = joinee′, on = on′, left = n.left, right = n.right, optional = n.optional) + n′ = RoutedJoin(over = over′, joinee = joinee′, on = on′, name = name, left = n.left, right = n.right, optional = n.optional) Resolved(t, over = n′) end @@ -458,6 +464,33 @@ function resolve(n::SelectNode, ctx) Resolved(RowType(fields), over = n′) end +function resolve(n::ShowNode, ctx) + over′ = resolve(n.over, ctx) + t = row_type(over′) + for name in n.names + ft = get(t.fields, name, EmptyType()) + if ft isa EmptyType + throw( + ReferenceError( + REFERENCE_ERROR_TYPE.UNDEFINED_NAME, + name = name, + path = get_path(ctx))) + end + end + fields = FieldTypeMap() + for (f, ft) in t.fields + if f in keys(n.label_map) + if ft isa ScalarType + ft = ScalarType(visible = n.visible) + else + ft = RowType(ft.fields, ft.group, visible = n.visible) + end + end + fields[f] = ft + end + Resolved(RowType(fields, t.group, visible = t.visible), over = over′) +end + function resolve_scalar(n::SortNode, ctx) over′ = resolve_scalar(n.over, ctx) n′ = Sort(over = over′, value = n.value, nulls = n.nulls) @@ -491,16 +524,7 @@ function resolve(n::Union{WithNode, WithExternalNode}, ctx) v = get(ctx.cte_types, name, nothing) depth = 1 + (v !== nothing ? v[1] : 0) t = row_type(args′[i]) - cte_t = get(t.fields, name, EmptyType()) - if !(cte_t isa RowType) - throw( - ReferenceError( - REFERENCE_ERROR_TYPE.INVALID_TABLE_REFERENCE, - name = name, - path = get_path(ctx))) - - end - cte_types′ = Base.ImmutableDict(cte_types′, name => (depth, cte_t)) + cte_types′ = Base.ImmutableDict(cte_types′, name => (depth, t)) end ctx′ = ResolveContext(ctx, cte_types = cte_types′) over′ = resolve(n.over, ctx′) diff --git a/src/translate.jl b/src/translate.jl index cc2f2128..c7325247 100644 --- a/src/translate.jl +++ b/src/translate.jl @@ -417,26 +417,8 @@ function assemble(n::AppendNode, ctx) Assemblage(a_name, c, repl = repl, cols = dummy_cols) end -function assemble(n::AsNode, ctx) - refs′ = SQLNode[] - for ref in ctx.refs - if @dissect(ref, over |> Nested()) - push!(refs′, over) - else - push!(refs′, ref) - end - end - base = assemble(n.over, TranslateContext(ctx, refs = refs′)) - repl′ = Dict{SQLNode, Symbol}() - for ref in ctx.refs - if @dissect(ref, over |> Nested()) - repl′[ref] = base.repl[over] - else - repl′[ref] = base.repl[ref] - end - end - Assemblage(n.name, base.clause, cols = base.cols, repl = repl′) -end +assemble(n::AsNode, ctx) = + assemble(n.over, ctx) function assemble(n::BindNode, ctx) vars′ = ctx.vars @@ -519,21 +501,12 @@ end assemble(::FromNothingNode, ctx) = assemble(nothing, ctx) -function unwrap_repl(a::Assemblage) - repl′ = Dict{SQLNode, Symbol}() - for (ref, name) in a.repl - @dissect(ref, over |> Nested()) || error() - repl′[over] = name - end - Assemblage(a.name, a.clause, cols = a.cols, repl = repl′) -end - function assemble(n::FromTableExpressionNode, ctx) cte_a = ctx.ctes[ctx.cte_map[(n.name, n.depth)]] alias = allocate_alias(ctx, n.name) tbl = ID(cte_a.qualifiers, cte_a.name) c = FROM(AS(over = tbl, name = alias)) - subs = make_subs(unwrap_repl(cte_a.a), alias) + subs = make_subs(cte_a.a, alias) trns = Pair{SQLNode, SQLClause}[] for ref in ctx.refs push!(trns, ref => subs[ref]) @@ -664,6 +637,27 @@ function assemble(n::GroupNode, ctx) return Assemblage(base.name, c, cols = cols, repl = repl) end +function assemble(n::IntoNode, ctx) + refs′ = SQLNode[] + for ref in ctx.refs + if @dissect(ref, over |> Nested()) + push!(refs′, over) + else + push!(refs′, ref) + end + end + base = assemble(n.over, TranslateContext(ctx, refs = refs′)) + repl′ = Dict{SQLNode, Symbol}() + for ref in ctx.refs + if @dissect(ref, over |> Nested()) + repl′[ref] = base.repl[over] + else + repl′[ref] = base.repl[ref] + end + end + Assemblage(n.name, base.clause, cols = base.cols, repl = repl′) +end + function assemble(n::IterateNode, ctx) ctx′ = TranslateContext(ctx, vars = Base.ImmutableDict{Tuple{Symbol, Int}, SQLClause}()) left = assemble(n.over, ctx) @@ -870,22 +864,16 @@ function assemble(n::RoutedJoinNode, ctx) right = assemble(n.joinee, ctx) end if @dissect(right.clause, (joinee := (ID() || AS())) |> FROM()) && (!n.left || _outer_safe(right)) - for (ref, name) in right.repl - subs[ref] = right.cols[name] - end + right_alias = nothing if ctx.catalog.dialect.has_implicit_lateral lateral = false end else right_alias = allocate_alias(ctx, right) joinee = AS(over = complete(right), name = right_alias) - right_cache = Dict{Symbol, SQLClause}() - for (ref, name) in right.repl - subs[ref] = get(right_cache, name) do - ID(over = right_alias, name = name) - end - end end + right_subs = make_subs(right, right_alias) + merge!(subs, right_subs) on = translate(n.on, ctx, subs) c = JOIN(over = tail, joinee = joinee, on = on, left = n.left, right = n.right, lateral = lateral) trns = Pair{SQLNode, SQLClause}[] diff --git a/src/types.jl b/src/types.jl index 856821ed..641ac01b 100644 --- a/src/types.jl +++ b/src/types.jl @@ -13,17 +13,27 @@ PrettyPrinting.quoteof(::EmptyType) = Expr(:call, nameof(EmptyType)) struct ScalarType <: AbstractSQLType + visible::Bool + + ScalarType(; visible = true) = + new(visible) end -PrettyPrinting.quoteof(::ScalarType) = - Expr(:call, nameof(ScalarType)) +function PrettyPrinting.quoteof(::ScalarType) + ex = Expr(:call, nameof(ScalarType)) + if !t.visible + push!(ex.args, Expr(:kw, :visible, t.visible)) + end + ex +end struct RowType <: AbstractSQLType fields::OrderedDict{Symbol, Union{ScalarType, RowType}} group::Union{EmptyType, RowType} + visible::Bool - RowType(fields, group = EmptyType()) = - new(fields, group) + RowType(fields, group = EmptyType(); visible = true) = + new(fields, group, visible) end const FieldTypeMap = OrderedDict{Symbol, Union{ScalarType, RowType}} @@ -43,6 +53,9 @@ function PrettyPrinting.quoteof(t::RowType) if !(t.group isa EmptyType) push!(ex.args, Expr(:kw, :group, quoteof(t.group))) end + if !t.visible + push!(ex.args, Expr(:kw, :visible, t.visible)) + end ex end @@ -54,8 +67,8 @@ const EMPTY_ROW = RowType() Base.intersect(::AbstractSQLType, ::AbstractSQLType) = EmptyType() -Base.intersect(::ScalarType, ::ScalarType) = - ScalarType() +Base.intersect(t1::ScalarType, t2::ScalarType) = + ScalarType(visible = t1.visible || t2.visible) function Base.intersect(t1::RowType, t2::RowType) if t1 === t2 @@ -71,7 +84,7 @@ function Base.intersect(t1::RowType, t2::RowType) end end group = intersect(t1.group, t2.group) - RowType(fields, group) + RowType(fields, group, visible = t1.visible || t2.visible) end @@ -98,5 +111,8 @@ function Base.issubset(t1::RowType, t2::RowType) if !issubset(t1.group, t2.group) return false end + if !t1.visible && t2.visible + return false + end return true end