Skip to content

Commit

Permalink
Pragma to force function inlining (#462)
Browse files Browse the repository at this point in the history
  • Loading branch information
eldritchconundrum authored Sep 24, 2024
1 parent 3982ed1 commit cfeae09
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 15 deletions.
4 changes: 2 additions & 2 deletions src/ast.fs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type Ident(name: string) =
member this.Rename(n) = newName <- n
member val ToBeInlined = newName.StartsWith("i_") with get, set
// This prefix disables function inlining and variable inlining.
member this.DoNotInline = this.OldName.StartsWith("noinline_")
member val DoNotInline = newName.StartsWith("noinline_") with get, set

member val Loc = {line = -1; col = -1} with get, set

Expand Down Expand Up @@ -197,7 +197,7 @@ and FunctionType = {

and TopLevel =
| TLVerbatim of string
| TLDirective of string list
| TLDirective of string list * Location
| Function of FunctionType * Stmt
| TLDecl of Decl
| TypeDecl of StructOrInterfaceBlock // struct declaration, or interface block that introduce a set of external global variables.
Expand Down
6 changes: 2 additions & 4 deletions src/inlining.fs
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,9 @@ type FunctionInlining(options: Options.Options) =

let tryMarkFunctionToInline (funcInfo: FuncInfo) (callSites: CallSite list) =
if not funcInfo.funcType.fName.DoNotInline && verifyVarsAndParams funcInfo callSites then
// Mark both the call site (so that simplifyExpr can remove it) and the function (to remember to remove it).
// We cannot simply rely on unused functions removal, because it might be disabled through its own flag.
// Mark only the function's ident, not the call sites.
// simplifyExpr can't rely on call sites to be marked anyway, to support the function inline pragma.
options.trace $"{funcInfo.funcType.fName.Loc}: inlining function '{Printer.debugFunc funcInfo.funcType}' into {callSites.Length} call sites"
for callSite in callSites do
callSite.ident.ToBeInlined <- true
funcInfo.funcType.fName.ToBeInlined <- true

let markInlinableFunctions code =
Expand Down
2 changes: 1 addition & 1 deletion src/parse.fs
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ type private ParseImpl(options: Options.Options) =
let toplevel =
let decl = declaration .>> ch ';'
let item = choice [
macro |>> Ast.TLDirective
pipe2 macro getPosition (fun ss pos -> Ast.TLDirective (ss, {line = int pos.Line; col = int pos.Column}))
template |>> Ast.TLVerbatim
verbatim |>> Ast.TLVerbatim
attribute |>> Ast.TLVerbatim
Expand Down
4 changes: 2 additions & 2 deletions src/printer.fs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ type PrinterImpl(withLocations) =
// add a space at the end when it seems to be needed
let trailing = if s.Length > 0 && isIdentChar s.[s.Length - 1] then " " else ""
out "%s%s" s trailing
| TLDirective d -> directiveToS d
| TLDirective (d, _) -> directiveToS d
| Function (fct, Block []) -> out "%s%s{}" (funToS fct) (nl 0)
| Function (fct, (Block _ as body)) -> out "%s%s" (funToS fct) (stmtToS 0 body)
| Function (fct, body) -> out "%s%s{%s%s}" (funToS fct) (nl 0) (stmtToS 1 body) (nl 0)
Expand Down Expand Up @@ -322,7 +322,7 @@ type PrinterImpl(withLocations) =
| TypeDecl { name = Some n } -> n.OldName
| TypeDecl _ -> "*type decl*" // struct or unnamed interface block
| Precision _ -> "*precision*"
| TLDirective ("#define"::_) -> "#define"
| TLDirective (("#define"::_), _) -> "#define"
| TLDirective _ -> "*directive*"
| TLVerbatim _ -> "*verbatim*" // HLSL attribute, //[ skipped //]
symbolMap.AddMapping tlString symbolName
Expand Down
42 changes: 39 additions & 3 deletions src/rewriter.fs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ open Builtin
open Ast
open Inlining
open Analyzer
open System.Text.RegularExpressions

let private commaSeparatedExprs = List.reduce (fun a b -> FunCall(Op ",", [a; b]))

Expand Down Expand Up @@ -331,8 +332,13 @@ type private RewriterImpl(options: Options.Options, optimizationPass: Optimizati
else e
| _ -> e

let isFuncDeclarationToInline (fn: Ident) =
match fn.Declaration with
| Declaration.UserFunction uf -> uf.funcType.fName.ToBeInlined
| _ -> false

let simplifyExpr (didInline: bool ref) env = function
| FunCall(Var v, passedArgs) as e when v.ToBeInlined ->
| FunCall(Var v, passedArgs) as e when isFuncDeclarationToInline v ->
match env.fns.TryFind (v.Name, passedArgs.Length) with
| Some ([{args = declArgs}, body]) ->
if List.length declArgs <> List.length passedArgs then
Expand Down Expand Up @@ -899,7 +905,7 @@ type private RewriterImpl(options: Options.Options, optimizationPass: Optimizati
let li = declsNotToInline li
if li = [] then None else TLDecl (rwType ty, li) |> Some
| TLVerbatim s -> TLVerbatim (stripSpaces s) |> Some
| TLDirective d -> TLDirective (stripDirectiveSpaces d) |> Some
| TLDirective (d, loc) -> TLDirective (stripDirectiveSpaces d, loc) |> Some
| Function (fct, _) when fct.fName.ToBeInlined -> None
| Function (fct, body) -> Function (rwFType fct, body) |> Some
| e -> e |> Some
Expand Down Expand Up @@ -964,8 +970,38 @@ let rec private iterateSimplifyAndInline (options: Options.Options) optimization
iterateSimplifyAndInline options optimizationPass (passCount + 1) li
else li

let simplify options li =
let processPragmas (options: Options.Options) li =
let mutable forceInlineNextFunction = None
let warnIgnoredPragma (_, ss, loc) = options.trace $"{loc}: ignored pragma {ss}"
let processPragma tl =
match tl with
| TLDirective ([s], loc) when Regex.IsMatch(s, "#pragma +function +inline", RegexOptions.IgnoreCase) ->
forceInlineNextFunction |> Option.iter warnIgnoredPragma
forceInlineNextFunction <- Some (true, [s], loc)
None
| TLDirective ([s], loc) when Regex.IsMatch(s, "#pragma +function +noinline", RegexOptions.IgnoreCase) ->
forceInlineNextFunction |> Option.iter warnIgnoredPragma
forceInlineNextFunction <- Some (false, [s], loc)
None
| Function (ft, _) as tl ->
match forceInlineNextFunction with
| Some (true, _, _) ->
options.trace $"{ft.fName.Loc}: pragma forces inlining of '{Printer.debugFunc ft}'"
ft.fName.ToBeInlined <- true
| Some (false, _, _) ->
options.trace $"{ft.fName.Loc}: pragma prevents inlining of '{Printer.debugFunc ft}'"
ft.fName.DoNotInline <- true
| None -> ()
forceInlineNextFunction <- None
Some tl
| tl -> Some tl
let li = li |> List.choose processPragma
forceInlineNextFunction |> Option.iter warnIgnoredPragma
li

let simplify (options: Options.Options) li =
li
|> processPragmas options
|> iterateSimplifyAndInline options OptimizationPass.First 1
|> iterateSimplifyAndInline options OptimizationPass.Second 1
|> RewriterImpl(options, OptimizationPass.First).Cleanup
6 changes: 5 additions & 1 deletion tests/unit/inline-fn.aggro.expected
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ float F3_PRESERVED(inout float f)
{
return 7.;
}
float PRAGMA_PRESERVED()
{
return 9.;
}
float setup()
{
return shadowedVar++;
Expand All @@ -108,7 +112,7 @@ float f()
setup();
shadowedVar++;
shadowedFunc++;
return shadowedVar+shadowedFunc+four+five+six+ten+A1_PRESERVED()+1.+_A3+_A4+(B1_PRESERVED(3.)+B1_PRESERVED(4.))+4.+C1_PRESERVED()+(3.+sin(0.))+_D1+_D2+_D3+_E1+_E2+_E3+_E4+_E5+7.+_F2+o;
return shadowedVar+shadowedFunc+four+five+six+ten+A1_PRESERVED()+1.+_A3+_A4+(B1_PRESERVED(3.)+B1_PRESERVED(4.))+4.+C1_PRESERVED()+(3.+sin(0.))+_D1+_D2+_D3+_E1+_E2+_E3+_E4+_E5+7.+_F2+o+(vec3(9).x+vec3(8).x)+PRAGMA_PRESERVED();
}
float g()
{
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/inline-fn.expected
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ float F3_PRESERVED(inout float f)
{
return 7.;
}
float PRAGMA_PRESERVED()
{
return 9.;
}
float setup()
{
return shadowedVar++;
Expand All @@ -108,7 +112,7 @@ float f()
setup();
shadowedVar++;
shadowedFunc++;
return shadowedVar+shadowedFunc+four+five+six+ten+A1_PRESERVED()+1.+_A3+_A4+(B1_PRESERVED(3.)+B1_PRESERVED(4.))+4.+C1_PRESERVED()+(3.+sin(0.))+_D1+_D2+_D3+_E1+_E2+_E3+_E4+_E5+7.+_F2+o;
return shadowedVar+shadowedFunc+four+five+six+ten+A1_PRESERVED()+1.+_A3+_A4+(B1_PRESERVED(3.)+B1_PRESERVED(4.))+4.+C1_PRESERVED()+(3.+sin(0.))+_D1+_D2+_D3+_E1+_E2+_E3+_E4+_E5+7.+_F2+o+(vec3(9).x+vec3(8).x)+PRAGMA_PRESERVED();
}
float g()
{
Expand Down
10 changes: 9 additions & 1 deletion tests/unit/inline-fn.frag
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ float F1_INLINED(in float f) { return 7.0; }
float F2_PRESERVED(out float ff) { return 7.0; }
float F3_PRESERVED(inout float f) { return 7.0; }

#pragma function inline
float PRAGMA_INLINED(float x) { return vec3(x).x; }
#pragma function noinline
float PRAGMA_PRESERVED(float x) { return x; }

float setup() {
return shadowedVar++; // prevent inlining of the global shadowedVar
}
Expand Down Expand Up @@ -122,6 +127,8 @@ float f() {
float _F3 = F3_PRESERVED(o); // not inlined

sep++;
float _P1 = PRAGMA_INLINED(9.0)+PRAGMA_INLINED(8.0);
float _P2 = PRAGMA_PRESERVED(9.0);

setup();
shadowedVar++;
Expand All @@ -133,7 +140,8 @@ float f() {
_C1+_C2+
_D1+_D2+_D3+
_E1+_E2+_E3+_E4+_E5+
_F1+_F2+_F3;
_F1+_F2+_F3+
_P1+_P2;
}


Expand Down

0 comments on commit cfeae09

Please sign in to comment.