From fbf88471908c38759c8d67dc3ebc9097e4bb1c65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 25 Jul 2024 10:38:17 +0200 Subject: [PATCH] Customize `show` for `EinExpr`, `SizedEinExpr` --- src/EinExpr.jl | 2 ++ src/SizedEinExpr.jl | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/src/EinExpr.jl b/src/EinExpr.jl index ab6b481..3677939 100644 --- a/src/EinExpr.jl +++ b/src/EinExpr.jl @@ -17,6 +17,8 @@ EinExpr(head::NTuple, args::NTuple) = EinExpr(collect(head), collect(args)) Base.copy(expr::EinExpr) = EinExpr(copy(expr.head), copy(expr.args)) +Base.show(io::IO, path::EinExpr) = print_tree((io, node) -> print(io, head(node)), io, path) + """ head(path::EinExpr) diff --git a/src/SizedEinExpr.jl b/src/SizedEinExpr.jl index 8b9b29c..9783f90 100644 --- a/src/SizedEinExpr.jl +++ b/src/SizedEinExpr.jl @@ -13,6 +13,16 @@ end EinExpr(path::Vector{L}, size::Dict{L}) where {L} = SizedEinExpr(EinExpr(path), size) +function Base.show(io::IO, path::SizedEinExpr) + print_tree(io, path.path) do io, node + print(io, head(node)) + print( + io, + " " * join(["$(flops(node, path.size)) flops", "$(length(SizedEinExpr(node, path.size))) elems"], ", "), + ) + end +end + head(sexpr::SizedEinExpr) = head(sexpr.path) """