diff --git a/src/ast.fs b/src/ast.fs index ccc500a..9b971af 100644 --- a/src/ast.fs +++ b/src/ast.fs @@ -23,7 +23,7 @@ type Ident(name: string) = member this.Rename(n) = newName <- n member val ToBeInlined = newName.StartsWith("i_") with get, set // This prefix disables function inlining and variable inlining. - member this.DoNotInline = this.OldName.StartsWith("noinline_") + member val DoNotInline = newName.StartsWith("noinline_") with get, set member val Loc = {line = -1; col = -1} with get, set @@ -197,7 +197,7 @@ and FunctionType = { and TopLevel = | TLVerbatim of string - | TLDirective of string list + | TLDirective of string list * Location | Function of FunctionType * Stmt | TLDecl of Decl | TypeDecl of StructOrInterfaceBlock // struct declaration, or interface block that introduce a set of external global variables. diff --git a/src/inlining.fs b/src/inlining.fs index ff903a3..a9c44cf 100644 --- a/src/inlining.fs +++ b/src/inlining.fs @@ -248,11 +248,9 @@ type FunctionInlining(options: Options.Options) = let tryMarkFunctionToInline (funcInfo: FuncInfo) (callSites: CallSite list) = if not funcInfo.funcType.fName.DoNotInline && verifyVarsAndParams funcInfo callSites then - // Mark both the call site (so that simplifyExpr can remove it) and the function (to remember to remove it). - // We cannot simply rely on unused functions removal, because it might be disabled through its own flag. + // Mark only the function's ident, not the call sites. + // simplifyExpr can't rely on call sites to be marked anyway, to support the function inline pragma. options.trace $"{funcInfo.funcType.fName.Loc}: inlining function '{Printer.debugFunc funcInfo.funcType}' into {callSites.Length} call sites" - for callSite in callSites do - callSite.ident.ToBeInlined <- true funcInfo.funcType.fName.ToBeInlined <- true let markInlinableFunctions code = diff --git a/src/parse.fs b/src/parse.fs index 4bdb992..f30e654 100644 --- a/src/parse.fs +++ b/src/parse.fs @@ -371,7 +371,7 @@ type private ParseImpl(options: Options.Options) = let toplevel = let decl = declaration .>> ch ';' let item = choice [ - macro |>> Ast.TLDirective + pipe2 macro getPosition (fun ss pos -> Ast.TLDirective (ss, {line = int pos.Line; col = int pos.Column})) template |>> Ast.TLVerbatim verbatim |>> Ast.TLVerbatim attribute |>> Ast.TLVerbatim diff --git a/src/printer.fs b/src/printer.fs index 8f7603a..7cecaa0 100644 --- a/src/printer.fs +++ b/src/printer.fs @@ -287,7 +287,7 @@ type PrinterImpl(withLocations) = // add a space at the end when it seems to be needed let trailing = if s.Length > 0 && isIdentChar s.[s.Length - 1] then " " else "" out "%s%s" s trailing - | TLDirective d -> directiveToS d + | TLDirective (d, _) -> directiveToS d | Function (fct, Block []) -> out "%s%s{}" (funToS fct) (nl 0) | Function (fct, (Block _ as body)) -> out "%s%s" (funToS fct) (stmtToS 0 body) | Function (fct, body) -> out "%s%s{%s%s}" (funToS fct) (nl 0) (stmtToS 1 body) (nl 0) @@ -322,7 +322,7 @@ type PrinterImpl(withLocations) = | TypeDecl { name = Some n } -> n.OldName | TypeDecl _ -> "*type decl*" // struct or unnamed interface block | Precision _ -> "*precision*" - | TLDirective ("#define"::_) -> "#define" + | TLDirective (("#define"::_), _) -> "#define" | TLDirective _ -> "*directive*" | TLVerbatim _ -> "*verbatim*" // HLSL attribute, //[ skipped //] symbolMap.AddMapping tlString symbolName diff --git a/src/rewriter.fs b/src/rewriter.fs index dfd7f2e..14256f5 100644 --- a/src/rewriter.fs +++ b/src/rewriter.fs @@ -6,6 +6,7 @@ open Builtin open Ast open Inlining open Analyzer +open System.Text.RegularExpressions let private commaSeparatedExprs = List.reduce (fun a b -> FunCall(Op ",", [a; b])) @@ -331,8 +332,13 @@ type private RewriterImpl(options: Options.Options, optimizationPass: Optimizati else e | _ -> e + let isFuncDeclarationToInline (fn: Ident) = + match fn.Declaration with + | Declaration.UserFunction uf -> uf.funcType.fName.ToBeInlined + | _ -> false + let simplifyExpr (didInline: bool ref) env = function - | FunCall(Var v, passedArgs) as e when v.ToBeInlined -> + | FunCall(Var v, passedArgs) as e when isFuncDeclarationToInline v -> match env.fns.TryFind (v.Name, passedArgs.Length) with | Some ([{args = declArgs}, body]) -> if List.length declArgs <> List.length passedArgs then @@ -899,7 +905,7 @@ type private RewriterImpl(options: Options.Options, optimizationPass: Optimizati let li = declsNotToInline li if li = [] then None else TLDecl (rwType ty, li) |> Some | TLVerbatim s -> TLVerbatim (stripSpaces s) |> Some - | TLDirective d -> TLDirective (stripDirectiveSpaces d) |> Some + | TLDirective (d, loc) -> TLDirective (stripDirectiveSpaces d, loc) |> Some | Function (fct, _) when fct.fName.ToBeInlined -> None | Function (fct, body) -> Function (rwFType fct, body) |> Some | e -> e |> Some @@ -964,8 +970,38 @@ let rec private iterateSimplifyAndInline (options: Options.Options) optimization iterateSimplifyAndInline options optimizationPass (passCount + 1) li else li -let simplify options li = +let processPragmas (options: Options.Options) li = + let mutable forceInlineNextFunction = None + let warnIgnoredPragma (_, ss, loc) = options.trace $"{loc}: ignored pragma {ss}" + let processPragma tl = + match tl with + | TLDirective ([s], loc) when Regex.IsMatch(s, "#pragma +function +inline", RegexOptions.IgnoreCase) -> + forceInlineNextFunction |> Option.iter warnIgnoredPragma + forceInlineNextFunction <- Some (true, [s], loc) + None + | TLDirective ([s], loc) when Regex.IsMatch(s, "#pragma +function +noinline", RegexOptions.IgnoreCase) -> + forceInlineNextFunction |> Option.iter warnIgnoredPragma + forceInlineNextFunction <- Some (false, [s], loc) + None + | Function (ft, _) as tl -> + match forceInlineNextFunction with + | Some (true, _, _) -> + options.trace $"{ft.fName.Loc}: pragma forces inlining of '{Printer.debugFunc ft}'" + ft.fName.ToBeInlined <- true + | Some (false, _, _) -> + options.trace $"{ft.fName.Loc}: pragma prevents inlining of '{Printer.debugFunc ft}'" + ft.fName.DoNotInline <- true + | None -> () + forceInlineNextFunction <- None + Some tl + | tl -> Some tl + let li = li |> List.choose processPragma + forceInlineNextFunction |> Option.iter warnIgnoredPragma + li + +let simplify (options: Options.Options) li = li + |> processPragmas options |> iterateSimplifyAndInline options OptimizationPass.First 1 |> iterateSimplifyAndInline options OptimizationPass.Second 1 |> RewriterImpl(options, OptimizationPass.First).Cleanup diff --git a/tests/unit/inline-fn.aggro.expected b/tests/unit/inline-fn.aggro.expected index afa1363..ef7141f 100644 --- a/tests/unit/inline-fn.aggro.expected +++ b/tests/unit/inline-fn.aggro.expected @@ -86,6 +86,10 @@ float F3_PRESERVED(inout float f) { return 7.; } +float PRAGMA_PRESERVED() +{ + return 9.; +} float setup() { return shadowedVar++; @@ -108,7 +112,7 @@ float f() setup(); shadowedVar++; shadowedFunc++; - return shadowedVar+shadowedFunc+four+five+six+ten+A1_PRESERVED()+1.+_A3+_A4+(B1_PRESERVED(3.)+B1_PRESERVED(4.))+4.+C1_PRESERVED()+(3.+sin(0.))+_D1+_D2+_D3+_E1+_E2+_E3+_E4+_E5+7.+_F2+o; + return shadowedVar+shadowedFunc+four+five+six+ten+A1_PRESERVED()+1.+_A3+_A4+(B1_PRESERVED(3.)+B1_PRESERVED(4.))+4.+C1_PRESERVED()+(3.+sin(0.))+_D1+_D2+_D3+_E1+_E2+_E3+_E4+_E5+7.+_F2+o+(vec3(9).x+vec3(8).x)+PRAGMA_PRESERVED(); } float g() { diff --git a/tests/unit/inline-fn.expected b/tests/unit/inline-fn.expected index afa1363..ef7141f 100644 --- a/tests/unit/inline-fn.expected +++ b/tests/unit/inline-fn.expected @@ -86,6 +86,10 @@ float F3_PRESERVED(inout float f) { return 7.; } +float PRAGMA_PRESERVED() +{ + return 9.; +} float setup() { return shadowedVar++; @@ -108,7 +112,7 @@ float f() setup(); shadowedVar++; shadowedFunc++; - return shadowedVar+shadowedFunc+four+five+six+ten+A1_PRESERVED()+1.+_A3+_A4+(B1_PRESERVED(3.)+B1_PRESERVED(4.))+4.+C1_PRESERVED()+(3.+sin(0.))+_D1+_D2+_D3+_E1+_E2+_E3+_E4+_E5+7.+_F2+o; + return shadowedVar+shadowedFunc+four+five+six+ten+A1_PRESERVED()+1.+_A3+_A4+(B1_PRESERVED(3.)+B1_PRESERVED(4.))+4.+C1_PRESERVED()+(3.+sin(0.))+_D1+_D2+_D3+_E1+_E2+_E3+_E4+_E5+7.+_F2+o+(vec3(9).x+vec3(8).x)+PRAGMA_PRESERVED(); } float g() { diff --git a/tests/unit/inline-fn.frag b/tests/unit/inline-fn.frag index 349ee1e..13978cd 100644 --- a/tests/unit/inline-fn.frag +++ b/tests/unit/inline-fn.frag @@ -74,6 +74,11 @@ float F1_INLINED(in float f) { return 7.0; } float F2_PRESERVED(out float ff) { return 7.0; } float F3_PRESERVED(inout float f) { return 7.0; } +#pragma function inline +float PRAGMA_INLINED(float x) { return vec3(x).x; } +#pragma function noinline +float PRAGMA_PRESERVED(float x) { return x; } + float setup() { return shadowedVar++; // prevent inlining of the global shadowedVar } @@ -122,6 +127,8 @@ float f() { float _F3 = F3_PRESERVED(o); // not inlined sep++; + float _P1 = PRAGMA_INLINED(9.0)+PRAGMA_INLINED(8.0); + float _P2 = PRAGMA_PRESERVED(9.0); setup(); shadowedVar++; @@ -133,7 +140,8 @@ float f() { _C1+_C2+ _D1+_D2+_D3+ _E1+_E2+_E3+_E4+_E5+ - _F1+_F2+_F3; + _F1+_F2+_F3+ + _P1+_P2; }