Skip to content

Commit

Permalink
Implement PTX syntax highlighting in Julia.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ellipse0934 authored and maleadt committed Jan 19, 2022
1 parent 09f8b5a commit bf64c20
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 92 deletions.
72 changes: 0 additions & 72 deletions res/pygments/ptx.py

This file was deleted.

114 changes: 94 additions & 20 deletions src/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,107 @@ const Cthulhu = Base.PkgId(UUID("f68482b8-f384-11e8-15f7-abe071a5a75f"), "Cthulh
# syntax highlighting
#

const _pygmentize = Ref{Union{String,Nothing}}()
function pygmentize()
if !isassigned(_pygmentize)
_pygmentize[] = Sys.which("pygmentize")
end
return _pygmentize[]
end
# https://github.com/JuliaLang/julia/blob/dacd16f068fb27719b31effbe8929952ee2d5b32/stdlib/InteractiveUtils/src/codeview.jl
const hlscheme = Dict{Symbol, Tuple{Bool, Union{Symbol, Int}}}(
:default => (false, :normal), # e.g. comma, equal sign, unknown token
:comment => (false, :light_black),
:label => (false, :light_red),
:instruction => ( true, :light_cyan),
:type => (false, :cyan),
:number => (false, :yellow),
:bracket => (false, :yellow),
:variable => (false, :normal), # e.g. variable, register
:keyword => (false, :light_magenta),
:funcname => (false, :light_yellow),
)

function highlight(io::IO, code, lexer)
highlighter = pygmentize()
have_color = get(io, :color, false)
if highlighter === nothing || !have_color
if !haskey(io, :color)
print(io, code)
elseif lexer == "llvm"
InteractiveUtils.print_llvm(io, code)
elseif lexer == "ptx"
highlight_ptx(io, code)
else
custom_lexer = joinpath(dirname(@__DIR__), "res", "pygments", "$lexer.py")
if isfile(custom_lexer)
lexer = `$custom_lexer -x`
end

pipe = open(`$highlighter -f terminal -P bg=dark -l $lexer`, "r+")
print(pipe, code)
close(pipe.in)
print(io, read(pipe, String))
print(io, code)
end
return
end

const ptx_instructions = [
"abs", "cvt", "min", "shfl", "vadd", "activemask", "cvta", "mma", "shl", "vadd2",
"add", "discard", "mov", "shr", "vadd4", "addc", "div", "mul", "sin", "vavrg2",
"alloca", "dp2a", "mul24", "slct", "vavrg4", "and", "dp4a", "nanosleep", "sqrt",
"vmad", "applypriority", "ex2", "neg", "st", "vmax", "atom", "exit", "not",
"stackrestore", "vmax2", "bar", "fence", "or", "stacksave", "vmax4", "barrier",
"fma", "pmevent", "sub", "vmin", "bfe", "fns", "popc", "subc", "vmin2", "bfi",
"isspacep", "prefetch", "suld", "vmin4", "bfind", "istypep", "prefetchu", "suq",
"vote", "bmsk", "ld", "prmt", "sured", "vset", "bra", "ldmatrix", "rcp", "sust",
"vset2", "brev", "ldu", "red", "szext", "vset4", "brkpt", "lg2", "redux", "tanh",
"vshl", "brx", "lop3", "rem", "testp", "vshr", "call", "mad", "ret", "tex", "vsub",
"clz", "mad24", "rsqrt", "tld4", "vsub2", "cnot", "madc", "sad", "trap", "vsub4",
"copysign", "match", "selp", "txq", "wmma", "cos", "max", "set", "vabsdiff", "xor",
"cp", "mbarrier", "setp", "vabsdiff2", "createpolicy", "membar", "shf", "vabsdiff4"]

# simple regex-based highlighter
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
function highlight_ptx(io::IO, code)
function get_token(s)
# TODO: doesn't handle `ret;`, `{1`, etc; not properly tokenizing symbols
m = match(r"(\s*)([^\s]+)(.*)", s)
m !== nothing && (return m.captures[1:3])
return nothing, nothing, nothing
end
print_tok(token, type) = Base.printstyled(io,
token,
bold = hlscheme[type][1],
color = hlscheme[type][2])
buf = IOBuffer(code)
while !eof(buf)
line = readline(buf)
indent, tok, line = get_token(line)
istok(regex) = match(regex, tok) !== nothing
isinstr() = first(split(tok, '.')) in ptx_instructions
while (tok !== nothing)
print(io, indent)

# comments
if istok(r"^\/\/")
print_tok(tok, :comment)
print_tok(line, :comment)
break
# labels
elseif istok(r"^[\w]+:")
print_tok(tok, :label)
# instructions
elseif isinstr()
print_tok(tok, :instruction)
# directives
elseif istok(r"^\.[\w]+")
print_tok(tok, :type)
# guard predicates
elseif istok(r"^@!?%p.+")
print_tok(tok, :keyword)
# registers
elseif istok(r"^%[\w]+")
print_tok(tok, :variable)
# constants
elseif istok(r"^0[xX][A-F]+U?") || # hexadecimal
istok(r"^0[0-8]+U?") || # octal
istok(r"^0[bB][01]+U?") || # binary
istok(r"^[0-9]+U?") || # decimal
istok(r"^0[fF]{hexdigit}{8}") || # single-precision floating point
istok(r"^0[dD]{hexdigit}{16}") # double-precision floating point
print_tok(tok, :number)
# TODO: function names
# TODO: labels as RHS
else
print_tok(tok, :default)
end
indent, tok, line = get_token(line)
end
print(io, '\n')
end
end

#
# code_* replacements
Expand Down

0 comments on commit bf64c20

Please sign in to comment.